Skip to content

Commit

Permalink
feat: Allow custom url path and query getter (#364)
Browse files Browse the repository at this point in the history
* feat: Allow custom url rest getter

* public encode_url_path

* Format Rust code using rustfmt

* split to path and query getter

* use encode_url_path

* encode_url_path

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
chrislearn and github-actions[bot] authored Aug 11, 2023
1 parent 9def4b2 commit 1ba59c0
Showing 1 changed file with 62 additions and 12 deletions.
74 changes: 62 additions & 12 deletions crates/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use tokio::io::copy_bidirectional;
type HyperRequest = hyper::Request<ReqBody>;
type HyperResponse = hyper::Response<ResBody>;

/// Encode url path. This can be used when build your custom url path getter.
#[inline]
pub(crate) fn encode_url_path(path: &str) -> String {
path.split('/')
Expand Down Expand Up @@ -78,6 +79,23 @@ where
}
}

/// Url part getter. You can use this to get the proxied url path or query.
pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;

/// Default url path getter. This getter will get the url path from request wildcard param, like `<*rest>`, `<**rest>`.
pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
if let Some((_, rest)) = param {
Some(encode_url_path(rest))
} else {
None
}
}
/// Default url query getter. This getter just return the query string from request uri.
pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
req.uri().query().map(Into::into)
}

/// Proxy
#[non_exhaustive]

Expand All @@ -86,6 +104,10 @@ pub struct Proxy<U> {
pub upstreams: U,
/// [`Client`] for proxy.
pub client: Client,
/// Url path getter.
pub url_path_getter: UrlPartGetter,
/// Url query getter.
pub url_query_getter: UrlPartGetter,
}

impl<U> Proxy<U>
Expand All @@ -98,11 +120,38 @@ where
Proxy {
upstreams,
client: Client::new(),
url_path_getter: Box::new(default_url_path_getter),
url_query_getter: Box::new(default_url_query_getter),
}
}
/// Create new `Proxy` with upstreams list and [`Client`].
pub fn with_client(upstreams: U, client: Client) -> Self {
Proxy { upstreams, client }
Proxy {
upstreams,
client,
url_path_getter: Box::new(default_url_path_getter),
url_query_getter: Box::new(default_url_query_getter),
}
}

/// Set url path getter.
#[inline]
pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_path_getter = Box::new(url_path_getter);
self
}

/// Set url query getter.
#[inline]
pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_query_getter = Box::new(url_query_getter);
self
}

/// Get upstreams list.
Expand All @@ -128,23 +177,24 @@ where
}

#[inline]
fn build_proxied_request(&self, req: &mut Request) -> Result<HyperRequest, Error> {
fn build_proxied_request(&self, req: &mut Request, depot: &Depot) -> Result<HyperRequest, Error> {
let upstream = self.upstreams.elect().map_err(Error::other)?;
if upstream.is_empty() {
tracing::error!("upstreams is empty");
return Err(Error::other("upstreams is empty"));
}

let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
let mut rest = if let Some((_, rest)) = param {
encode_url_path(rest)
let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
let query = (self.url_query_getter)(req, depot);
let rest = if let Some(query) = query {
if query.starts_with('?') {
format!("{}{}", path, query)
} else {
format!("{}?{}", path, query)
}
} else {
"".into()
path
};
if let Some(query) = req.uri().query() {
rest = format!("{}?{}", rest, query);
}

let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
format!("{}{}", upstream.trim_end_matches('/'), rest)
} else if upstream.ends_with('/') || rest.starts_with('/') {
Expand Down Expand Up @@ -252,8 +302,8 @@ where
U::Error: Into<BoxedError>,
{
#[inline]
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
match self.build_proxied_request(req) {
async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
match self.build_proxied_request(req, depot) {
Ok(proxied_request) => {
match self
.call_proxied_server(proxied_request, req.extensions_mut().remove())
Expand Down

0 comments on commit 1ba59c0

Please sign in to comment.