Skip to content

Commit

Permalink
solution update of talaia-labs#199
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Mar 30, 2023
1 parent 47da3c8 commit 8bf3194
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 70 deletions.
43 changes: 18 additions & 25 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
if param_count != 2 {
Err(GetAppointmentError::InvalidFormat(format!(
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
)))
)))
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
Expand Down Expand Up @@ -289,7 +289,9 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
match value {
serde_json::Value::Array(a) => {
let param_count = a.len();
if param_count != 1 && param_count != 3 {
if param_count == 2{
Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()))
} else if param_count != 1 && param_count != 3 {
Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
)))
Expand All @@ -303,33 +305,23 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
"tower_id must be a hex encoded string".to_owned(),
))
}?;
let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) {
if start >= 0 {
Some(start as u32)
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-start must be a positive integer".to_owned(),
));
}
} else {
None
};
let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) {
if expire > subscription_start.unwrap() as i64 {
Some(expire as u32)

let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1).and_then(|v| v.as_i64()), a.get(2).and_then(|v| v.as_i64())) {
if start >= 0 && expire > start {
(Some(start as u32), Some(expire as u32))
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-expire must be a positive integer and greater than subscription_start".to_owned(),
));
"Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
));
}
} else if a.get(1).is_some() || a.get(2).is_some() {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(),
));
} else {
None
(None, None)
};
if subscription_start.is_some() != subscription_expiry.is_some() {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-start and subscription-expiry must be provided together".to_owned(),
));
}

Ok(Self {
tower_id,
subscription_start,
Expand All @@ -354,11 +346,12 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
params.push(v);
}
}

GetRegistrationReceiptParams::try_from(json!(params))
}
},
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'"
"Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'"
))),
}
}
Expand Down
77 changes: 40 additions & 37 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FromIterator;
use std::path::PathBuf;
use std::str::FromStr;

use rusqlite::{params, Connection, Error as SqliteError};
use rusqlite::{params, Connection, Error as SqliteError, ToSql};

use bitcoin::secp256k1::SecretKey;

Expand Down Expand Up @@ -218,28 +218,37 @@ impl DBM {
user_id: UserId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<RegistrationReceipt> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1 AND (subscription_start >=?2 OR ?2 is NULL) AND (subscription_expiry <=?3 OR ?3 is NULL)".to_string();
) -> Vec<RegistrationReceipt> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

if subscription_expiry == None {
query.push_str(" OR subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
};
let tower_id_encoded = tower_id.to_vec();
let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded];

if subscription_expiry.is_none() {
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
} else {
query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3");
params.push(&subscription_start);
params.push(&subscription_expiry)
}
let mut stmt = self.connection.prepare(&query).unwrap();

stmt.query_row(
params![tower_id.to_vec(), subscription_start, subscription_expiry],
|row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
let signature: String = row.get(3).unwrap();
let receipts = stmt
.query_map(params.as_slice(), |row| {
let slots: u32 = row.get(0)?;
let start: u32 = row.get(1)?;
let expiry: u32 = row.get(2)?;
let signature: String = row.get(3)?;

Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
},
)
.ok()
})
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap_or_default();

receipts
}

/// Removes a tower record from the database.
Expand Down Expand Up @@ -650,8 +659,8 @@ mod tests {

use teos_common::cryptography::get_random_keypair;
use teos_common::test_utils::{
generate_random_appointment, get_random_int, get_random_registration_receipt,
get_random_user_id, get_registration_receipt_from_previous,
generate_random_appointment, get_random_registration_receipt, get_random_user_id,
get_registration_receipt_from_previous,
};

impl DBM {
Expand Down Expand Up @@ -738,39 +747,34 @@ mod tests {
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
)[0],
receipt
);

// Add another receipt for the same tower with a higher expiry and check this last one is loaded
// Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts
let middle_receipt = get_registration_receipt_from_previous(&receipt);
let latest_receipt = get_registration_receipt_from_previous(&middle_receipt);

let latest_subscription_expiry = Some(latest_receipt.subscription_expiry());

dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
latest_receipt
latest_subscription_expiry
),
vec![receipt, latest_receipt.clone()]
);

// Add a final one with a lower expiry and check the last is still loaded
// Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry
// params are not passed
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)[0],
latest_receipt
);
}
Expand All @@ -783,8 +787,8 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = get_random_int();
let subscription_expiry = get_random_int();
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Store it once
dbm.store_tower_record(tower_id, net_addr, &receipt)
Expand All @@ -795,8 +799,7 @@ mod tests {
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
)[0],
receipt
);

Expand Down
18 changes: 11 additions & 7 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,18 @@ async fn get_registration_receipt(
let subscription_expiry = params.subscription_expiry;
let state = plugin.state().lock().unwrap();

if let Some(response) =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry)
{
Ok(json!(response))
let response =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
if response.is_empty() {
if state.towers.contains_key(&tower_id) {
Err(anyhow!("No registration receipt found for {tower_id}"))
} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered ?"
))
}
} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered?"
))
Ok(json!(response))
}
}

Expand Down
2 changes: 1 addition & 1 deletion watchtower-plugin/src/wt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl WTClient {
tower_id: TowerId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<RegistrationReceipt> {
) -> Vec<RegistrationReceipt> {
self.dbm.load_registration_receipt(
tower_id,
self.user_id,
Expand Down

0 comments on commit 8bf3194

Please sign in to comment.