Skip to content

Commit

Permalink
refactor(sync): combine server channels
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanm-starkware committed Jul 24, 2024
1 parent 3bcfcc6 commit 6a41137
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 341 deletions.
81 changes: 29 additions & 52 deletions crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use futures::channel::mpsc::{Receiver, SendError, Sender};
use futures::channel::oneshot;
use futures::future::{ready, BoxFuture, Ready};
use futures::sink::With;
use futures::stream::{self, BoxStream, FuturesUnordered, Map, Stream};
use futures::stream::{self, FuturesUnordered, Map, Stream};
use futures::{pin_mut, FutureExt, Sink, SinkExt, StreamExt};
use libp2p::gossipsub::{SubscriptionError, TopicHash};
use libp2p::swarm::SwarmEvent;
Expand All @@ -36,9 +36,8 @@ pub enum NetworkError {
pub struct GenericNetworkManager<SwarmT: SwarmTrait> {
swarm: SwarmT,
inbound_protocol_to_buffer_size: HashMap<StreamProtocol, usize>,
sqmr_inbound_response_receivers:
StreamHashMap<InboundSessionId, BoxStream<'static, Option<Bytes>>>,
sqmr_inbound_query_senders: HashMap<StreamProtocol, Sender<(Bytes, Sender<Bytes>)>>,
sqmr_inbound_response_receivers: StreamHashMap<InboundSessionId, ResponsesReceiverForNetwork>,
sqmr_inbound_payload_senders: HashMap<StreamProtocol, SqmrServerSender>,

sqmr_outbound_payload_receivers: StreamHashMap<StreamProtocol, SqmrClientReceiver>,
sqmr_outbound_response_senders: HashMap<OutboundSessionId, ResponsesSenderForNetwork>,
Expand Down Expand Up @@ -79,7 +78,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
swarm,
inbound_protocol_to_buffer_size: HashMap::new(),
sqmr_inbound_response_receivers: StreamHashMap::new(HashMap::new()),
sqmr_inbound_query_senders: HashMap::new(),
sqmr_inbound_payload_senders: HashMap::new(),
sqmr_outbound_payload_receivers: StreamHashMap::new(HashMap::new()),
sqmr_outbound_response_senders: HashMap::new(),
sqmr_outbound_report_receivers: HashMap::new(),
Expand All @@ -96,10 +95,11 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
&mut self,
protocol: String,
buffer_size: usize,
) -> SqmrQueryReceiver<Query, Response>
) -> SqmrServerReceiver<Query, Response>
where
Bytes: From<Response>,
Query: TryFrom<Bytes>,
Response: 'static,
{
let protocol = StreamProtocol::try_from_owned(protocol)
.expect("Could not parse protocol into StreamProtocol.");
Expand All @@ -109,19 +109,18 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
{
panic!("Protocol '{}' has already been registered as a server.", protocol);
}
let (inbound_query_sender, inbound_query_receiver) =
let (inbound_payload_sender, inbound_payload_receiver) =
futures::channel::mpsc::channel(buffer_size);
let result = self.sqmr_inbound_query_senders.insert(protocol.clone(), inbound_query_sender);
if result.is_some() {
let insert_result = self
.sqmr_inbound_payload_senders
.insert(protocol.clone(), Box::new(inbound_payload_sender));
if insert_result.is_some() {
panic!("Protocol '{}' has already been registered as a server.", protocol);
}

inbound_query_receiver.map(|(query_bytes, response_bytes_sender)| {
(
Query::try_from(query_bytes),
response_bytes_sender.with(|response| ready(Ok(Bytes::from(response)))),
)
})
let inbound_payload_receiver = inbound_payload_receiver
.map(|payload: SqmrServerPayloadForNetwork| SqmrServerPayload::from(payload));
Box::new(inbound_payload_receiver)
}

/// TODO: Support multiple protocols where they're all different versions of the same protocol
Expand Down Expand Up @@ -155,8 +154,6 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
};
let payload_sender = payload_sender.with(payload_fn);

// let response_fn: ReceivedMessagesConverterFn<Response> = |x| Response::try_from(x);

