Skip to content

Commit

Permalink
Add optional boundaries to getregistrationreceipt
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Mar 30, 2023
1 parent 8bf3194 commit 6046848
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 30 deletions.
46 changes: 29 additions & 17 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,44 +289,56 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
match value {
serde_json::Value::Array(a) => {
let param_count = a.len();
if param_count == 2{
Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()))
if param_count == 2 {
Err(GetRegistrationReceiptError::InvalidFormat((
"Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()
))
} else if param_count != 1 && param_count != 3 {
Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
)))
} else{
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
})
} else {
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
"tower_id must be a hex encoded string".to_owned(),
))
}?;

let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1).and_then(|v| v.as_i64()), a.get(2).and_then(|v| v.as_i64())) {
let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){
let start = start.as_i64().ok_or_else(|| {
GetRegistrationReceiptError::InvalidFormat(
"Subscription_start must be a positive integer".to_owned(),
)
})?;

let expire = expire.as_i64().ok_or_else(|| {
GetRegistrationReceiptError::InvalidFormat(
"Subscription_expire must be a positive integer".to_owned(),
)
})?;

if start >= 0 && expire > start {
(Some(start as u32), Some(expire as u32))
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
));
"subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
));
}
} else if a.get(1).is_some() || a.get(2).is_some() {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(),
));
} else {
(None, None)
};

Ok(Self {
tower_id,
subscription_start,
subscription_expiry,
})
Ok(
Self {
tower_id,
subscription_start,
subscription_expiry,
}
)
}
},
serde_json::Value::Object(mut m) => {
Expand Down
25 changes: 14 additions & 11 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl DBM {
user_id: UserId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Vec<RegistrationReceipt> {
) -> Option<Vec<RegistrationReceipt>> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

let tower_id_encoded = tower_id.to_vec();
Expand All @@ -233,8 +233,8 @@ impl DBM {
}
let mut stmt = self.connection.prepare(&query).unwrap();

let receipts = stmt
.query_map(params.as_slice(), |row| {
Some(
stmt.query_map(params.as_slice(), |row| {
let slots: u32 = row.get(0)?;
let start: u32 = row.get(1)?;
let expiry: u32 = row.get(2)?;
Expand All @@ -245,10 +245,9 @@ impl DBM {
))
})
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap_or_default();

receipts
.map(|r| r.unwrap())
.collect(),
)
}

/// Removes a tower record from the database.
Expand Down Expand Up @@ -747,7 +746,8 @@ mod tests {
receipt.user_id(),
subscription_start,
subscription_expiry
)[0],
)
.unwrap()[0],
receipt
);

Expand All @@ -765,7 +765,8 @@ mod tests {
latest_receipt.user_id(),
subscription_start,
latest_subscription_expiry
),
)
.unwrap(),
vec![receipt, latest_receipt.clone()]
);

Expand All @@ -774,7 +775,8 @@ mod tests {
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)[0],
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)
.unwrap()[0],
latest_receipt
);
}
Expand All @@ -799,7 +801,8 @@ mod tests {
receipt.user_id(),
subscription_start,
subscription_expiry
)[0],
)
.unwrap()[0],
receipt
);

Expand Down
2 changes: 1 addition & 1 deletion watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async fn get_registration_receipt(

let response =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
if response.is_empty() {
if response.clone().unwrap().is_empty() {
if state.towers.contains_key(&tower_id) {
Err(anyhow!("No registration receipt found for {tower_id}"))
} else {
Expand Down
2 changes: 1 addition & 1 deletion watchtower-plugin/src/wt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl WTClient {
tower_id: TowerId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Vec<RegistrationReceipt> {
) -> Option<Vec<RegistrationReceipt>> {
self.dbm.load_registration_receipt(
tower_id,
self.user_id,
Expand Down

0 comments on commit 6046848

Please sign in to comment.