Skip to content

Commit

Permalink
Add a test echo server
Browse files Browse the repository at this point in the history
  • Loading branch information
MOZGIII committed Feb 1, 2024
1 parent 2aba431 commit 8430062
Show file tree
Hide file tree
Showing 18 changed files with 661 additions and 38 deletions.
273 changes: 273 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions crates/xwt-test-assets-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "xwt-test-assets-build"
version = "0.1.0"
edition = "2021"
resolver = "2"
license = "MIT"
description = """
A static assets generation utility.
"""
repository = "https://github.com/MOZGIII/xwt"

[dependencies]
rcgen = "0.12"
tokio = { version = "1", default-features = false, features = ["macros", "rt-multi-thread", "fs"], optional = true }

[features]
default = ["tokio"]
tokio = ["dep:tokio"]
97 changes: 97 additions & 0 deletions crates/xwt-test-assets-build/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::{
io::ErrorKind,
path::{Path, PathBuf},
};

pub fn generate() -> rcgen::Certificate {
rcgen::generate_simple_self_signed(["localhost".into(), "127.0.0.1".into(), "::1".into()])
.unwrap()
}

pub fn state_dir() -> PathBuf {
let mut dir = PathBuf::from(std::env::var_os("CARGO_MANIFEST_DIR").unwrap());
dir.push("assets");
println!("{}", dir.display());
dir
}

#[cfg(feature = "tokio")]
pub async fn save_tokio(certificate: rcgen::Certificate, dir: impl AsRef<Path>) {
use tokio::io::AsyncWriteExt;

let dir = dir.as_ref();

tokio::fs::create_dir_all(dir).await.unwrap();

let open = |file| async move {
tokio::fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(file)
.await
};

let results = (
open(dir.join("cert.der")).await,
open(dir.join("key.der")).await,
);

match results {
(Ok(mut cert_file), Ok(mut key_file)) => {
cert_file
.write_all(&certificate.serialize_der().unwrap())
.await
.unwrap();
key_file
.write_all(&certificate.serialize_private_key_der())
.await
.unwrap();
cert_file.flush().await.unwrap();
key_file.flush().await.unwrap();
}
(Err(cert_err), Err(key_err))
if cert_err.kind() == ErrorKind::AlreadyExists
&& key_err.kind() == ErrorKind::AlreadyExists => {}
(cert_res, key_res) => {
let _ = cert_res.unwrap();
let _ = key_res.unwrap();
}
}
}

pub fn save(certificate: rcgen::Certificate, dir: impl AsRef<Path>) {
use std::io::Write;

let dir = dir.as_ref();

std::fs::create_dir_all(dir).unwrap();

let open = |file| {
std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(file)
};

let results = (open(dir.join("cert.der")), open(dir.join("key.der")));

match results {
(Ok(mut cert_file), Ok(mut key_file)) => {
cert_file
.write_all(&certificate.serialize_der().unwrap())
.unwrap();
key_file
.write_all(&certificate.serialize_private_key_der())
.unwrap();
cert_file.flush().unwrap();
key_file.flush().unwrap();
}
(Err(cert_err), Err(key_err))
if cert_err.kind() == ErrorKind::AlreadyExists
&& key_err.kind() == ErrorKind::AlreadyExists => {}
(cert_res, key_res) => {
let _ = cert_res.unwrap();
let _ = key_res.unwrap();
}
}
}
16 changes: 16 additions & 0 deletions crates/xwt-test-assets/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "xwt-test-assets"
version = "0.1.0"
edition = "2021"
resolver = "2"
license = "MIT"
description = """
Static test assets for xwt.
"""
repository = "https://github.com/MOZGIII/xwt"
private = true

[dependencies]

[build-dependencies]
xwt-test-assets-build = { version = "0.1", path = "../xwt-test-assets-build", default-features = false }
2 changes: 2 additions & 0 deletions crates/xwt-test-assets/assets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
6 changes: 6 additions & 0 deletions crates/xwt-test-assets/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main() {
xwt_test_assets_build::save(
xwt_test_assets_build::generate(),
xwt_test_assets_build::state_dir(),
);
}
2 changes: 2 additions & 0 deletions crates/xwt-test-assets/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub const CERT: &[u8] = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/cert.der"));
pub const KEY: &[u8] = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/key.der"));
22 changes: 22 additions & 0 deletions crates/xwt-test-echo-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "xwt-test-echo-server"
version = "0.2.0"
edition = "2021"
resolver = "2"
license = "MIT"
description = """
The echo server to use for xwt testing.
Not intended to be wasm-compatible.
"""
repository = "https://github.com/MOZGIII/xwt"

