diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs index d49889ca..a2acd576 100644 --- a/watchtower-plugin/src/convert.rs +++ b/watchtower-plugin/src/convert.rs @@ -289,44 +289,56 @@ impl TryFrom for GetRegistrationReceiptParams { match value { serde_json::Value::Array(a) => { let param_count = a.len(); - if param_count == 2{ - Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string())) + if param_count == 2 { + Err(GetRegistrationReceiptError::InvalidFormat(( + "Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string() + )) } else if param_count != 1 && param_count != 3 { Err(GetRegistrationReceiptError::InvalidFormat(format!( "Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}" ))) - } else{ + } else { let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { TowerId::from_str(s).map_err(|_| { - GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned()) + GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned()) }) } else { Err(GetRegistrationReceiptError::InvalidId( - "tower_id must be a hex encoded string".to_owned(), + "tower_id must be a hex encoded string".to_owned(), )) }?; - let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1).and_then(|v| v.as_i64()), a.get(2).and_then(|v| v.as_i64())) { + let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){ + let start = start.as_i64().ok_or_else(|| { + GetRegistrationReceiptError::InvalidFormat( + "Subscription_start must be a positive integer".to_owned(), + ) + })?; + + let expire = expire.as_i64().ok_or_else(|| { + GetRegistrationReceiptError::InvalidFormat( + "Subscription_expire must be a positive integer".to_owned(), + ) + })?; + if start >= 0 && expire > start { (Some(start as u32), Some(expire as u32)) } else { return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(), - )); + "subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(), + )); } - } else if a.get(1).is_some() || a.get(2).is_some() { - return Err(GetRegistrationReceiptError::InvalidFormat( - "Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(), - )); } else { (None, None) }; - Ok(Self { - tower_id, - subscription_start, - subscription_expiry, - }) + Ok( + Self { + tower_id, + subscription_start, + subscription_expiry, + } + ) } }, serde_json::Value::Object(mut m) => { diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index d7cfcb0e..7ee9f388 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -218,7 +218,7 @@ impl DBM { user_id: UserId, subscription_start: Option, subscription_expiry: Option, - ) -> Vec { + ) -> Option> { let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string(); let tower_id_encoded = tower_id.to_vec(); @@ -233,8 +233,8 @@ impl DBM { } let mut stmt = self.connection.prepare(&query).unwrap(); - let receipts = stmt - .query_map(params.as_slice(), |row| { + Some( + stmt.query_map(params.as_slice(), |row| { let slots: u32 = row.get(0)?; let start: u32 = row.get(1)?; let expiry: u32 = row.get(2)?; @@ -245,10 +245,9 @@ impl DBM { )) }) .unwrap() - .collect::, _>>() - .unwrap_or_default(); - - receipts + .map(|r| r.unwrap()) + .collect(), + ) } /// Removes a tower record from the database. @@ -747,7 +746,8 @@ mod tests { receipt.user_id(), subscription_start, subscription_expiry - )[0], + ) + .unwrap()[0], receipt ); @@ -765,7 +765,8 @@ mod tests { latest_receipt.user_id(), subscription_start, latest_subscription_expiry - ), + ) + .unwrap(), vec![receipt, latest_receipt.clone()] ); @@ -774,7 +775,8 @@ mod tests { dbm.store_tower_record(tower_id, net_addr, &middle_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)[0], + dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None) + .unwrap()[0], latest_receipt ); } @@ -799,7 +801,8 @@ mod tests { receipt.user_id(), subscription_start, subscription_expiry - )[0], + ) + .unwrap()[0], receipt ); diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index f6f19c1b..d80fbbe4 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -143,7 +143,7 @@ async fn get_registration_receipt( let response = state.get_registration_receipt(tower_id, subscription_start, subscription_expiry); - if response.is_empty() { + if response.clone().unwrap().is_empty() { if state.towers.contains_key(&tower_id) { Err(anyhow!("No registration receipt found for {tower_id}")) } else { diff --git a/watchtower-plugin/src/wt_client.rs b/watchtower-plugin/src/wt_client.rs index 8287628d..59f6d68d 100644 --- a/watchtower-plugin/src/wt_client.rs +++ b/watchtower-plugin/src/wt_client.rs @@ -185,7 +185,7 @@ impl WTClient { tower_id: TowerId, subscription_start: Option, subscription_expiry: Option, - ) -> Vec { + ) -> Option> { self.dbm.load_registration_receipt( tower_id, self.user_id,