Skip to content

Commit

Permalink
Add ServerHandle for stop server
Browse files Browse the repository at this point in the history
  • Loading branch information
astoring committed Dec 5, 2023
1 parent ab16556 commit 8f15146
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 200 deletions.
14 changes: 3 additions & 11 deletions crates/core/src/conn/joined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::time::Duration;

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::sync::CancellationToken;

use crate::async_trait;
use crate::conn::{Holding, HttpBuilder};
Expand Down Expand Up @@ -124,18 +123,11 @@ where
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
match self {
JoinedStream::A(a) => {
a.serve(handler, builder, server_shutdown_token, idle_connection_timeout)
.await
}
JoinedStream::B(b) => {
b.serve(handler, builder, server_shutdown_token, idle_connection_timeout)
.await
}
JoinedStream::A(a) => a.serve(handler, builder, idle_timeout).await,
JoinedStream::B(b) => b.serve(handler, builder, idle_timeout).await,
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions crates/core/src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ cfg_feature! {

use tokio_rustls::server::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;

use crate::async_trait;
use crate::service::HyperHandler;
Expand All @@ -89,10 +88,9 @@ cfg_feature! {
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
async fn serve(self, handler: HyperHandler, builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>) -> IoResult<()> {
idle_timeout: Option<Duration>) -> IoResult<()> {
builder
.serve_connection(self, handler, server_shutdown_token, idle_connection_timeout)
.serve_connection(self, handler, idle_timeout)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
}
Expand Down
6 changes: 2 additions & 4 deletions crates/core/src/conn/native_tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use futures_util::task::noop_waker_ref;
use http::uri::Scheme;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_native_tls::TlsStream;
use tokio_util::sync::CancellationToken;

use crate::async_trait;
use crate::conn::{Accepted, Acceptor, Holding, HttpBuilder, IntoConfigStream, Listener};
Expand Down Expand Up @@ -72,11 +71,10 @@ where
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
builder
.serve_connection(self, handler, server_shutdown_token, idle_connection_timeout)
.serve_connection(self, handler, idle_timeout)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
}
Expand Down
6 changes: 2 additions & 4 deletions crates/core/src/conn/openssl/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use openssl::ssl::{Ssl, SslAcceptor};
use tokio::io::ErrorKind;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_openssl::SslStream;
use tokio_util::sync::CancellationToken;

use super::SslAcceptorBuilder;

Expand Down Expand Up @@ -122,11 +121,10 @@ where
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
builder
.serve_connection(self, handler, server_shutdown_token, idle_connection_timeout)
.serve_connection(self, handler, idle_timeout)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
}
Expand Down
35 changes: 8 additions & 27 deletions crates/core/src/conn/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use http::{Request, Response, Version};
use hyper::service::Service;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::{oneshot, Notify};
use tokio::sync::Notify;
use tokio_util::either::Either;
use tokio_util::sync::CancellationToken;

Expand Down Expand Up @@ -51,8 +51,7 @@ impl HttpBuilder {
&self,
socket: I,
#[allow(unused_variables)] service: S,
#[allow(unused_variables)] server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> Result<()>
where
S: Service<Request<HyperBody>, Response = Response<B>> + Send,
Expand All @@ -73,7 +72,7 @@ impl HttpBuilder {
#[cfg(all(not(feature = "http1"), feature = "http2"))]
let version = Version::HTTP_2;
#[allow(unused_variables)]
let socket = match idle_connection_timeout {
let socket = match idle_timeout {
Some(timeout) => Either::Left(ClosingInactiveConnection::new(socket, timeout, {
let conn_shutdown_token = conn_shutdown_token.clone();

Expand Down Expand Up @@ -106,7 +105,6 @@ impl HttpBuilder {
_ = conn_shutdown_token.cancelled() => {
tracing::info!("closing connection due to inactivity");
}
_ = server_shutdown_token.cancelled() => {}
}

// Init graceful shutdown for connection (`GOAWAY` for `HTTP/2` or disabling `keep-alive` for `HTTP/1`)
Expand All @@ -128,7 +126,6 @@ impl HttpBuilder {
_ = conn_shutdown_token.cancelled() => {
tracing::info!("closing connection due to inactivity");
}
_ = server_shutdown_token.cancelled() => {}
}

// Init graceful shutdown for connection (`GOAWAY` for `HTTP/2` or disabling `keep-alive` for `HTTP/1`)
Expand All @@ -152,7 +149,6 @@ struct ClosingInactiveConnection<T> {
#[pin]
alive: Arc<Notify>,
timeout: Duration,
stop_tx: oneshot::Sender<()>,
}

impl<T> AsyncRead for ClosingInactiveConnection<T>
Expand Down Expand Up @@ -213,33 +209,18 @@ impl<T> ClosingInactiveConnection<T> {
Fut: Future + Send + 'static,
{
let alive = Arc::new(Notify::new());
let (stop_tx, stop_rx) = oneshot::channel();
tokio::spawn({
let alive = alive.clone();

async move {
let check_timeout = async {
loop {
match tokio::time::timeout(timeout, alive.notified()).await {
Ok(()) => {}
Err(_) => {
f().await;
}
}
loop {
if tokio::time::timeout(timeout, alive.notified()).await.is_err() {
f().await;
break;
}
};
tokio::select! {
_ = stop_rx => {},
_ = check_timeout => {}
}
}
});
Self {
inner,
alive,
timeout,
stop_tx,
}
Self { inner, alive, timeout }
}
}

Expand Down
112 changes: 74 additions & 38 deletions crates/core/src/conn/quinn/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;

Expand All @@ -11,6 +12,7 @@ use futures_util::Stream;
use salvo_http3::error::ErrorLevel;
use salvo_http3::ext::Protocol;
use salvo_http3::server::RequestStream;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;

use crate::http::body::{H3ReqBody, ReqBody};
Expand Down Expand Up @@ -49,57 +51,91 @@ impl Builder {
}
}

macro_rules! process_accepted {
($conn:expr, $accepted:expr, $hyper_handler:expr) => {
match $accepted {
Ok(Some((request, stream))) => {
tracing::debug!("new request: {:#?}", request);
let hyper_handler = $hyper_handler.clone();
match request.method() {
&Method::CONNECT
if request.extensions().get::<Protocol>() == Some(&Protocol::WEB_TRANSPORT) =>
{
if let Some(c) = process_web_transport($conn, request, stream, hyper_handler).await? {
$conn = c;
} else {
return Ok(());
}
}
_ => {
tokio::spawn(async move {
match process_request(request, stream, hyper_handler).await {
Ok(_) => {}
Err(e) => {
tracing::error!(error = ?e, "process request failed")
}
}
});
}
}
}
Ok(None) => {
break;
}
Err(e) => {
tracing::warn!(error = ?e, "accept failed");
match e.get_error_level() {
ErrorLevel::ConnectionError => break,
ErrorLevel::StreamError => continue,
}
}
}
}
}
impl Builder {
/// Serve HTTP3 connection.
pub async fn serve_connection(
&self,
conn: crate::conn::quinn::H3Connection,
hyper_handler: crate::service::HyperHandler,
_server_shutdown_token: CancellationToken, //TODO
_idle_connection_timeout: Option<Duration>, //TODO
idle_timeout: Option<Duration>, //TODO
) -> IoResult<()> {
let mut conn = self
.0
.build::<salvo_http3::http3_quinn::Connection, bytes::Bytes>(conn.into_inner())
.await
.map_err(|e| IoError::new(ErrorKind::Other, format!("invalid connection: {}", e)))?;

loop {
match conn.accept().await {
Ok(Some((request, stream))) => {
tracing::debug!("new request: {:#?}", request);
let hyper_handler = hyper_handler.clone();
match request.method() {
&Method::CONNECT
if request.extensions().get::<Protocol>() == Some(&Protocol::WEB_TRANSPORT) =>
{
if let Some(c) = process_web_transport(conn, request, stream, hyper_handler).await? {
conn = c;
} else {
return Ok(());
if let Some(idle_timeout) = idle_timeout {
let conn_shutdown_token = CancellationToken::new();
let alive = Arc::new(Notify::new());
tokio::spawn({
let alive = alive.clone();
let conn_shutdown_token = conn_shutdown_token.clone();
async move {
loop {
let timeout = tokio::time::timeout(idle_timeout, alive.notified());
if timeout.await.is_err() {
conn_shutdown_token.cancel();
break;
}
}
_ => {
tokio::spawn(async move {
match process_request(request, stream, hyper_handler).await {
Ok(_) => {}
Err(e) => {
tracing::error!(error = ?e, "process request failed")
}
}
});
}
}
}
Ok(None) => {
break;
}
Err(e) => {
tracing::warn!(error = ?e, "accept failed");
match e.get_error_level() {
ErrorLevel::ConnectionError => break,
ErrorLevel::StreamError => continue,
});
tokio::select! {
accepted = conn.accept() => {
alive.notify_waiters();
process_accepted!(conn, accepted, hyper_handler);
}
_ = conn_shutdown_token.cancelled() => {
tracing::info!("closing http3 connection due to inactivity");
break;
}
}
} else {
let accpeted = conn.accept().await;
process_accepted!(conn, accpeted, hyper_handler);
}
}
Ok(())
Expand Down Expand Up @@ -175,9 +211,9 @@ async fn process_web_transport(
if let Err(e) = stream.send_data(frame.into_data().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send data to connection peer");
}
} else if let Err(e) = stream.send_trailers(frame.into_trailers().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
} else if let Err(e) = stream.send_trailers(frame.into_trailers().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
}
Err(e) => {
tracing::error!(error = ?e, "unable to poll data from connection");
Expand Down Expand Up @@ -230,8 +266,8 @@ where
tracing::error!(error = ?e, "unable to send data to connection peer");
}
} else if let Err(e) = tx.send_trailers(frame.into_trailers().unwrap_or_default()).await {
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
tracing::error!(error = ?e, "unable to send trailers to connection peer");
}
}
Err(e) => {
tracing::error!(error = ?e, "unable to poll data from connection");
Expand Down
6 changes: 2 additions & 4 deletions crates/core/src/conn/quinn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use futures_util::Stream;
use salvo_http3::http3_quinn;
pub use salvo_http3::http3_quinn::ServerConfig;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::sync::CancellationToken;

use crate::async_trait;
use crate::conn::rustls::RustlsConfig;
Expand Down Expand Up @@ -78,12 +77,11 @@ impl HttpConnection for H3Connection {
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
builder
.quinn
.serve_connection(self, handler, server_shutdown_token, idle_connection_timeout)
.serve_connection(self, handler, idle_timeout)
.await
}
}
Expand Down
6 changes: 2 additions & 4 deletions crates/core/src/conn/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::time::Duration;
use std::vec;

use tokio::net::{TcpListener as TokioTcpListener, TcpStream, ToSocketAddrs};
use tokio_util::sync::CancellationToken;

use crate::async_trait;
use crate::conn::{Holding, HttpBuilder};
Expand Down Expand Up @@ -135,11 +134,10 @@ impl HttpConnection for TcpStream {
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
server_shutdown_token: CancellationToken,
idle_connection_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
builder
.serve_connection(self, handler, server_shutdown_token, idle_connection_timeout)
.serve_connection(self, handler, idle_timeout)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
}
Expand Down
Loading

0 comments on commit 8f15146

Please sign in to comment.