Skip to content

Commit

Permalink
Validate bytes based on ser_json_bytes (#1308)
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-newman authored Aug 1, 2024
1 parent 40b8a94 commit 57e6991
Show file tree
Hide file tree
Showing 19 changed files with 223 additions and 22 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ num-bigint = "0.4.6"
python3-dll-a = "0.2.10"
uuid = "1.9.1"
jiter = { version = "0.5", features = ["python"] }
hex = "0.4.3"

[lib]
name = "_pydantic_core"
Expand Down
8 changes: 4 additions & 4 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def to_json(
exclude_none: bool = False,
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
Expand All @@ -373,7 +373,7 @@ def to_json(
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
Expand Down Expand Up @@ -427,7 +427,7 @@ def to_jsonable_python(
exclude_none: bool = False,
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
Expand All @@ -448,7 +448,7 @@ def to_jsonable_python(
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
Expand Down
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
val_json_bytes: The validation option for `bytes` values, complementing ser_json_bytes. Default is 'utf8'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -107,6 +108,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
Expand Down Expand Up @@ -3904,6 +3906,7 @@ def definition_reference_schema(
'bytes_type',
'bytes_too_short',
'bytes_too_long',
'bytes_invalid_encoding',
'value_error',
'assertion_error',
'literal_error',
Expand Down
10 changes: 10 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ error_types! {
BytesTooLong {
max_length: {ctx_type: usize, ctx_fn: field_from_context},
},
BytesInvalidEncoding {
encoding: {ctx_type: String, ctx_fn: field_from_context},
encoding_error: {ctx_type: String, ctx_fn: field_from_context},
},
// ---------------------
// python errors from functions
ValueError {
Expand Down Expand Up @@ -515,6 +519,7 @@ impl ErrorType {
Self::BytesType {..} => "Input should be a valid bytes",
Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}",
Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}",
Self::BytesInvalidEncoding { .. } => "Data should be valid {encoding}: {encoding_error}",
Self::ValueError {..} => "Value error, {error}",
Self::AssertionError {..} => "Assertion failed, {error}",
Self::CustomError {..} => "", // custom errors are handled separately
Expand Down Expand Up @@ -664,6 +669,11 @@ impl ErrorType {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::BytesInvalidEncoding {
encoding,
encoding_error,
..
} => render!(tmpl, encoding, encoding_error),
Self::ValueError { error, .. } => {
let error = &error
.as_ref()
Expand Down
3 changes: 2 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::{intern, prelude::*};
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;
use crate::validators::ValBytesMode;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
Expand Down Expand Up @@ -71,7 +72,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;

fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;
fn validate_bytes<'a>(&'a self, strict: bool, mode: ValBytesMode) -> ValMatch<EitherBytes<'a, 'py>>;

fn validate_bool(&self, strict: bool) -> ValMatch<bool>;

Expand Down
23 changes: 19 additions & 4 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use strum::EnumMessage;
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -106,9 +107,16 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())),
JsonValue::Str(s) => match mode.deserialize_string(s) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
},
_ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}
Expand Down Expand Up @@ -342,8 +350,15 @@ impl<'py> Input<'py> for str {
Ok(ValidationMatch::strict(self.into()))
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
Ok(ValidationMatch::strict(self.as_bytes().into()))
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match mode.deserialize_string(self) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
}
}

fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {
Expand Down
12 changes: 10 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
use crate::tools::{extract_i64, safe_repr};
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::ValBytesMode;
use crate::ArgsKwargs;

use super::datetime::{
Expand Down Expand Up @@ -174,7 +175,11 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::StringType, self))
}

fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
if let Ok(py_bytes) = self.downcast_exact::<PyBytes>() {
return Ok(ValidationMatch::exact(py_bytes.into()));
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
Expand All @@ -185,7 +190,10 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
if !strict {
return if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
Ok(str.as_bytes().into())
match mode.deserialize_string(str) {
Ok(b) => Ok(b),
Err(e) => Err(ValError::new(e, self)),
}
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
Ok(py_byte_array.to_vec().into())
} else {
Expand Down
12 changes: 10 additions & 2 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
Expand Down Expand Up @@ -105,9 +106,16 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())),
Self::String(s) => py_string_str(s).and_then(|b| match mode.deserialize_string(b) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
}),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::sync::OnceLock;
use jiter::{map_json_error, PartialMode, PythonParse, StringCacheMode};
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};
use serializers::BytesMode;
use validators::ValBytesMode;

