From 1ba59c035177c54777098e7a330368784e04684b Mon Sep 17 00:00:00 2001 From: Chrislearn Young Date: Fri, 11 Aug 2023 21:13:50 +0800 Subject: [PATCH] feat: Allow custom url path and query getter (#364) * feat: Allow custom url rest getter * public encode_url_path * Format Rust code using rustfmt * split to path and query getter * use encode_url_path * encode_url_path --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- crates/proxy/src/lib.rs | 74 ++++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/crates/proxy/src/lib.rs b/crates/proxy/src/lib.rs index 6b8769ed8..937503d09 100644 --- a/crates/proxy/src/lib.rs +++ b/crates/proxy/src/lib.rs @@ -24,6 +24,7 @@ use tokio::io::copy_bidirectional; type HyperRequest = hyper::Request; type HyperResponse = hyper::Response; +/// Encode url path. This can be used when build your custom url path getter. #[inline] pub(crate) fn encode_url_path(path: &str) -> String { path.split('/') @@ -78,6 +79,23 @@ where } } +/// Url part getter. You can use this to get the proxied url path or query. +pub type UrlPartGetter = Box Option + Send + Sync + 'static>; + +/// Default url path getter. This getter will get the url path from request wildcard param, like `<*rest>`, `<**rest>`. +pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option { + let param = req.params().iter().find(|(key, _)| key.starts_with('*')); + if let Some((_, rest)) = param { + Some(encode_url_path(rest)) + } else { + None + } +} +/// Default url query getter. This getter just return the query string from request uri. +pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option { + req.uri().query().map(Into::into) +} + /// Proxy #[non_exhaustive] @@ -86,6 +104,10 @@ pub struct Proxy { pub upstreams: U, /// [`Client`] for proxy. pub client: Client, + /// Url path getter. + pub url_path_getter: UrlPartGetter, + /// Url query getter. + pub url_query_getter: UrlPartGetter, } impl Proxy @@ -98,11 +120,38 @@ where Proxy { upstreams, client: Client::new(), + url_path_getter: Box::new(default_url_path_getter), + url_query_getter: Box::new(default_url_query_getter), } } /// Create new `Proxy` with upstreams list and [`Client`]. pub fn with_client(upstreams: U, client: Client) -> Self { - Proxy { upstreams, client } + Proxy { + upstreams, + client, + url_path_getter: Box::new(default_url_path_getter), + url_query_getter: Box::new(default_url_query_getter), + } + } + + /// Set url path getter. + #[inline] + pub fn url_path_getter(mut self, url_path_getter: G) -> Self + where + G: Fn(&Request, &Depot) -> Option + Send + Sync + 'static, + { + self.url_path_getter = Box::new(url_path_getter); + self + } + + /// Set url query getter. + #[inline] + pub fn url_query_getter(mut self, url_query_getter: G) -> Self + where + G: Fn(&Request, &Depot) -> Option + Send + Sync + 'static, + { + self.url_query_getter = Box::new(url_query_getter); + self } /// Get upstreams list. @@ -128,23 +177,24 @@ where } #[inline] - fn build_proxied_request(&self, req: &mut Request) -> Result { + fn build_proxied_request(&self, req: &mut Request, depot: &Depot) -> Result { let upstream = self.upstreams.elect().map_err(Error::other)?; if upstream.is_empty() { tracing::error!("upstreams is empty"); return Err(Error::other("upstreams is empty")); } - let param = req.params().iter().find(|(key, _)| key.starts_with('*')); - let mut rest = if let Some((_, rest)) = param { - encode_url_path(rest) + let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default()); + let query = (self.url_query_getter)(req, depot); + let rest = if let Some(query) = query { + if query.starts_with('?') { + format!("{}{}", path, query) + } else { + format!("{}?{}", path, query) + } } else { - "".into() + path }; - if let Some(query) = req.uri().query() { - rest = format!("{}?{}", rest, query); - } - let forward_url = if upstream.ends_with('/') && rest.starts_with('/') { format!("{}{}", upstream.trim_end_matches('/'), rest) } else if upstream.ends_with('/') || rest.starts_with('/') { @@ -252,8 +302,8 @@ where U::Error: Into, { #[inline] - async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { - match self.build_proxied_request(req) { + async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { + match self.build_proxied_request(req, depot) { Ok(proxied_request) => { match self .call_proxied_server(proxied_request, req.extensions_mut().remove())