diff --git a/crates/core/src/extract/metadata.rs b/crates/core/src/extract/metadata.rs index eeac85786..d15bcc091 100644 --- a/crates/core/src/extract/metadata.rs +++ b/crates/core/src/extract/metadata.rs @@ -49,6 +49,8 @@ pub enum SourceParser { MultiMap, /// Json parser. Json, + /// Url or Header parser. + Flat, /// Smart parser. Smart, } diff --git a/crates/core/src/serde/cow_value.rs b/crates/core/src/serde/cow_value.rs new file mode 100644 index 000000000..fc133d474 --- /dev/null +++ b/crates/core/src/serde/cow_value.rs @@ -0,0 +1,111 @@ +use std::borrow::Cow; + +use serde::de::value::Error as ValError; +use serde::de::{Deserializer, Error as DeError, IntoDeserializer, Visitor}; +use serde::forward_to_deserialize_any; + +use super::ValueEnumAccess; + +macro_rules! forward_cow_parsed_value { + ($($ty:ident => $method:ident,)*) => { + $( + fn $method(self, visitor: V) -> Result + where V: Visitor<'de> + { + match self.0.parse::<$ty>() { + Ok(val) => val.into_deserializer().$method(visitor), + Err(e) => Err(DeError::custom(e)) + } + } + )* + } +} + +#[derive(Debug)] +pub(super) struct CowValue<'de>(pub(super) Cow<'de, str>); +impl<'de> IntoDeserializer<'de> for CowValue<'de> { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de> Deserializer<'de> for CowValue<'de> { + type Error = ValError; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.0 { + Cow::Borrowed(value) => visitor.visit_borrowed_str(value), + Cow::Owned(value) => visitor.visit_string(value), + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + #[inline] + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(ValueEnumAccess(self.0)) + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + forward_to_deserialize_any! { + char + str + string + unit + bytes + byte_buf + unit_struct + tuple_struct + struct + identifier + tuple + ignored_any + seq + map + } + + forward_cow_parsed_value! { + bool => deserialize_bool, + u8 => deserialize_u8, + u16 => deserialize_u16, + u32 => deserialize_u32, + u64 => deserialize_u64, + i8 => deserialize_i8, + i16 => deserialize_i16, + i32 => deserialize_i32, + i64 => deserialize_i64, + f32 => deserialize_f32, + f64 => deserialize_f64, + } +} diff --git a/crates/core/src/serde/flat_value.rs b/crates/core/src/serde/flat_value.rs new file mode 100644 index 000000000..489c19340 --- /dev/null +++ b/crates/core/src/serde/flat_value.rs @@ -0,0 +1,242 @@ +use std::borrow::Cow; + +use serde::de::value::{Error as ValError, SeqDeserializer}; +use serde::de::{Deserializer, Error as DeError, IntoDeserializer, Visitor}; +use serde::forward_to_deserialize_any; + +use super::{CowValue, ValueEnumAccess}; + +macro_rules! forward_url_query_parsed_value { + ($($ty:ident => $method:ident,)*) => { + $( + fn $method(self, visitor: V) -> Result + where V: Visitor<'de> + { + if let Some(item) = self.0.into_iter().next() { + match item.0.parse::<$ty>() { + Ok(val) => val.into_deserializer().$method(visitor), + Err(e) => Err(DeError::custom(e)) + } + } else { + Err(DeError::custom("expected vec not empty")) + } + } + )* + } +} + +pub(super) struct FlatValue<'de>(pub(super) Vec>); +impl<'de> IntoDeserializer<'de> for FlatValue<'de> { + type Deserializer = Self; + + #[inline] + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de> Deserializer<'de> for FlatValue<'de> { + type Error = ValError; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if let Some(item) = self.0.into_iter().next() { + item.deserialize_any(visitor) + } else { + Err(DeError::custom("expected url query not empty")) + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + #[inline] + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if let Some(item) = self.0.into_iter().next() { + visitor.visit_enum(ValueEnumAccess(item.0)) + } else { + Err(DeError::custom("expected vec not empty")) + } + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + #[inline] + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + #[inline] + fn deserialize_seq(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let mut items = std::mem::take(&mut self.0); + let single_mode = if items.len() == 1 { + if let Some(item) = items.get(0) { + item.0.starts_with('[') && item.0.ends_with(']') + } else { + false + } + } else { + false + }; + if single_mode { + let parser = FlatParser::new(items.remove(0).0); + visitor.visit_seq(SeqDeserializer::new(parser.into_iter())) + } else { + visitor.visit_seq(SeqDeserializer::new(items.into_iter())) + } + } + + forward_to_deserialize_any! { + char + str + string + unit + bytes + byte_buf + unit_struct + struct + identifier + ignored_any + map + } + + forward_url_query_parsed_value! { + bool => deserialize_bool, + u8 => deserialize_u8, + u16 => deserialize_u16, + u32 => deserialize_u32, + u64 => deserialize_u64, + i8 => deserialize_i8, + i16 => deserialize_i16, + i32 => deserialize_i32, + i64 => deserialize_i64, + f32 => deserialize_f32, + f64 => deserialize_f64, + } +} + +struct FlatParser<'de> { + input: Cow<'de, str>, + start: usize, +} +impl<'de> FlatParser<'de> { + fn new(input: Cow<'de, str>) -> Self { + Self { input, start: 1 } + } +} +impl<'de> Iterator for FlatParser<'de> { + type Item = CowValue<'de>; + + fn next(&mut self) -> Option { + let mut quote = None; + let mut in_escape = false; + let mut end = self.start; + let mut in_next = false; + for c in self.input[self.start..].chars() { + if in_escape { + in_escape = false; + continue; + } + match c { + '\\' => { + in_escape = true; + in_next = true; + } + ' ' => { + if quote.is_none() { + self.start += 1; + } + } + '"' | '\'' => { + in_next = true; + if quote == Some(c) { + let item = Cow::Owned(self.input[self.start..end].to_string()); + self.start = end + 2; + return Some(CowValue(item)); + } else { + quote = Some(c); + self.start += 1; + } + } + ',' | ']' => { + if quote.is_none() && in_next { + let item = Cow::Owned(self.input[self.start..end].to_string()); + self.start = end + 1; + return Some(CowValue(item)); + } + } + _ => { + in_next = true; + } + } + end += 1; + } + None + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_flat_parser_1() { + let parser = super::FlatParser::new("[1,2, 3]".into()); + let mut iter = parser.into_iter(); + assert_eq!(iter.next().unwrap().0, "1"); + assert_eq!(iter.next().unwrap().0, "2"); + assert_eq!(iter.next().unwrap().0, "3"); + assert!(iter.next().is_none()); + } + #[test] + fn test_flat_parser_2() { + let parser = super::FlatParser::new(r#"['3', '2',"11","1,2"]"#.into()); + let mut iter = parser.into_iter(); + assert_eq!(iter.next().unwrap().0, "3"); + assert_eq!(iter.next().unwrap().0, "2"); + assert_eq!(iter.next().unwrap().0, "11"); + assert_eq!(iter.next().unwrap().0, "1,2"); + assert!(iter.next().is_none()); + } +} diff --git a/crates/core/src/serde/mod.rs b/crates/core/src/serde/mod.rs index d1b89b4d7..e2205f1ca 100644 --- a/crates/core/src/serde/mod.rs +++ b/crates/core/src/serde/mod.rs @@ -3,13 +3,18 @@ use std::hash::Hash; pub use serde::de::value::{Error as ValError, MapDeserializer, SeqDeserializer}; use serde::de::{ - Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as DeError, IntoDeserializer, - VariantAccess, Visitor, + Deserialize, DeserializeSeed, EnumAccess, Error as DeError, IntoDeserializer, VariantAccess, + Visitor, }; -use serde::forward_to_deserialize_any; mod request; pub use request::from_request; +mod cow_value; +use cow_value::CowValue; +mod vec_value; +use vec_value::VecValue; +mod flat_value; +use flat_value::FlatValue; #[inline] pub fn from_str_map<'de, I, T, K, V>(input: I) -> Result @@ -62,40 +67,6 @@ where T::deserialize(CowValue(input.into())) } -macro_rules! forward_cow_parsed_value { - ($($ty:ident => $method:ident,)*) => { - $( - fn $method(self, visitor: V) -> Result - where V: Visitor<'de> - { - match self.0.parse::<$ty>() { - Ok(val) => val.into_deserializer().$method(visitor), - Err(e) => Err(DeError::custom(e)) - } - } - )* - } -} - -macro_rules! forward_vec_parsed_value { - ($($ty:ident => $method:ident,)*) => { - $( - fn $method(self, visitor: V) -> Result - where V: Visitor<'de> - { - if let Some(item) = self.0.into_iter().next() { - match item.0.parse::<$ty>() { - Ok(val) => val.into_deserializer().$method(visitor), - Err(e) => Err(DeError::custom(e)) - } - } else { - Err(DeError::custom("expected vec not empty")) - } - } - )* - } -} - struct ValueEnumAccess<'de>(Cow<'de, str>); impl<'de> EnumAccess<'de> for ValueEnumAccess<'de> { @@ -151,219 +122,6 @@ impl<'de> VariantAccess<'de> for UnitOnlyVariantAccess { } } -#[derive(Debug)] -struct CowValue<'de>(Cow<'de, str>); -impl<'de> IntoDeserializer<'de> for CowValue<'de> { - type Deserializer = Self; - - fn into_deserializer(self) -> Self::Deserializer { - self - } -} - -impl<'de> Deserializer<'de> for CowValue<'de> { - type Error = ValError; - - #[inline] - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - match self.0 { - Cow::Borrowed(value) => visitor.visit_borrowed_str(value), - Cow::Owned(value) => visitor.visit_string(value), - } - } - - #[inline] - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_some(self) - } - - #[inline] - fn deserialize_enum( - self, - _name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - visitor.visit_enum(ValueEnumAccess(self.0)) - } - - #[inline] - fn deserialize_newtype_struct( - self, - _name: &'static str, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - forward_to_deserialize_any! { - char - str - string - unit - bytes - byte_buf - unit_struct - tuple_struct - struct - identifier - tuple - ignored_any - seq - map - } - - forward_cow_parsed_value! { - bool => deserialize_bool, - u8 => deserialize_u8, - u16 => deserialize_u16, - u32 => deserialize_u32, - u64 => deserialize_u64, - i8 => deserialize_i8, - i16 => deserialize_i16, - i32 => deserialize_i32, - i64 => deserialize_i64, - f32 => deserialize_f32, - f64 => deserialize_f64, - } -} - -struct VecValue(I); -impl<'de, I> IntoDeserializer<'de> for VecValue -where - I: Iterator>, -{ - type Deserializer = Self; - - #[inline] - fn into_deserializer(self) -> Self::Deserializer { - self - } -} - -impl<'de, I> Deserializer<'de> for VecValue -where - I: IntoIterator>, -{ - type Error = ValError; - - #[inline] - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - if let Some(item) = self.0.into_iter().next() { - item.deserialize_any(visitor) - } else { - Err(DeError::custom("expected vec not empty")) - } - } - - #[inline] - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_some(self) - } - - #[inline] - fn deserialize_enum( - self, - _name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - if let Some(item) = self.0.into_iter().next() { - visitor.visit_enum(ValueEnumAccess(item.0)) - } else { - Err(DeError::custom("expected vec not empty")) - } - } - - #[inline] - fn deserialize_newtype_struct( - self, - _name: &'static str, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - - #[inline] - fn deserialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(visitor) - } - #[inline] - fn deserialize_tuple(self, _len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(visitor) - } - #[inline] - fn deserialize_seq(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_seq(SeqDeserializer::new(self.0.into_iter())) - } - - forward_to_deserialize_any! { - char - str - string - unit - bytes - byte_buf - unit_struct - struct - identifier - ignored_any - map - } - - forward_vec_parsed_value! { - bool => deserialize_bool, - u8 => deserialize_u8, - u16 => deserialize_u16, - u32 => deserialize_u32, - u64 => deserialize_u64, - i8 => deserialize_i8, - i16 => deserialize_i16, - i32 => deserialize_i32, - i64 => deserialize_i64, - f32 => deserialize_f32, - f64 => deserialize_f64, - } -} - #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/crates/core/src/serde/request.rs b/crates/core/src/serde/request.rs index fe2c84798..c350a0d62 100644 --- a/crates/core/src/serde/request.rs +++ b/crates/core/src/serde/request.rs @@ -15,7 +15,7 @@ use crate::http::header::HeaderMap; use crate::http::ParseError; use crate::Request; -use super::{CowValue, VecValue}; +use super::{CowValue, FlatValue, VecValue}; pub async fn from_request<'de, T>( req: &'de mut Request, @@ -141,6 +141,8 @@ impl<'de> RequestDeserializer<'de> { } else { parser = SourceParser::MultiMap; } + } else if source.from == SourceFrom::Query || source.from == SourceFrom::Header { + parser = SourceParser::Flat; } else { parser = SourceParser::MultiMap; } @@ -158,7 +160,7 @@ impl<'de> RequestDeserializer<'de> { .fields .get(self.field_index as usize) .expect("field must exist."); - let metadata = field.metadata.expect("Field's metadata must exist"); + let metadata = field.metadata.expect("field's metadata must exist"); seed.deserialize(RequestDeserializer { params: self.params, queries: self.queries, @@ -192,7 +194,11 @@ impl<'de> RequestDeserializer<'de> { } else if let Some(value) = self.field_str_value.take() { seed.deserialize(CowValue(value.into())) } else if let Some(value) = self.field_vec_value.take() { - seed.deserialize(VecValue(value.into_iter())) + if source.from == SourceFrom::Query || source.from == SourceFrom::Header { + seed.deserialize(FlatValue(value)) + } else { + seed.deserialize(VecValue(value.into_iter())) + } } else { Err(ValError::custom("parse value error")) } @@ -668,6 +674,7 @@ mod tests { let data: RequestData = req.extract().await.unwrap(); assert_eq!(data, RequestData { p2: "921", b: true }); } + #[tokio::test] async fn test_de_request_with_json_str() { #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] @@ -691,6 +698,7 @@ mod tests { } ); } + #[tokio::test] async fn test_de_request_with_form_json_str() { #[derive(Deserialize, Eq, PartialEq, Debug)] @@ -721,6 +729,7 @@ mod tests { } ); } + #[tokio::test] async fn test_de_request_with_extract_rename_all() { #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] @@ -743,6 +752,7 @@ mod tests { } ); } + #[tokio::test] async fn test_de_request_with_serde_rename_all() { #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] @@ -766,6 +776,7 @@ mod tests { } ); } + #[tokio::test] async fn test_de_request_with_both_rename_all() { #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] @@ -789,4 +800,56 @@ mod tests { } ); } + + #[tokio::test] + async fn test_de_request_url_array() { + #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] + #[salvo(extract(default_source(from = "query")))] + struct RequestData { + ids: Vec, + } + let mut req = + TestClient::get("http://127.0.0.1:5800/test/1234/param2v?ids=[3,2,11]").build(); + let data: RequestData = req.extract().await.unwrap(); + assert_eq!( + data, + RequestData { + ids: vec!["3".to_string(), "2".to_string(), "11".to_string()] + } + ); + let mut req = TestClient::get( + r#"http://127.0.0.1:5800/test/1234/param2v?ids=['3', '2',"11","1,2"]"#, + ) + .build(); + let data: RequestData = req.extract().await.unwrap(); + assert_eq!( + data, + RequestData { + ids: vec![ + "3".to_string(), + "2".to_string(), + "11".to_string(), + "1,2".to_string() + ] + } + ); + } + + #[tokio::test] + async fn test_de_request_url_array2() { + #[derive(Deserialize, Extractible, Eq, PartialEq, Debug)] + #[salvo(extract(default_source(from = "query")))] + struct RequestData { + ids: Vec, + } + let mut req = + TestClient::get("http://127.0.0.1:5800/test/1234/param2v?ids=[3,2,11]").build(); + let data: RequestData = req.extract().await.unwrap(); + assert_eq!( + data, + RequestData { + ids: vec![3, 2, 11] + } + ); + } } diff --git a/crates/core/src/serde/vec_value.rs b/crates/core/src/serde/vec_value.rs new file mode 100644 index 000000000..46baa01ea --- /dev/null +++ b/crates/core/src/serde/vec_value.rs @@ -0,0 +1,148 @@ +use serde::de::value::{Error as ValError, SeqDeserializer}; +use serde::de::{Deserializer, Error as DeError, IntoDeserializer, Visitor}; +use serde::forward_to_deserialize_any; + +use super::{CowValue, ValueEnumAccess}; + +macro_rules! forward_vec_parsed_value { + ($($ty:ident => $method:ident,)*) => { + $( + fn $method(self, visitor: V) -> Result + where V: Visitor<'de> + { + if let Some(item) = self.0.into_iter().next() { + match item.0.parse::<$ty>() { + Ok(val) => val.into_deserializer().$method(visitor), + Err(e) => Err(DeError::custom(e)) + } + } else { + Err(DeError::custom("expected vec not empty")) + } + } + )* + } +} + +pub(super) struct VecValue(pub(super) I); +impl<'de, I> IntoDeserializer<'de> for VecValue +where + I: Iterator>, +{ + type Deserializer = Self; + + #[inline] + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de, I> Deserializer<'de> for VecValue +where + I: IntoIterator>, +{ + type Error = ValError; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if let Some(item) = self.0.into_iter().next() { + item.deserialize_any(visitor) + } else { + Err(DeError::custom("expected vec not empty")) + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + #[inline] + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if let Some(item) = self.0.into_iter().next() { + visitor.visit_enum(ValueEnumAccess(item.0)) + } else { + Err(DeError::custom("expected vec not empty")) + } + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + #[inline] + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + #[inline] + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + #[inline] + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer::new(self.0.into_iter())) + } + + forward_to_deserialize_any! { + char + str + string + unit + bytes + byte_buf + unit_struct + struct + identifier + ignored_any + map + } + + forward_vec_parsed_value! { + bool => deserialize_bool, + u8 => deserialize_u8, + u16 => deserialize_u16, + u32 => deserialize_u32, + u64 => deserialize_u64, + i8 => deserialize_i8, + i16 => deserialize_i16, + i32 => deserialize_i32, + i64 => deserialize_i64, + f32 => deserialize_f32, + f64 => deserialize_f64, + } +}