Correctly escape strings

This commit is contained in:
Erik Johnston
2018-11-02 16:36:08 +00:00
parent 3b3ddc0ed0
commit 7218bb0c83
5 changed files with 62 additions and 8 deletions

View File

@@ -15,8 +15,12 @@
use fallible_iterator::FallibleIterator;
use indicatif::{ProgressBar, ProgressStyle};
use postgres::{Connection, TlsMode};
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt;
use StateGroupEntry;
@@ -47,7 +51,8 @@ pub fn get_data_from_db(db_url: &str, room_id: &str) -> BTreeMap<i64, StateGroup
} else {
None
}
}).collect();
})
.collect();
if missing_sgs.is_empty() {
break;
@@ -74,7 +79,8 @@ fn get_initial_data_from_db(conn: &Connection, room_id: &str) -> BTreeMap<i64, S
LEFT JOIN state_group_edges AS e ON (m.id = e.state_group)
WHERE m.room_id = $1
"#,
).unwrap();
)
.unwrap();
let trans = conn.transaction().unwrap();
let mut rows = stmt.lazy_query(&trans, &[&room_id], 1000).unwrap();
@@ -123,7 +129,8 @@ fn get_missing_from_db(conn: &Connection, missing_sgs: &[i64]) -> BTreeMap<i64,
FROM state_group_edges
WHERE state_group = ANY($1)
"#,
).unwrap();
)
.unwrap();
let trans = conn.transaction().unwrap();
let mut rows = stmt.lazy_query(&trans, &[&missing_sgs], 100).unwrap();
@@ -140,3 +147,38 @@ fn get_missing_from_db(conn: &Connection, missing_sgs: &[i64]) -> BTreeMap<i64,
state_group_map
}
/// Helper function that escapes the wrapped text when writing SQL
pub struct PGEscapse<'a>(pub &'a str);
impl<'a> fmt::Display for PGEscapse<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut delim = Cow::from("$$");
while self.0.contains(&delim as &str) {
let s: String = thread_rng().sample_iter(&Alphanumeric).take(10).collect();
delim = format!("${}$", s).into();
}
write!(f, "{}{}{}", delim, self.0, delim)
}
}
#[test]
fn test_pg_escape() {
let s = format!("{}", PGEscapse("test"));
assert_eq!(s, "$$test$$");
let dodgy_string = "test$$ing";
let s = format!("{}", PGEscapse(dodgy_string));
// prefix and suffixes should match
let start_pos = s.find(dodgy_string).expect("expected to find dodgy string");
let end_pos = start_pos + dodgy_string.len();
assert_eq!(s[..start_pos], s[end_pos..]);
// .. and they should start and end with '$'
assert_eq!(&s[0..1], "$");
assert_eq!(&s[start_pos - 1..start_pos], "$");
}