Skip to content

Commit

Permalink
create PeerAddress as a wrapper of libp2p::PeerId
Browse files Browse the repository at this point in the history
Signed-off-by: onur-ozkan <[email protected]>
  • Loading branch information
onur-ozkan committed Sep 26, 2024
1 parent ad61f31 commit b4d63f6
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 84 deletions.
192 changes: 117 additions & 75 deletions mm2src/mm2_main/src/lp_healthcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use futures::channel::oneshot::{self, Receiver, Sender};
use instant::Duration;
use mm2_core::mm_ctx::{HealthcheckConfig, MmArc};
use mm2_err_handle::prelude::MmError;
use mm2_libp2p::{decode_message, encode_message, pub_sub_topic, Libp2pPublic, PeerId, TopicPrefix};
use mm2_libp2p::{decode_message, encode_message, pub_sub_topic, Libp2pPublic, TopicPrefix};
use mm2_net::p2p::P2PContext;
use ser_error_derive::SerializeErrorType;
use serde::{Deserialize, Serialize};
Expand All @@ -28,15 +28,80 @@ pub(crate) struct HealthcheckMessage {
data: HealthcheckData,
}

/// Wrapper of `libp2p::PeerId` with trait additional implementations.
///
/// TODO: This should be used as a replacement of `libp2p::PeerId` in the entire project.
#[derive(Clone, Copy, Debug, Display, PartialEq)]
pub struct PeerAddress(mm2_libp2p::PeerId);

impl From<mm2_libp2p::PeerId> for PeerAddress {
fn from(value: mm2_libp2p::PeerId) -> Self { Self(value) }
}

impl From<PeerAddress> for mm2_libp2p::PeerId {
fn from(value: PeerAddress) -> Self { value.0 }
}

impl Serialize for PeerAddress {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.0.to_string())
}
}

impl<'de> Deserialize<'de> for PeerAddress {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PeerAddressVisitor;

impl<'de> serde::de::Visitor<'de> for PeerAddressVisitor {
type Value = PeerAddress;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representation of peer id.")
}

fn visit_str<E>(self, value: &str) -> Result<PeerAddress, E>
where
E: serde::de::Error,
{
if value.len() > 100 {
return Err(serde::de::Error::invalid_length(
value.len(),
&"peer id cannot exceed 100 characters.",
));
}

Ok(mm2_libp2p::PeerId::from_str(value)
.map_err(serde::de::Error::custom)?
.into())
}

fn visit_string<E>(self, value: String) -> Result<PeerAddress, E>
where
E: serde::de::Error,
{
self.visit_str(&value)
}
}

deserializer.deserialize_str(PeerAddressVisitor)
}
}

