Skip to content

Commit

Permalink
second update talaia-labs#199
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Mar 16, 2023
1 parent 0c95fa0 commit b089fc3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 65 deletions.
1 change: 1 addition & 0 deletions watchtower-plugin/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR";
pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower";

/// Collections of plugin option names, default values and descriptions

pub const WT_PORT: &str = "watchtower-port";
pub const DEFAULT_WT_PORT: i64 = 9814;
pub const WT_PORT_DESC: &str = "tower API port";
Expand Down
91 changes: 61 additions & 30 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
let param_count = a.len();
if param_count != 2 {
Err(GetAppointmentError::InvalidFormat(format!(
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
)))
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
)))
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
Expand Down Expand Up @@ -288,51 +288,82 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Array(a) => {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
let param_count = a.len();
if param_count != 1 && param_count != 3 {
Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
)))
} else{
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
})
} else {
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
))
}?;
let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) {
if start >= 0 {
Some(start as u32)
})
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
))
}?;
let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) {
if start >= 0 {
Some(start as u32)
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-start must be a positive integer".to_owned(),
));
}
} else {
None
};
let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) {
if expire >= 0 {
));
}
} else {
None
};
let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) {
if expire > subscription_start.unwrap() as i64 {
Some(expire as u32)
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-expire must be a positive integer and greater than subscription_start".to_owned(),
));
}
} else {
None
};
if subscription_start.is_some() != subscription_expiry.is_some() {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-expire must be a positive integer".to_owned(),
));
"Subscription-start and subscription-expiry must be provided together".to_owned(),
));
}
} else {
None
};
Ok(Self {
Ok(Self {
tower_id,
subscription_start,
subscription_expiry,
})
}
})
}
},
serde_json::Value::Object(mut m) => {
let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"];
let param_count = m.len();

if m.is_empty() || param_count > allowed_keys.len() {
Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}")))
} else if !m.contains_key(allowed_keys[0]){
Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0])))
} else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) {
Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned()))
} else {
let mut params = Vec::with_capacity(allowed_keys.len());
for k in allowed_keys {
if let Some(v) = m.remove(k) {
params.push(v);
}
}
GetRegistrationReceiptParams::try_from(json!(params))
}
},
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'"
))),
}
}
}


/// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook.
#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentRevocation {
Expand Down
59 changes: 24 additions & 35 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FromIterator;
use std::path::PathBuf;
use std::str::FromStr;

use rusqlite::{params, Connection, Error as SqliteError, ToSql};
use rusqlite::{params, Connection, Error as SqliteError};

use bitcoin::secp256k1::SecretKey;

Expand Down Expand Up @@ -219,38 +219,27 @@ impl DBM {
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<RegistrationReceipt> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

let mut params = vec![tower_id.to_vec()];

if let Some(start) = subscription_start {
query.push_str(" AND subscription_start >= ?2");
params.push(start.to_be_bytes().to_vec());
} else {
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)");
}

if let Some(expiry) = subscription_expiry {
query.push_str(" AND subscription_expiry <= ?3");
params.push(expiry.to_be_bytes().to_vec());
}

//query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1");
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1 AND (subscription_start >=?2 OR ?2 is NULL) AND (subscription_expiry <=?3 OR ?3 is NULL)".to_string();

if subscription_expiry == None {
query.push_str(" OR subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
};
let mut stmt = self.connection.prepare(&query).unwrap();
let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect();

stmt.query_row(params.as_slice(), |row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
let signature: String = row.get(3).unwrap();
stmt.query_row(
params![tower_id.to_vec(), subscription_start, subscription_expiry],
|row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
let signature: String = row.get(3).unwrap();

Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
}).
ok()
Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
},
)
.ok()
}

/// Removes a tower record from the database.
Expand Down Expand Up @@ -661,8 +650,8 @@ mod tests {

use teos_common::cryptography::get_random_keypair;
use teos_common::test_utils::{
generate_random_appointment, get_random_registration_receipt, get_random_user_id,
get_registration_receipt_from_previous,
generate_random_appointment, get_random_int, get_random_registration_receipt,
get_random_user_id, get_registration_receipt_from_previous,
};

impl DBM {
Expand Down Expand Up @@ -737,8 +726,8 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = None;
let subscription_expiry = None;
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Check the receipt was stored
dbm.store_tower_record(tower_id, net_addr, &receipt)
Expand Down Expand Up @@ -794,8 +783,8 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = None;
let subscription_expiry = None;
let subscription_start = get_random_int();
let subscription_expiry = get_random_int();

// Store it once
dbm.store_tower_record(tower_id, net_addr, &receipt)
Expand Down

0 comments on commit b089fc3

Please sign in to comment.