[dependencies]
xwt-test-assets = { version = "0.1", path = "../xwt-test-assets" }

color-eyre = "0.6"
envfury = "0.2"
thiserror = "1"
tokio = { version = "1", default-features = false, features = ["macros", "rt-multi-thread"] }
tracing = "0.1"
tracing-subscriber = "0.3"
wtransport = "0.1.8"
154 changes: 154 additions & 0 deletions crates/xwt-test-echo-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
//! The echo server to use for xwt testing.
//! Not intended to be wasm-compatible.

use std::sync::Arc;

#[derive(Default)]
pub struct EndpointParams {
pub addr: Option<std::net::SocketAddr>,
pub cert: Option<wtransport::tls::Certificate>,
}

pub async fn endpoint(
params: EndpointParams,
) -> Result<wtransport::Endpoint<wtransport::endpoint::endpoint_side::Server>, std::io::Error> {
let EndpointParams { addr, cert } = params;

let cert = cert.unwrap_or_else(|| {
wtransport::tls::Certificate::new(
vec![xwt_test_assets::CERT.into()],
xwt_test_assets::KEY.into(),
)
});

let server_config =
wtransport::ServerConfig::builder()
.with_bind_address(addr.unwrap_or(std::net::SocketAddr::V4(
std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 0),
)))
.with_certificate(cert)
.build();

let endpoint = wtransport::Endpoint::server(server_config)?;

Ok(endpoint)
}

pub async fn serve_endpoint(
endpoint: wtransport::Endpoint<wtransport::endpoint::endpoint_side::Server>,
) -> Result<std::convert::Infallible, std::io::Error> {
let bind_addr = endpoint.local_addr()?;
tracing::info!(message = "serving endpoint", %bind_addr);

let mut joinset = tokio::task::JoinSet::new();
loop {
let incoming_session = endpoint.accept().await;
joinset.spawn(async move {
if let Err(error) = serve_incoming_session(incoming_session).await {
tracing::error!(message = "error while serving incoming session", %error);
}
});
}
}

pub async fn serve_incoming_session(
incoming_session: wtransport::endpoint::IncomingSession,
) -> Result<(), wtransport::error::ConnectionError> {
tracing::info!(message = "got an incoming session");

let session_request = incoming_session.await?;

tracing::info!(message = "accepting incoming session");

let connection = session_request.accept().await?;

tracing::info!(message = "new connection accepted");

let connection = Arc::new(connection);

let mut joinset = tokio::task::JoinSet::new();

{
let connection = Arc::clone(&connection);
joinset.spawn(async move {
if let Err(error) = serve_streams(connection).await {
tracing::error!(message = "error while serving streams", %error);
}
});
}
{
let connection = Arc::clone(&connection);
joinset.spawn(async move {
if let Err(error) = serve_datagrams(connection).await {
tracing::error!(message = "error while serving datagrams", %error);
}
});
}

connection.closed().await;

tracing::info!(message = "connection is closing");

Ok(())
}

pub async fn serve_streams(
connection: impl AsRef<wtransport::Connection>,
) -> Result<std::convert::Infallible, wtransport::error::ConnectionError> {
let connection = connection.as_ref();
let mut joinset = tokio::task::JoinSet::new();
loop {
let stream = connection.accept_bi().await?;
joinset.spawn(async move {
if let Err(error) = serve_stream(stream).await {
tracing::error!(message = "error while serving stream", %error);
}
});
}
}

#[derive(Debug, thiserror::Error)]
pub enum StreamError {
#[error("read: {0}")]
Read(wtransport::error::StreamReadError),
#[error("write: {0}")]
Write(wtransport::error::StreamWriteError),
}

pub async fn serve_stream(
stream: (wtransport::SendStream, wtransport::RecvStream),
) -> Result<(), StreamError> {
let (mut tx, mut rx) = stream;
let mut buf = vec![0; 1024];
loop {
let Some(len) = rx.read(&mut buf).await.map_err(StreamError::Read)? else {
return Ok(());
};
tx.write_all(&buf[..len])
.await
.map_err(StreamError::Write)?;
}
}