impl HealthcheckMessage {
pub(crate) fn generate_message(
ctx: &MmArc,
target_peer: PeerId,
target_peer: PeerAddress,
is_a_reply: bool,
expires_in_seconds: i64,
) -> Result<Self, String> {
let p2p_ctx = P2PContext::fetch_from_mm_arc(ctx);
let sender_peer = p2p_ctx.peer_id();
let sender_peer = p2p_ctx.peer_id().into();
let keypair = p2p_ctx.keypair();
let sender_public_key = keypair.public().encode_protobuf();

Expand All @@ -53,7 +118,11 @@ impl HealthcheckMessage {
Ok(Self { signature, data })
}

pub(crate) fn is_received_message_valid(&self, my_peer_id: PeerId, healthcheck_config: &HealthcheckConfig) -> bool {
pub(crate) fn is_received_message_valid(
&self,
my_peer_address: PeerAddress,
healthcheck_config: &HealthcheckConfig,
) -> bool {
let now = Utc::now().timestamp();
let remaining_expiration_seconds = u64::try_from(self.data.expires_at - now).unwrap_or(0);

Expand All @@ -72,10 +141,10 @@ impl HealthcheckMessage {
return false;
}

if self.data.target_peer != my_peer_id {
if self.data.target_peer != my_peer_address {
log::debug!(
"`target_peer` doesn't match with our peer address. Our address: '{}', healthcheck `target_peer`: '{}'.",
my_peer_id,
my_peer_address,
self.data.target_peer
);
return false;
Expand All @@ -87,7 +156,7 @@ impl HealthcheckMessage {
return false
};

if self.data.sender_peer != public_key.to_peer_id() {
if self.data.sender_peer != public_key.to_peer_id().into() {
log::debug!("`sender_peer` and `sender_public_key` doesn't belong each other.");

return false;
Expand Down Expand Up @@ -117,18 +186,16 @@ impl HealthcheckMessage {
pub(crate) fn should_reply(&self) -> bool { !self.data.is_a_reply }

#[inline]
pub(crate) fn sender_peer(&self) -> PeerId { self.data.sender_peer }
pub(crate) fn sender_peer(&self) -> PeerAddress { self.data.sender_peer }
}

#[derive(Debug, Deserialize, Serialize)]
#[cfg_attr(any(test, target_arch = "wasm32"), derive(PartialEq))]
struct HealthcheckData {
#[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")]
sender_peer: PeerId,
sender_peer: PeerAddress,
#[serde(deserialize_with = "deserialize_bytes")]
sender_public_key: Vec<u8>,
#[serde(deserialize_with = "deserialize_peer_id", serialize_with = "serialize_peer_id")]
target_peer: PeerId,
target_peer: PeerAddress,
expires_at: i64,
is_a_reply: bool,
}
Expand All @@ -139,59 +206,13 @@ impl HealthcheckData {
}

#[inline]
pub fn peer_healthcheck_topic(peer_id: &PeerId) -> String {
pub_sub_topic(PEER_HEALTHCHECK_PREFIX, &peer_id.to_string())
pub fn peer_healthcheck_topic(peer_address: &PeerAddress) -> String {
pub_sub_topic(PEER_HEALTHCHECK_PREFIX, &peer_address.to_string())
}

#[derive(Deserialize)]
pub struct RequestPayload {
#[serde(deserialize_with = "deserialize_peer_id")]
peer_id: PeerId,
}

fn deserialize_peer_id<'de, D>(deserializer: D) -> Result<PeerId, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PeerIdVisitor;

impl<'de> serde::de::Visitor<'de> for PeerIdVisitor {
type Value = PeerId;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representation of peer id.")
}

fn visit_str<E>(self, value: &str) -> Result<PeerId, E>
where
E: serde::de::Error,
{
if value.len() > 100 {
return Err(serde::de::Error::invalid_length(
value.len(),
&"peer id cannot exceed 100 characters.",
));
}

PeerId::from_str(value).map_err(serde::de::Error::custom)
}

fn visit_string<E>(self, value: String) -> Result<PeerId, E>
where
E: serde::de::Error,
{
self.visit_str(&value)
}
}

deserializer.deserialize_str(PeerIdVisitor)
}

fn serialize_peer_id<S>(peer_id: &PeerId, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
s.serialize_str(&peer_id.to_string())
peer_address: PeerAddress,
}

fn deserialize_bytes<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
Expand Down Expand Up @@ -260,17 +281,17 @@ pub async fn peer_connection_healthcheck_rpc(
// This is unrelated to the timeout logic.
let address_record_exp = Duration::from_secs(ctx.health_checker.config.timeout_secs);

let target_peer_id = req.peer_id;
let target_peer_address = req.peer_address;

let p2p_ctx = P2PContext::fetch_from_mm_arc(&ctx);
if target_peer_id == p2p_ctx.peer_id() {
if target_peer_address == p2p_ctx.peer_id().into() {
// That's us, so return true.
return Ok(true);
}

let message = HealthcheckMessage::generate_message(
&ctx,
target_peer_id,
target_peer_address,
false,
ctx.health_checker
.config
Expand All @@ -289,10 +310,15 @@ pub async fn peer_connection_healthcheck_rpc(
{
let mut book = ctx.health_checker.response_handler.lock().unwrap();
book.clear_expired_entries();
book.insert(target_peer_id.to_string(), tx, address_record_exp);
book.insert(target_peer_address.to_string(), tx, address_record_exp);
}

broadcast_p2p_msg(&ctx, peer_healthcheck_topic(&target_peer_id), encoded_message, None);
broadcast_p2p_msg(
&ctx,
peer_healthcheck_topic(&target_peer_address),
encoded_message,
None,
);

let timeout_duration = Duration::from_secs(ctx.health_checker.config.timeout_secs);
Ok(rx.timeout(timeout_duration).await == Ok(Ok(())))
Expand All @@ -316,16 +342,16 @@ pub(crate) async fn process_p2p_healthcheck_message(ctx: &MmArc, message: mm2_li
"Couldn't decode healthcheck message"
);

let sender_peer = data.sender_peer().to_owned();
let sender_peer = data.sender_peer();

let ctx = ctx.clone();

// Pass the remaining work to another thread to free up this one as soon as possible,
// so KDF can handle a high amount of healthcheck messages more efficiently.
ctx.spawner().spawn(async move {
let my_peer_id = P2PContext::fetch_from_mm_arc(&ctx).peer_id();
let my_peer_address = P2PContext::fetch_from_mm_arc(&ctx).peer_id().into();

if !data.is_received_message_valid(my_peer_id, &ctx.health_checker.config) {
if !data.is_received_message_valid(my_peer_address, &ctx.health_checker.config) {
log::error!("Received an invalid healthcheck message.");
log::debug!("Message context: {:?}", data);
return;
Expand Down Expand Up @@ -389,9 +415,9 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
}

fn create_test_peer_id() -> PeerId {
fn create_test_peer_address() -> PeerAddress {
let keypair = mm2_libp2p::Keypair::generate_ed25519();
PeerId::from(keypair.public())
mm2_libp2p::PeerId::from(keypair.public()).into()
}

fn ctx() -> MmArc {
Expand All @@ -410,16 +436,32 @@ mod tests {
ctx
}

cross_test!(test_peer_address, {
#[derive(Deserialize, Serialize)]
struct PeerAddressTest {
peer_address: PeerAddress,
}

let address_str = "12D3KooWEtuv7kmgGCC7oAQ31hB7AR5KkhT3eEWB2bP2roo3M7rY";
let json_content = format!("{{\"peer_address\": \"{address_str}\"}}");
let address_struct: PeerAddressTest = serde_json::from_str(&json_content).unwrap();

let actual_peer_id = mm2_libp2p::PeerId::from_str(address_str).unwrap();
let deserialized_peer_id: mm2_libp2p::PeerId = address_struct.peer_address.into();

assert_eq!(deserialized_peer_id, actual_peer_id);
});

cross_test!(test_valid_message, {
let ctx = ctx();
let target_peer = create_test_peer_id();
let target_peer = create_test_peer_address();
let message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap();
assert!(message.is_received_message_valid(target_peer, &ctx.health_checker.config));
});

cross_test!(test_corrupted_messages, {
let ctx = ctx();
let target_peer = create_test_peer_id();
let target_peer = create_test_peer_address();

let mut message = HealthcheckMessage::generate_message(&ctx, target_peer, false, 5).unwrap();
message.data.expires_at += 1;
Expand All @@ -443,14 +485,14 @@ mod tests {

cross_test!(test_expired_message, {
let ctx = ctx();
let target_peer = create_test_peer_id();
let target_peer = create_test_peer_address();
let message = HealthcheckMessage::generate_message(&ctx, target_peer, false, -1).unwrap();
assert!(!message.is_received_message_valid(target_peer, &ctx.health_checker.config));
});

cross_test!(test_encode_decode, {
let ctx = ctx();
let target_peer = create_test_peer_id();
let target_peer = create_test_peer_address();

let original = HealthcheckMessage::generate_message(&ctx, target_peer, false, 10).unwrap();

Expand Down
4 changes: 2 additions & 2 deletions mm2src/mm2_main/src/lp_native_dex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use mm2_metrics::mm_gauge;
use mm2_net::network_event::NetworkEvent;
use mm2_net::p2p::P2PContext;
use rpc_task::RpcTaskError;
use serde_json::{self as json};
use serde_json as json;
use std::convert::TryInto;
use std::io;
use std::path::PathBuf;
Expand Down Expand Up @@ -646,7 +646,7 @@ pub async fn init_p2p(ctx: MmArc) -> P2PResult<()> {
ctx.spawner().spawn(fut);

// Listen for health check messages.
subscribe_to_topic(&ctx, peer_healthcheck_topic(&peer_id));
subscribe_to_topic(&ctx, peer_healthcheck_topic(&peer_id.into()));

Ok(())
}
Expand Down
Loading

0 comments on commit b4d63f6

Please sign in to comment.