Skip to content

Commit

Permalink
Save and use router param without wildcard (#844)
Browse files Browse the repository at this point in the history
* Save and use router param without wildcard

* wip

* Format Rust code using rustfmt

* cargo clippy

* cargo fmt

* wip

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
chrislearn and github-actions[bot] authored Jul 29, 2024
1 parent 8184414 commit 286a34e
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 66 deletions.
12 changes: 6 additions & 6 deletions crates/core/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ pub use http::request::Parts;
use http::uri::{Scheme, Uri};
use http::Extensions;
use http_body_util::{BodyExt, Limited};
use indexmap::IndexMap;
use multimap::MultiMap;
use parking_lot::RwLock;
use serde::de::Deserialize;
Expand All @@ -25,6 +24,7 @@ use crate::fuse::TransProto;
use crate::http::body::ReqBody;
use crate::http::form::{FilePart, FormData};
use crate::http::{Mime, ParseError, Version};
use crate::routing::PathParams;
use crate::serde::{from_request, from_str_map, from_str_multi_map, from_str_multi_val, from_str_val};
use crate::Error;

Expand Down Expand Up @@ -61,7 +61,7 @@ pub struct Request {
#[cfg(feature = "cookie")]
pub(crate) cookies: CookieJar,

pub(crate) params: IndexMap<String, String>,
pub(crate) params: PathParams,

// accept: Option<Vec<Mime>>,
pub(crate) queries: OnceLock<MultiMap<String, String>>,
Expand Down Expand Up @@ -110,7 +110,7 @@ impl Request {
method: Method::default(),
#[cfg(feature = "cookie")]
cookies: CookieJar::default(),
params: IndexMap::new(),
params: PathParams::new(),
queries: OnceLock::new(),
form_data: tokio::sync::OnceCell::new(),
payload: tokio::sync::OnceCell::new(),
Expand Down Expand Up @@ -171,7 +171,7 @@ impl Request {
#[cfg(feature = "cookie")]
cookies,
// accept: None,
params: IndexMap::new(),
params: PathParams::new(),
form_data: tokio::sync::OnceCell::new(),
payload: tokio::sync::OnceCell::new(),
// multipart: OnceLock::new(),
Expand Down Expand Up @@ -567,12 +567,12 @@ impl Request {
}
/// Get params reference.
#[inline]
pub fn params(&self) -> &IndexMap<String, String> {
pub fn params(&self) -> &PathParams {
&self.params
}
/// Get params mutable reference.
#[inline]
pub fn params_mut(&mut self) -> &mut IndexMap<String, String> {
pub fn params_mut(&mut self) -> &mut PathParams {
&mut self.params
}

Expand Down
21 changes: 8 additions & 13 deletions crates/core/src/routing/filters/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ impl PathWisp for CharsWisp {
}
if chars.len() == max_width {
state.forward(max_width);
state.params.insert(self.name.clone(), chars.into_iter().collect());
state.params.insert(&self.name, chars.into_iter().collect());
return true;
}
}
if chars.len() >= self.min_width {
state.forward(chars.len());
state.params.insert(self.name.clone(), chars.into_iter().collect());
state.params.insert(&self.name, chars.into_iter().collect());
true
} else {
false
Expand All @@ -274,7 +274,7 @@ impl PathWisp for CharsWisp {
}
if chars.len() >= self.min_width {
state.forward(chars.len());
state.params.insert(self.name.clone(), chars.into_iter().collect());
state.params.insert(&self.name, chars.into_iter().collect());
true
} else {
false
Expand All @@ -298,7 +298,7 @@ impl CombWisp {
impl PathWisp for CombWisp {
#[inline]
fn detect<'a>(&self, state: &mut PathState) -> bool {
let mut offline = if let Some(part) = state.parts.get_mut(state.cursor.0) {
let mut offline = if let Some(part) = state.parts.get(state.cursor.0) {
part.clone()
} else {
return false;
Expand Down Expand Up @@ -403,7 +403,7 @@ impl PathWisp for NamedWisp {
}
if !rest.is_empty() || !self.0.starts_with("*+") {
let rest = rest.to_string();
state.params.insert(self.0.clone(), rest);
state.params.insert(&self.0, rest);
state.cursor.0 = state.parts.len();
true
} else {
Expand All @@ -416,7 +416,7 @@ impl PathWisp for NamedWisp {
}
let picked = picked.expect("picked should not be `None`").to_owned();
state.forward(picked.len());
state.params.insert(self.0.clone(), picked);
state.params.insert(&self.0, picked);
true
}
}
Expand Down Expand Up @@ -456,7 +456,7 @@ impl PathWisp for RegexWisp {
if let Some(cap) = cap {
let cap = cap.as_str().to_owned();
state.forward(cap.len());
state.params.insert(self.name.clone(), cap);
state.params.insert(&self.name, cap);
true
} else {
false
Expand All @@ -472,7 +472,7 @@ impl PathWisp for RegexWisp {
if let Some(cap) = cap {
let cap = cap.as_str().to_owned();
state.forward(cap.len());
state.params.insert(self.name.clone(), cap);
state.params.insert(&self.name, cap);
true
} else {
false
Expand Down Expand Up @@ -930,11 +930,6 @@ mod tests {
let segments = PathParser::new("/").parse().unwrap();
assert!(segments.is_empty());
}
#[test]
fn test_parse_rest_without_name() {
let segments = PathParser::new("/hello/<**>").parse().unwrap();
assert_eq!(format!("{:?}", segments), r#"[ConstWisp("hello"), NamedWisp("**")]"#);
}

#[test]
fn test_parse_single_const() {
Expand Down
56 changes: 53 additions & 3 deletions crates/core/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
//!
//! #[handler]
//! fn serve_file(req: &mut Request) {
//! let rest_path = req.param::<i64>("**rest_path");
//! let rest_path = req.param::<i64>("rest_path");
//! }
//! ```
//!
Expand Down Expand Up @@ -375,6 +375,7 @@ mod router;
pub use router::Router;

use std::borrow::Cow;
use std::ops::Deref;
use std::sync::Arc;

use indexmap::IndexMap;
Expand All @@ -388,8 +389,57 @@ pub struct DetectMatched {
pub goal: Arc<dyn Handler>,
}

#[doc(hidden)]
pub type PathParams = IndexMap<String, String>;
/// The path parameters.
#[derive(Clone, Default, Debug, Eq, PartialEq)]
pub struct PathParams {
inner: IndexMap<String, String>,
greedy: bool,
}
impl Deref for PathParams {
type Target = IndexMap<String, String>;

fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl PathParams {
/// Create new `PathParams`.
pub fn new() -> Self {
PathParams::default()
}
/// If there is a wildcard param, it's value is `true`.
pub fn greedy(&self) -> bool {
self.greedy
}
/// Get the last param starts with '*', for example: <**rest>, <*?rest>.
pub fn tail(&self) -> Option<&str> {
if self.greedy {
self.inner.last().map(|(_, v)| &**v)
} else {
None
}
}

/// Insert new param.
pub fn insert(&mut self, name: &str, value: String) {
#[cfg(debug_assertions)]
{
if self.greedy {
panic!("only one wildcard param is allowed and it must be the last one.");
}
}
if name.starts_with("*+") || name.starts_with("*?") || name.starts_with("**") {
self.inner.insert(name[2..].to_owned(), value);
self.greedy = true;
} else if let Some(name) = name.strip_prefix('*') {
self.inner.insert(name.to_owned(), value);
self.greedy = true;
} else {
self.inner.insert(name.to_owned(), value);
}
}
}

#[doc(hidden)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PathState {
Expand Down
10 changes: 5 additions & 5 deletions crates/core/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,11 @@ impl fmt::Debug for Router {
} else {
format!("{prefix}{SYMBOL_TEE}{SYMBOL_RIGHT}{SYMBOL_RIGHT}")
};
let hd = if let Some(goal) = &router.goal {
format!(" -> {}", goal.type_name())
} else {
"".into()
};
let hd = router
.goal
.as_ref()
.map(|goal| format!(" -> {}", goal.type_name()))
.unwrap_or_default();
if !others.is_empty() {
writeln!(f, "{cp}{path}[{}]{hd}", others.join(","))?;
} else {
Expand Down
24 changes: 12 additions & 12 deletions crates/core/src/serde/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,16 @@ impl<'de> RequestDeserializer<'de> {
return false;
};

let field_name: Cow<'_, str> = if let Some(rename) = field.rename {
Cow::from(rename)
let field_name = if let Some(rename) = field.rename {
rename
} else if let Some(serde_rename) = field.serde_rename {
Cow::from(serde_rename)
serde_rename
} else if let Some(rename_all) = self.metadata.rename_all {
rename_all.apply_to_field(field.decl_name).into()
&*rename_all.apply_to_field(field.decl_name)
} else if let Some(serde_rename_all) = self.metadata.serde_rename_all {
serde_rename_all.apply_to_field(field.decl_name).into()
&*serde_rename_all.apply_to_field(field.decl_name)
} else {
field.decl_name.into()
field.decl_name
};

for source in sources {
Expand All @@ -237,7 +237,7 @@ impl<'de> RequestDeserializer<'de> {
}
}
SourceFrom::Query => {
let mut value = self.queries.get_vec(field_name.as_ref());
let mut value = self.queries.get_vec(field_name);
if value.is_none() {
for alias in &field.aliases {
value = self.queries.get_vec(*alias);
Expand All @@ -254,8 +254,8 @@ impl<'de> RequestDeserializer<'de> {
}
SourceFrom::Header => {
let mut value = None;
if self.headers.contains_key(field_name.as_ref()) {
value = Some(self.headers.get_all(field_name.as_ref()))
if self.headers.contains_key(field_name) {
value = Some(self.headers.get_all(field_name))
} else {
for alias in &field.aliases {
if self.headers.contains_key(*alias) {
Expand Down Expand Up @@ -301,7 +301,7 @@ impl<'de> RequestDeserializer<'de> {
if let Some(payload) = &self.payload {
match payload {
Payload::FormData(form_data) => {
let mut value = form_data.fields.get(field_name.as_ref());
let mut value = form_data.fields.get(field_name);
if value.is_none() {
for alias in &field.aliases {
value = form_data.fields.get(*alias);
Expand All @@ -318,7 +318,7 @@ impl<'de> RequestDeserializer<'de> {
return false;
}
Payload::JsonMap(ref map) => {
let mut value = map.get(field_name.as_ref());
let mut value = map.get(field_name);
if value.is_none() {
for alias in &field.aliases {
value = map.get(alias);
Expand Down Expand Up @@ -346,7 +346,7 @@ impl<'de> RequestDeserializer<'de> {
}
SourceParser::MultiMap => {
if let Some(Payload::FormData(form_data)) = self.payload {
let mut value = form_data.fields.get_vec(field_name.as_ref());
let mut value = form_data.fields.get_vec(field_name);
if value.is_none() {
for alias in &field.aliases {
value = form_data.fields.get_vec(*alias);
Expand Down
2 changes: 1 addition & 1 deletion crates/oapi/src/extract/parameter/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ mod tests {
let req = TestClient::get("http://127.0.0.1:5801").build_hyper();
let schema = req.uri().scheme().cloned().unwrap();
let mut req = Request::from_hyper(req, schema);
req.params_mut().insert("param".to_string(), "param".to_string());
req.params_mut().insert("param", "param".to_string());
let result = PathParam::<String>::extract_with_arg(&mut req, "param").await;
assert_eq!(result.unwrap().0, "param");
}
Expand Down
10 changes: 7 additions & 3 deletions crates/oapi/src/swagger_ui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,16 @@ pub(crate) fn redirect_to_dir_url(req_uri: &Uri, res: &mut Response) {
#[async_trait]
impl Handler for SwaggerUi {
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) {
let path = req.params().get("**").map(|s| &**s).unwrap_or_default();
// Redirect to dir url if path is empty and not end with '/'
if path.is_empty() && !req.uri().path().ends_with('/') {
// Redirect to dir url if path is not end with '/'
if !req.uri().path().ends_with('/') {
redirect_to_dir_url(req.uri(), res);
return;
}
let Some(path) = req.params().tail() else {
res.render(StatusError::not_found().detail("The router params is incorrect. The params should ending with a wildcard."));
return;
};

let keywords = self
.keywords
.as_ref()
Expand Down
12 changes: 5 additions & 7 deletions crates/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,12 @@ where
/// Url part getter. You can use this to get the proxied url path or query.
pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;

/// Default url path getter. This getter will get the url path from request wildcard param, like `<**rest>`, `<*+rest>`.
/// Default url path getter.
///
/// This getter will get the last param as the rest url path from request.
/// In most case you should use wildcard param, like `<**rest>`, `<*+rest>`.
pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
if let Some((_, rest)) = param {
Some(encode_url_path(rest))
} else {
None
}
req.params().tail().map(encode_url_path)
}
/// Default url query getter. This getter just return the query string from request uri.
pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
Expand Down
9 changes: 4 additions & 5 deletions crates/serve-static/src/dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,13 @@ impl DirInfo {
#[async_trait]
impl Handler for StaticDir {
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) {
let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
let req_path = req.uri().path();
let rel_path = if let Some((_, value)) = param {
value.clone()
let rel_path = if let Some(rest) = req.params().tail() {
rest
} else {
decode_url_path_safely(req_path)
&*decode_url_path_safely(req_path)
};
let rel_path = format_url_path_safely(&rel_path);
let rel_path = format_url_path_safely(rel_path);
let mut files: HashMap<String, Metadata> = HashMap::new();
let mut dirs: HashMap<String, Metadata> = HashMap::new();
let is_dot_file = Path::new(&rel_path)
Expand Down
7 changes: 3 additions & 4 deletions crates/serve-static/src/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ where
T: RustEmbed + Send + Sync + 'static,
{
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) {
let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
let req_path = if let Some((_, value)) = param {
value.clone()
let req_path = if let Some(rest) = req.params().tail() {
rest
} else {
decode_url_path_safely(req.uri().path())
&*decode_url_path_safely(req.uri().path())
};
let req_path = format_url_path_safely(&req_path);
let mut key_path = Cow::Borrowed(&*req_path);
Expand Down
Loading

0 comments on commit 286a34e

Please sign in to comment.