#[derive(Debug, thiserror::Error)]
pub enum DatagramError {
#[error("receive: {0}")]
Receive(wtransport::error::ConnectionError),
#[error("send: {0}")]
Send(wtransport::error::SendDatagramError),
}

pub async fn serve_datagrams(
connection: impl AsRef<wtransport::Connection>,
) -> Result<(), DatagramError> {
let connection = connection.as_ref();
loop {
let datagram = connection
.receive_datagram()
.await
.map_err(DatagramError::Receive)?;
connection
.send_datagram(datagram.payload())
.map_err(DatagramError::Send)?;
}
}
15 changes: 15 additions & 0 deletions crates/xwt-test-echo-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#[tokio::main]
async fn main() -> color_eyre::eyre::Result<()> {
tracing_subscriber::fmt::init();
color_eyre::install()?;

let addr = envfury::or_parse("ADDR", "127.0.0.1:8080")?;
let endpoint = xwt_test_echo_server::endpoint(xwt_test_echo_server::EndpointParams {
addr: Some(addr),
cert: None,
})
.await?;
xwt_test_echo_server::serve_endpoint(endpoint).await?;

Ok(())
}
1 change: 1 addition & 0 deletions crates/xwt-tests/src/consts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub const ECHO_SERVER_URL: &str = "https://127.0.0.1:8080";
1 change: 1 addition & 0 deletions crates/xwt-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod consts;
pub mod tests;
pub mod utils;
14 changes: 9 additions & 5 deletions crates/xwt-tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ where
BadData(Vec<u8>),
}

pub async fn echo<Endpoint>(endpoint: Endpoint) -> Result<(), EchoError<Endpoint>>
pub async fn echo<Endpoint>(endpoint: Endpoint, url: &str) -> Result<(), EchoError<Endpoint>>
where
Endpoint: xwt_core::EndpointConnect + std::fmt::Debug,
Endpoint::Connecting: std::fmt::Debug,
EndpointConnectConnectionFor<Endpoint>: xwt_core::OpenBiStream + std::fmt::Debug,
{
let connection = crate::utils::connect(endpoint, "https://echo.webtransport.day")
let connection = crate::utils::connect(endpoint, url)
.await
.map_err(EchoError::Connect)?;

Expand Down Expand Up @@ -81,6 +81,7 @@ where

pub async fn echo_chunks<Endpoint, WriteChunk, ReadChunk>(
endpoint: Endpoint,
url: &str,
) -> Result<(), EchoChunksError<Endpoint, WriteChunk, ReadChunk>>
where
Endpoint: xwt_core::EndpointConnect + std::fmt::Debug,
Expand All @@ -96,7 +97,7 @@ where
SendStreamFor<EndpointConnectConnectionFor<Endpoint>>: xwt_core::WriteChunk<WriteChunk>,
RecvStreamFor<EndpointConnectConnectionFor<Endpoint>>: xwt_core::ReadChunk<ReadChunk>,
{
let connection = crate::utils::connect(endpoint, "https://echo.webtransport.day")
let connection = crate::utils::connect(endpoint, url)
.await
.map_err(EchoChunksError::Connect)?;

Expand Down Expand Up @@ -140,14 +141,17 @@ where
BadData(ReceiveDatagramFor<EndpointConnectConnectionFor<Endpoint>>),
}

pub async fn echo_datagrams<Endpoint>(endpoint: Endpoint) -> Result<(), EchoDatagrmsError<Endpoint>>
pub async fn echo_datagrams<Endpoint>(
endpoint: Endpoint,
url: &str,
) -> Result<(), EchoDatagrmsError<Endpoint>>
where
Endpoint: xwt_core::EndpointConnect + std::fmt::Debug,
Endpoint::Connecting: std::fmt::Debug,
EndpointConnectConnectionFor<Endpoint>: xwt_core::datagram::Datagrams + std::fmt::Debug,
ReceiveDatagramFor<EndpointConnectConnectionFor<Endpoint>>: std::fmt::Debug,
{
let connection = crate::utils::connect(endpoint, "https://echo.webtransport.day")
let connection = crate::utils::connect(endpoint, url)
.await
.map_err(EchoDatagrmsError::Connect)?;

Expand Down
Loading

0 comments on commit 8430062

Please sign in to comment.