// parse this first to get access to the contained macro
#[macro_use]
Expand Down Expand Up @@ -55,7 +57,7 @@ pub fn from_json<'py>(
allow_partial: bool,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false)
.validate_bytes(false, ValBytesMode { ser: BytesMode::Utf8 })
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
Expand Down
6 changes: 2 additions & 4 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub trait FromConfig {
macro_rules! serialization_mode {
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum $name {
pub enum $name {
#[default]
$($variant,)*
}
Expand Down Expand Up @@ -183,9 +183,7 @@ impl BytesMode {
Err(e) => Err(Error::custom(e.to_string())),
},
Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)),
Self::Hex => {
serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")))
}
Self::Hex => serializer.serialize_str(hex::encode(bytes).as_str()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::{PyTraverseError, PyVisit};
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::py_gc::PyGcTraverse;

pub(crate) use config::BytesMode;
use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionState, WarningsMode};
Expand Down
11 changes: 9 additions & 2 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ use crate::input::Input;

use crate::tools::SchemaDict;

use super::config::ValBytesMode;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BytesValidator {
strict: bool,
bytes_mode: ValBytesMode,
}

impl BuildValidator for BytesValidator {
Expand All @@ -31,6 +33,7 @@ impl BuildValidator for BytesValidator {
} else {
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: ValBytesMode::from_config(config)?,
}
.into())
}
Expand All @@ -47,7 +50,7 @@ impl Validator for BytesValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
input
.validate_bytes(state.strict_or(self.strict))
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)
.map(|m| m.unpack(state).into_py(py))
}

Expand All @@ -59,6 +62,7 @@ impl Validator for BytesValidator {
#[derive(Debug, Clone)]
pub struct BytesConstrainedValidator {
strict: bool,
bytes_mode: ValBytesMode,
max_length: Option<usize>,
min_length: Option<usize>,
}
Expand All @@ -72,7 +76,9 @@ impl Validator for BytesConstrainedValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state);
let either_bytes = input
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)?
.unpack(state);
let len = either_bytes.len()?;

if let Some(min_length) = self.min_length {
Expand Down Expand Up @@ -110,6 +116,7 @@ impl BytesConstrainedValidator {
let py = schema.py();
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: ValBytesMode::from_config(config)?,
min_length: schema.get_as(intern!(py, "min_length"))?,
max_length: schema.get_as(intern!(py, "max_length"))?,
}
Expand Down
49 changes: 49 additions & 0 deletions src/validators/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::borrow::Cow;
use std::str::FromStr;

use base64::Engine;
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, prelude::*};

use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct ValBytesMode {
pub ser: BytesMode,
}

impl ValBytesMode {
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_json_bytes"))?;
let ser_mode = raw_mode.map_or_else(|| Ok(BytesMode::default()), |raw| BytesMode::from_str(&raw.to_cow()?))?;
Ok(Self { ser: ser_mode })
}

pub fn deserialize_string<'py>(self, s: &str) -> Result<EitherBytes<'_, 'py>, ErrorType> {
match self.ser {
BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
BytesMode::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
Ok(bytes) => Ok(EitherBytes::from(bytes)),
Err(err) => Err(ErrorType::BytesInvalidEncoding {
encoding: "base64".to_string(),
encoding_error: err.to_string(),
context: None,
}),
},
BytesMode::Hex => match hex::decode(s) {
Ok(vec) => Ok(EitherBytes::from(vec)),
Err(err) => Err(ErrorType::BytesInvalidEncoding {
encoding: "hex".to_string(),
encoding_error: err.to_string(),
context: None,
}),
},
}
}
}
Loading

0 comments on commit 57e6991

Please sign in to comment.