diff --git a/src/database.rs b/src/database.rs index b46bb3c..b685743 100644 --- a/src/database.rs +++ b/src/database.rs @@ -26,10 +26,14 @@ use StateGroupEntry; /// Fetch the entries in state_groups_state (and their prev groups) for the /// given `room_id` by connecting to the postgres database at `db_url`. -pub fn get_data_from_db(db_url: &str, room_id: &str) -> BTreeMap { +pub fn get_data_from_db( + db_url: &str, + room_id: &str, + max_state_group: Option, +) -> BTreeMap { let conn = Connection::connect(db_url, TlsMode::None).unwrap(); - let mut state_group_map = get_initial_data_from_db(&conn, room_id); + let mut state_group_map = get_initial_data_from_db(&conn, room_id, max_state_group); println!("Got initial state from database. Checking for any missing state groups..."); @@ -69,21 +73,36 @@ pub fn get_data_from_db(db_url: &str, room_id: &str) -> BTreeMap BTreeMap { - let stmt = conn - .prepare( - r#" - SELECT m.id, prev_state_group, type, state_key, s.event_id - FROM state_groups AS m - LEFT JOIN state_groups_state AS s ON (m.id = s.state_group) - LEFT JOIN state_group_edges AS e ON (m.id = e.state_group) - WHERE m.room_id = $1 - "#, - ) - .unwrap(); +fn get_initial_data_from_db( + conn: &Connection, + room_id: &str, + max_state_group: Option, +) -> BTreeMap { + let sql = format!( + r#" + SELECT m.id, prev_state_group, type, state_key, s.event_id + FROM state_groups AS m + LEFT JOIN state_groups_state AS s ON (m.id = s.state_group) + LEFT JOIN state_group_edges AS e ON (m.id = e.state_group) + WHERE m.room_id = $1 {} + "#, + if max_state_group.is_some() { + "AND state_group <= $2" + } else { + "" + } + ); + + let stmt = conn.prepare(&sql).unwrap(); let trans = conn.transaction().unwrap(); - let mut rows = stmt.lazy_query(&trans, &[&room_id], 1000).unwrap(); + + let mut rows = if let Some(s) = max_state_group { + stmt.lazy_query(&trans, &[&room_id, &s], 1000) + } else { + stmt.lazy_query(&trans, &[&room_id], 1000) + } + .unwrap(); let mut state_group_map: BTreeMap = BTreeMap::new(); diff --git a/src/main.rs b/src/main.rs index 3c0eb5a..2bbdb22 100644 --- a/src/main.rs +++ b/src/main.rs @@ -124,6 +124,12 @@ fn main() { .help("The room to process") .takes_value(true) .required(true), + ).arg( + Arg::with_name("max_state_group") + .short("s") + .help("The maximum state group to process up to") + .takes_value(true) + .required(false), ).arg( Arg::with_name("output_file") .short("o") @@ -161,17 +167,22 @@ fn main() { let mut output_file = matches .value_of("output_file") .map(|path| File::create(path).unwrap()); + let room_id = matches .value_of("room_id") .expect("room_id should be required since no file"); + let max_state_group = matches + .value_of("max_state_group") + .map(|s| s.parse().expect("max_state_group must be an integer")); + let transactions = matches.is_present("transactions"); let level_sizes = value_t_or_exit!(matches, "level_sizes", LevelSizes); // First we need to get the current state groups println!("Fetching state from DB for room '{}'...", room_id); - let state_group_map = database::get_data_from_db(db_url, room_id); + let state_group_map = database::get_data_from_db(db_url, room_id, max_state_group); println!("Number of state groups: {}", state_group_map.len());