Box::new(payload_sender)
}

Expand Down Expand Up @@ -310,7 +307,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
sqmr::behaviour::ExternalEvent::NewInboundSession {
query,
inbound_session_id,
peer_id: _,
peer_id,
protocol_name,
} => {
info!(
Expand All @@ -321,30 +318,33 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
papyrus_metrics::PAPYRUS_NUM_ACTIVE_INBOUND_SESSIONS,
self.num_active_inbound_sessions as f64
);
let (report_sender, report_receiver) = oneshot::channel::<()>();
self.handle_new_report_receiver(peer_id, report_receiver);
// TODO: consider returning error instead of panic.
let Some(query_sender) = self.sqmr_inbound_query_senders.get_mut(&protocol_name)
let Some(query_sender) = self.sqmr_inbound_payload_senders.get_mut(&protocol_name)
else {
return;
};
let (response_sender, response_receiver) = futures::channel::mpsc::channel(
let (responses_sender, responses_receiver) = futures::channel::mpsc::channel(
*self.inbound_protocol_to_buffer_size.get(&protocol_name).expect(
"A protocol is registered in NetworkManager but it has no buffer size.",
),
);
let responses_sender = Box::new(responses_sender);
self.sqmr_inbound_response_receivers.insert(
inbound_session_id,
Box::new(responses_receiver.map(Some).chain(stream::once(ready(None)))),
);

// TODO(shahak): Close the inbound session if the buffer is full.
server_send_now(
send_now(
query_sender,
(query, response_sender),
SqmrServerPayloadForNetwork { query, report_sender, responses_sender },
format!(
"Received an inbound query while the buffer is full. Dropping query for \
session {inbound_session_id:?}"
),
);
self.sqmr_inbound_response_receivers.insert(
inbound_session_id,
response_receiver.map(Some).chain(stream::once(ready(None))).boxed(),
);
}
sqmr::behaviour::ExternalEvent::ReceivedResponse {
outbound_session_id,
Expand All @@ -360,7 +360,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
self.sqmr_outbound_response_senders.get_mut(&outbound_session_id)
{
// TODO(shahak): Close the channel if the buffer is full.
network_send_now(
send_now(
response_sender,
response,
format!(
Expand Down Expand Up @@ -506,11 +506,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
}
}

fn network_send_now<Item>(
sender: &mut GenericSender<Item>,
item: Item,
buffer_full_message: String,
) {
fn send_now<Item>(sender: &mut GenericSender<Item>, item: Item, buffer_full_message: String) {
pin_mut!(sender);
match sender.as_mut().send(item).now_or_never() {
Some(Ok(())) => {}
Expand All @@ -523,17 +519,6 @@ fn network_send_now<Item>(
}
}

fn server_send_now<Item>(sender: &mut Sender<Item>, item: Item, buffer_full_message: String) {
if let Err(error) = sender.try_send(item) {
if error.is_disconnected() {
panic!("Receiver was dropped. This should never happen.")
} else if error.is_full() {
// TODO(shahak): Consider doing something else rather than dropping the message.
error!(buffer_full_message);
}
}
}

pub type NetworkManager = GenericNetworkManager<Swarm<mixed_behaviour::MixedBehaviour>>;

impl NetworkManager {
Expand Down Expand Up @@ -625,6 +610,7 @@ type GenericSender<T> = Box<dyn Sink<T, Error = SendError> + Unpin + Send>;
type GenericReceiver<T> = Box<dyn Stream<Item = T> + Unpin + Send>;

type ResponsesSenderForNetwork = GenericSender<Bytes>;
type ResponsesReceiverForNetwork = GenericReceiver<Option<Bytes>>;
type ResponsesSender<Response> =
GenericSender<Result<Response, <Response as TryFrom<Bytes>>::Error>>;

Expand Down Expand Up @@ -696,15 +682,6 @@ where
}
}

pub type SqmrQueryReceiver<Query, Response> =
Map<Receiver<(Bytes, Sender<Bytes>)>, ReceivedQueryConverterFn<Query, Response>>;

type ReceivedQueryConverterFn<Query, Response> =
fn(
(Bytes, Sender<Bytes>),
)
-> (Result<Query, <Query as TryFrom<Bytes>>::Error>, BroadcastSubscriberSender<Response>);

// TODO(eitan): improve naming of final channel types
pub type BroadcastSubscriberSender<T> = With<
Sender<Bytes>,
Expand Down
6 changes: 3 additions & 3 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use super::swarm_trait::{Event, SwarmTrait};
use super::GenericNetworkManager;
use crate::gossipsub_impl::{self, Topic};
use crate::mixed_behaviour;
use crate::network_manager::SqmrClientPayload;
use crate::network_manager::{SqmrClientPayload, SqmrServerPayload};
use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError};
use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId};

Expand Down Expand Up @@ -283,7 +283,7 @@ async fn process_incoming_query() {

let mut network_manager = GenericNetworkManager::generic_new(mock_swarm);

let mut inbound_query_receiver = network_manager
let mut inbound_payload_receiver = network_manager
.register_sqmr_protocol_server::<Vec<u8>, Vec<u8>>(protocol.to_string(), BUFFER_SIZE);

let actual_protocol = get_supported_inbound_protocol_fut.next().await.unwrap();
Expand All @@ -292,7 +292,7 @@ async fn process_incoming_query() {
let responses_clone = responses.clone();
select! {
_ = async move {
let (query_got, mut responses_sender) = inbound_query_receiver.next().await.unwrap();
let SqmrServerPayload{query: query_got, report_sender: _report_sender, mut responses_sender} = inbound_payload_receiver.next().await.unwrap();
assert_eq!(query_got.unwrap(), query);
for response in responses_clone {
responses_sender.feed(response).await.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_network/src/peer_manager/behaviour_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ where
}
let res = self.report_peer(peer_id, super::ReputationModifier::Bad);
if res.is_err() {
error!("Dial failure of an unknow peer. peer id: {}", peer_id)
error!("Dial failure of an unknown peer. peer id: {}", peer_id)
}
// Re-assign a peer to the session so that a SessionAssgined Event will be emitted.
// TODO: test this case
Expand Down
78 changes: 23 additions & 55 deletions crates/papyrus_node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use futures::future::BoxFuture;
use futures::FutureExt;
use papyrus_base_layer::ethereum_base_layer_contract::EthereumBaseLayerConfig;
use papyrus_common::metrics::COLLECT_PROFILING_METRICS;
use papyrus_common::pending_classes::{ApiContractClass, PendingClasses};
use papyrus_common::pending_classes::PendingClasses;
use papyrus_common::BlockHashAndNumber;
use papyrus_config::presentation::get_config_presentation;
use papyrus_config::validators::config_validate;
Expand All @@ -21,33 +21,19 @@ use papyrus_consensus::papyrus_consensus_context::PapyrusConsensusContext;
use papyrus_consensus::types::ConsensusError;
use papyrus_monitoring_gateway::MonitoringServer;
use papyrus_network::gossipsub_impl::Topic;
use papyrus_network::network_manager::{
BroadcastSubscriberChannels,
NetworkError,
SqmrQueryReceiver,
};
use papyrus_network::network_manager::{BroadcastSubscriberChannels, NetworkError};
use papyrus_network::{network_manager, NetworkConfig};
use papyrus_node::config::NodeConfig;
use papyrus_node::version::VERSION_FULL;
use papyrus_p2p_sync::client::{
P2PClientSyncError,
P2PSyncClient,
P2PSyncClientChannels,
P2PSyncClientConfig,
P2PSyncError,
};
use papyrus_p2p_sync::server::P2PSyncServer;
use papyrus_p2p_sync::server::{P2PSyncServer, P2PSyncServerChannels};
use papyrus_p2p_sync::{Protocol, BUFFER_SIZE};
use papyrus_protobuf::consensus::ConsensusMessage;
use papyrus_protobuf::sync::{
ClassQuery,
DataOrFin,
EventQuery,
HeaderQuery,
SignedBlockHeader,
StateDiffChunk,
StateDiffQuery,
TransactionQuery,
};
#[cfg(feature = "rpc")]
use papyrus_rpc::run_server;
use papyrus_storage::{open_storage, update_storage_metrics, StorageReader, StorageWriter};
Expand All @@ -57,7 +43,6 @@ use papyrus_sync::sources::pending::PendingSource;
use papyrus_sync::{StateSync, StateSyncError, SyncConfig};
use starknet_api::block::BlockHash;
use starknet_api::felt;
use starknet_api::transaction::{Event, Transaction, TransactionHash, TransactionOutput};
use starknet_client::reader::objects::pending_data::{PendingBlock, PendingBlockOrDeprecated};
use starknet_client::reader::PendingData;
use tokio::sync::RwLock;
Expand Down Expand Up @@ -187,21 +172,9 @@ async fn run_threads(config: NodeConfig) -> anyhow::Result<()> {

// P2P Sync Server task.
let p2p_sync_server_future = match maybe_sync_server_channels {
Some((
header_server_channel,
state_diff_server_channel,
transaction_server_channel,
class_server_channel,
event_server_channel,
)) => {
let p2p_sync_server = P2PSyncServer::new(
storage_reader.clone(),
header_server_channel,
state_diff_server_channel,
transaction_server_channel,
class_server_channel,
event_server_channel,
);
Some(p2p_sync_server_channels) => {
let p2p_sync_server =
P2PSyncServer::new(storage_reader.clone(), p2p_sync_server_channels);
p2p_sync_server.run().boxed()
}
None => pending().boxed(),
Expand Down Expand Up @@ -321,7 +294,7 @@ async fn run_threads(config: NodeConfig) -> anyhow::Result<()> {
storage_reader: StorageReader,
storage_writer: StorageWriter,
p2p_sync_client_channels: P2PSyncClientChannels,
) -> Result<(), P2PSyncError> {
) -> Result<(), P2PClientSyncError> {
let p2p_sync = P2PSyncClient::new(
p2p_sync_client_config,
storage_reader,
Expand All @@ -335,13 +308,7 @@ async fn run_threads(config: NodeConfig) -> anyhow::Result<()> {
type NetworkRunReturn = (
BoxFuture<'static, Result<(), NetworkError>>,
Option<P2PSyncClientChannels>,
Option<(
SqmrQueryReceiver<HeaderQuery, DataOrFin<SignedBlockHeader>>,
SqmrQueryReceiver<StateDiffQuery, DataOrFin<StateDiffChunk>>,
SqmrQueryReceiver<TransactionQuery, DataOrFin<(Transaction, TransactionOutput)>>,
SqmrQueryReceiver<ClassQuery, DataOrFin<ApiContractClass>>,
SqmrQueryReceiver<EventQuery, DataOrFin<(Event, TransactionHash)>>,
)>,
Option<P2PSyncServerChannels>,
Option<BroadcastSubscriberChannels<ConsensusMessage>>,
String,
);
Expand Down Expand Up @@ -380,22 +347,23 @@ fn run_network(
),
None => None,
};
let p2p_sync_channels = P2PSyncClientChannels {
header_payload_sender: header_client_sender,
state_diff_payload_sender: state_diff_client_sender,
transaction_payload_sender: transaction_client_sender,
};
let p2p_sync_client_channels = P2PSyncClientChannels::new(
header_client_sender,
state_diff_client_sender,
transaction_client_sender,
);
let p2p_sync_server_channels = P2PSyncServerChannels::new(
header_server_channel,
state_diff_server_channel,
transaction_server_channel,
class_server_channel,
event_server_channel,
);

Ok((
network_manager.run().boxed(),
Some(p2p_sync_channels),
Some((
header_server_channel,
state_diff_server_channel,
transaction_server_channel,
class_server_channel,
event_server_channel,
)),
Some(p2p_sync_client_channels),
Some(p2p_sync_server_channels),
consensus_channels,
local_peer_id,
))
Expand Down
Loading

0 comments on commit 6a41137

Please sign in to comment.