diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs index 96bdfe77..d49889ca 100644 --- a/watchtower-plugin/src/convert.rs +++ b/watchtower-plugin/src/convert.rs @@ -200,7 +200,7 @@ impl TryFrom for GetAppointmentParams { if param_count != 2 { Err(GetAppointmentError::InvalidFormat(format!( "Unexpected request format. The request needs 2 parameter. Received: {param_count}" - ))) + ))) } else { let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { TowerId::from_str(s).map_err(|_| { @@ -289,7 +289,9 @@ impl TryFrom for GetRegistrationReceiptParams { match value { serde_json::Value::Array(a) => { let param_count = a.len(); - if param_count != 1 && param_count != 3 { + if param_count == 2{ + Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string())) + } else if param_count != 1 && param_count != 3 { Err(GetRegistrationReceiptError::InvalidFormat(format!( "Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}" ))) @@ -303,33 +305,23 @@ impl TryFrom for GetRegistrationReceiptParams { "tower_id must be a hex encoded string".to_owned(), )) }?; - let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) { - if start >= 0 { - Some(start as u32) - } else { - return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription-start must be a positive integer".to_owned(), - )); - } - } else { - None - }; - let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) { - if expire > subscription_start.unwrap() as i64 { - Some(expire as u32) + + let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1).and_then(|v| v.as_i64()), a.get(2).and_then(|v| v.as_i64())) { + if start >= 0 && expire > start { + (Some(start as u32), Some(expire as u32)) } else { return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription-expire must be a positive integer and greater than subscription_start".to_owned(), - )); + "Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(), + )); } + } else if a.get(1).is_some() || a.get(2).is_some() { + return Err(GetRegistrationReceiptError::InvalidFormat( + "Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(), + )); } else { - None + (None, None) }; - if subscription_start.is_some() != subscription_expiry.is_some() { - return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription-start and subscription-expiry must be provided together".to_owned(), - )); - } + Ok(Self { tower_id, subscription_start, @@ -354,11 +346,12 @@ impl TryFrom for GetRegistrationReceiptParams { params.push(v); } } + GetRegistrationReceiptParams::try_from(json!(params)) } }, _ => Err(GetRegistrationReceiptError::InvalidFormat(format!( - "Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'" + "Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'" ))), } } diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index e58c9585..d7cfcb0e 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use std::path::PathBuf; use std::str::FromStr; -use rusqlite::{params, Connection, Error as SqliteError}; +use rusqlite::{params, Connection, Error as SqliteError, ToSql}; use bitcoin::secp256k1::SecretKey; @@ -218,28 +218,37 @@ impl DBM { user_id: UserId, subscription_start: Option, subscription_expiry: Option, - ) -> Option { - let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1 AND (subscription_start >=?2 OR ?2 is NULL) AND (subscription_expiry <=?3 OR ?3 is NULL)".to_string(); + ) -> Vec { + let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string(); - if subscription_expiry == None { - query.push_str(" OR subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)") - }; + let tower_id_encoded = tower_id.to_vec(); + let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded]; + + if subscription_expiry.is_none() { + query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)") + } else { + query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3"); + params.push(&subscription_start); + params.push(&subscription_expiry) + } let mut stmt = self.connection.prepare(&query).unwrap(); - stmt.query_row( - params![tower_id.to_vec(), subscription_start, subscription_expiry], - |row| { - let slots: u32 = row.get(0).unwrap(); - let start: u32 = row.get(1).unwrap(); - let expiry: u32 = row.get(2).unwrap(); - let signature: String = row.get(3).unwrap(); + let receipts = stmt + .query_map(params.as_slice(), |row| { + let slots: u32 = row.get(0)?; + let start: u32 = row.get(1)?; + let expiry: u32 = row.get(2)?; + let signature: String = row.get(3)?; Ok(RegistrationReceipt::with_signature( user_id, slots, start, expiry, signature, )) - }, - ) - .ok() + }) + .unwrap() + .collect::, _>>() + .unwrap_or_default(); + + receipts } /// Removes a tower record from the database. @@ -650,8 +659,8 @@ mod tests { use teos_common::cryptography::get_random_keypair; use teos_common::test_utils::{ - generate_random_appointment, get_random_int, get_random_registration_receipt, - get_random_user_id, get_registration_receipt_from_previous, + generate_random_appointment, get_random_registration_receipt, get_random_user_id, + get_registration_receipt_from_previous, }; impl DBM { @@ -738,15 +747,16 @@ mod tests { receipt.user_id(), subscription_start, subscription_expiry - ) - .unwrap(), + )[0], receipt ); - // Add another receipt for the same tower with a higher expiry and check this last one is loaded + // Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts let middle_receipt = get_registration_receipt_from_previous(&receipt); let latest_receipt = get_registration_receipt_from_previous(&middle_receipt); + let latest_subscription_expiry = Some(latest_receipt.subscription_expiry()); + dbm.store_tower_record(tower_id, net_addr, &latest_receipt) .unwrap(); assert_eq!( @@ -754,23 +764,17 @@ mod tests { tower_id, latest_receipt.user_id(), subscription_start, - subscription_expiry - ) - .unwrap(), - latest_receipt + latest_subscription_expiry + ), + vec![receipt, latest_receipt.clone()] ); - // Add a final one with a lower expiry and check the last is still loaded + // Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry + // params are not passed dbm.store_tower_record(tower_id, net_addr, &middle_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt( - tower_id, - latest_receipt.user_id(), - subscription_start, - subscription_expiry - ) - .unwrap(), + dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)[0], latest_receipt ); } @@ -783,8 +787,8 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); - let subscription_start = get_random_int(); - let subscription_expiry = get_random_int(); + let subscription_start = Some(receipt.subscription_start()); + let subscription_expiry = Some(receipt.subscription_expiry()); // Store it once dbm.store_tower_record(tower_id, net_addr, &receipt) @@ -795,8 +799,7 @@ mod tests { receipt.user_id(), subscription_start, subscription_expiry - ) - .unwrap(), + )[0], receipt ); diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index d0548e9c..f6f19c1b 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -141,14 +141,18 @@ async fn get_registration_receipt( let subscription_expiry = params.subscription_expiry; let state = plugin.state().lock().unwrap(); - if let Some(response) = - state.get_registration_receipt(tower_id, subscription_start, subscription_expiry) - { - Ok(json!(response)) + let response = + state.get_registration_receipt(tower_id, subscription_start, subscription_expiry); + if response.is_empty() { + if state.towers.contains_key(&tower_id) { + Err(anyhow!("No registration receipt found for {tower_id}")) + } else { + Err(anyhow!( + "Cannot find {tower_id} within the known towers. Have you registered ?" + )) + } } else { - Err(anyhow!( - "Cannot find {tower_id} within the known towers. Have you registered?" - )) + Ok(json!(response)) } } diff --git a/watchtower-plugin/src/wt_client.rs b/watchtower-plugin/src/wt_client.rs index 454b073c..8287628d 100644 --- a/watchtower-plugin/src/wt_client.rs +++ b/watchtower-plugin/src/wt_client.rs @@ -185,7 +185,7 @@ impl WTClient { tower_id: TowerId, subscription_start: Option, subscription_expiry: Option, - ) -> Option { + ) -> Vec { self.dbm.load_registration_receipt( tower_id, self.user_id,