diff --git a/generate_self_schema.py b/generate_self_schema.py index 3494af8a9..0b3b9feab 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -6,6 +6,7 @@ """ from __future__ import annotations as _annotations +import decimal import importlib.util import re from collections.abc import Callable @@ -23,6 +24,7 @@ UnionType = Union[TypingUnionType, TypesUnionType] except ImportError: + TypesUnionType = None UnionType = TypingUnionType @@ -46,8 +48,8 @@ def get_schema(obj) -> core_schema.CoreSchema: if isinstance(obj, str): return {'type': obj} - elif obj in (datetime, timedelta, date, time, bool, int, float, str): - return {'type': obj.__name__} + elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal): + return {'type': obj.__name__.lower()} elif is_typeddict(obj): return type_dict_schema(obj) elif obj == Any or obj == type: @@ -57,7 +59,7 @@ def get_schema(obj) -> core_schema.CoreSchema: origin = get_origin(obj) assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py' - if origin is Union: + if origin is Union or origin is TypesUnionType: return union_schema(obj) elif obj is Callable or origin is Callable: return {'type': 'callable'} @@ -79,6 +81,7 @@ def get_schema(obj) -> core_schema.CoreSchema: # can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators return {'type': 'any'} else: + print(origin) # debug(obj) raise TypeError(f'Unknown type: {obj!r}') diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 82ba5887d..dd2795a29 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -8,6 +8,7 @@ import sys from collections.abc import Mapping from datetime import date, datetime, time, timedelta +from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union if sys.version_info < (3, 11): @@ -659,6 +660,72 @@ def float_schema( ) +class DecimalSchema(TypedDict, total=False): + type: Required[Literal['decimal']] + gt: Union[int, Decimal] + ge: Union[int, Decimal] + lt: Union[int, Decimal] + le: Union[int, Decimal] + max_digits: int + decimal_places: int + multiple_of: Union[int, Decimal] + allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: False + check_digits: bool # FIXME document. default: False + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def decimal_schema( + *, + gt: int | Decimal | None = None, + ge: int | Decimal | None = None, + lt: int | Decimal | None = None, + le: int | Decimal | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, + multiple_of: int | Decimal | None = None, + allow_inf_nan: bool = None, + check_digits: bool = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> DecimalSchema: + """ + Returns a schema that matches a decimal value, e.g.: + + ```py + from decimal import Decimal + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.decimal_schema(le=0.8, ge=0.2) + v = SchemaValidator(schema) + assert v.validate_python('0.5') == Decimal('0.5') + ``` + + Args: + FIXME document + """ + return _dict_not_none( + type='decimal', + gt=gt, + ge=ge, + lt=lt, + le=le, + max_digits=max_digits, + decimal_places=decimal_places, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + check_digits=check_digits, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + class StringSchema(TypedDict, total=False): type: Required[Literal['str']] pattern: str @@ -3713,6 +3780,7 @@ def definition_reference_schema( BoolSchema, IntSchema, FloatSchema, + DecimalSchema, StringSchema, BytesSchema, DateSchema, @@ -3767,6 +3835,7 @@ def definition_reference_schema( 'bool', 'int', 'float', + 'decimal', 'str', 'bytes', 'date', @@ -3907,6 +3976,11 @@ def definition_reference_schema( 'uuid_type', 'uuid_parsing', 'uuid_version', + 'decimal_type', + 'decimal_parsing', + 'decimal_max_digits', + 'decimal_max_places', + 'decimal_whole_digits', ] diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 0983108e8..b81d1174e 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -8,7 +8,7 @@ mod value_exception; pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; pub use self::location::LocItem; -pub use self::types::{list_all_errors, ErrorMode, ErrorType}; +pub use self::types::{list_all_errors, ErrorMode, ErrorType, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/errors/types.rs b/src/errors/types.rs index 67a4da39b..a0f696141 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -108,8 +108,7 @@ pub enum ErrorType { // None errors NoneRequired, // --------------------- - // generic comparison errors - used for all inequality comparisons except int and float which have their - // own type, bounds arguments are Strings so they can be created from any type + // generic comparison errors GreaterThan { gt: Number, }, @@ -316,6 +315,18 @@ pub enum ErrorType { UuidVersion { expected_version: usize, }, + // Decimal errors + DecimalType, + DecimalParsing, + DecimalMaxDigits { + max_digits: u64, + }, + DecimalMaxPlaces { + decimal_places: u64, + }, + DecimalWholeDigits { + max_whole_digits: u64, + }, } macro_rules! render { @@ -463,6 +474,9 @@ impl ErrorType { Self::UuidVersion { .. } => { extract_context!(UuidVersion, ctx, expected_version: usize) } + Self::DecimalMaxDigits { .. } => extract_context!(DecimalMaxDigits, ctx, max_digits: u64), + Self::DecimalMaxPlaces { .. } => extract_context!(DecimalMaxPlaces, ctx, decimal_places: u64), + Self::DecimalWholeDigits { .. } => extract_context!(DecimalWholeDigits, ctx, max_whole_digits: u64), _ => { if ctx.is_some() { py_err!(PyTypeError; "'{}' errors do not require context", value) @@ -569,7 +583,12 @@ impl ErrorType { Self::UrlScheme {..} => "URL scheme should be {expected_schemes}", Self::UuidType => "UUID input should be a string, bytes or UUID object", Self::UuidParsing { .. } => "Input should be a valid UUID, {error}", - Self::UuidVersion { .. } => "UUID version {expected_version} expected" + Self::UuidVersion { .. } => "UUID version {expected_version} expected", + Self::DecimalType => "Decimal input should be an integer, float, string or Decimal object", + Self::DecimalParsing => "Input should be a valid decimal", + Self::DecimalMaxDigits { .. } => "Decimal input should have no more than {max_digits} digits in total", + Self::DecimalMaxPlaces { .. } => "Decimal input should have no more than {decimal_places} decimal places", + Self::DecimalWholeDigits { .. } => "Decimal input should have no more than {max_whole_digits} digits before the decimal point", } } @@ -692,6 +711,9 @@ impl ErrorType { Self::UrlScheme { expected_schemes } => render!(tmpl, expected_schemes), Self::UuidParsing { error } => render!(tmpl, error), Self::UuidVersion { expected_version } => to_string_render!(tmpl, expected_version), + Self::DecimalMaxDigits { max_digits } => to_string_render!(tmpl, max_digits), + Self::DecimalMaxPlaces { decimal_places } => to_string_render!(tmpl, decimal_places), + Self::DecimalWholeDigits { max_whole_digits } => to_string_render!(tmpl, max_whole_digits), _ => Ok(tmpl.to_string()), } } @@ -755,6 +777,9 @@ impl ErrorType { Self::UuidParsing { error } => py_dict!(py, error), Self::UuidVersion { expected_version } => py_dict!(py, expected_version), + Self::DecimalMaxDigits { max_digits } => py_dict!(py, max_digits), + Self::DecimalMaxPlaces { decimal_places } => py_dict!(py, decimal_places), + Self::DecimalWholeDigits { max_whole_digits } => py_dict!(py, max_whole_digits), _ => Ok(None), } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 91332e552..80ebc059c 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -152,6 +152,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_float() } + fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + if strict { + self.strict_decimal(decimal_type) + } else { + self.lax_decimal(decimal_type) + } + } + fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>; + #[cfg_attr(has_no_coverage, no_coverage)] + fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + self.strict_decimal(decimal_type) + } + fn validate_dict(&'a self, strict: bool) -> ValResult> { if strict { self.strict_dict() diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 218baf6b6..da2b3e4d4 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyDict, PyType}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; @@ -172,6 +172,23 @@ impl<'a> Input<'a> for JsonInput { } } + fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + match self { + JsonInput::String(str) => Ok(decimal_type.call1((str,))?), + _ => Err(ValError::new(ErrorType::DecimalType, self)), + } + } + + fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + match self { + JsonInput::Float(f) => Ok(decimal_type.call1((f.to_string(),))?), + JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => { + Ok(decimal_type.call1((self.to_object(decimal_type.py()),))?) + } + _ => Err(ValError::new(ErrorType::DecimalType, self)), + } + } + fn validate_dict(&'a self, _strict: bool) -> ValResult> { match self { JsonInput::Object(dict) => Ok(dict.into()), @@ -417,6 +434,10 @@ impl<'a> Input<'a> for String { } } + fn strict_decimal(&'a self, _decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + todo!() + } + #[cfg_attr(has_no_coverage, no_coverage)] fn validate_dict(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorType::DictType, self)) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 4874e0d84..d260aa226 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::str::from_utf8; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, @@ -13,6 +14,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; +use crate::validators::decimal::create_decimal; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; use super::datetime::{ @@ -352,6 +354,56 @@ impl<'a> Input<'a> for PyAny { } } + fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + let py = decimal_type.py(); + + // Fast path for existing decimal objects + if self.is_exact_instance(decimal_type) { + return Ok(self); + } + + // Try subclasses of decimals, they will be upcast to Decimal + if self.is_instance(decimal_type)? { + match decimal_type.call1((self,)) { + Ok(decimal) => return Ok(decimal), + Err(e) + if e.matches( + py, + PyTuple::new( + py, + [ + py.import("decimal")?.getattr("DecimalException")?, + PyTypeError::type_object(py), + ], + ), + ) => {} + Err(e) => return Err(ValError::InternalErr(e)), + } + } + + Err(ValError::new( + ErrorType::IsInstanceOf { + class: decimal_type.name().unwrap_or("Decimal").to_string(), + }, + self, + )) + } + + fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + // Fast path for existing decimal objects + if self.is_exact_instance(decimal_type) { + return Ok(self); + } + + if self.is_instance_of::() || self.is_instance_of::() { + create_decimal(self, self, decimal_type) + } else if self.is_instance_of::() { + create_decimal(self.str()?, self, decimal_type) + } else { + Err(ValError::new(ErrorType::DecimalType, self)) + } + } + fn strict_dict(&'a self) -> ValResult> { if let Ok(dict) = self.downcast::() { Ok(dict.into()) diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index c49098d40..d19a2f607 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -114,6 +114,7 @@ combined_serializer! { Int: super::type_serializers::simple::IntSerializer; Bool: super::type_serializers::simple::BoolSerializer; Float: super::type_serializers::simple::FloatSerializer; + Decimal: super::type_serializers::decimal::DecimalSerializer; Str: super::type_serializers::string::StrSerializer; Bytes: super::type_serializers::bytes::BytesSerializer; Datetime: super::type_serializers::datetime_etc::DatetimeSerializer; @@ -229,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Decimal(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit), diff --git a/src/serializers/type_serializers/decimal.rs b/src/serializers/type_serializers/decimal.rs new file mode 100644 index 000000000..1fa44f6c1 --- /dev/null +++ b/src/serializers/type_serializers/decimal.rs @@ -0,0 +1,106 @@ +use std::borrow::Cow; + +use pyo3::types::{PyDict, PyType}; +use pyo3::{intern, prelude::*}; + +use crate::definitions::DefinitionsBuilder; + +use super::{ + infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra, + TypeSerializer, +}; + +#[derive(Debug, Clone)] +pub struct DecimalSerializer { + decimal_type: Py, +} + +impl BuildSerializer for DecimalSerializer { + const EXPECTED_TYPE: &'static str = "decimal"; + + fn build( + schema: &PyDict, + _config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + Ok(Self { + decimal_type: py + .import(intern!(py, "decimal"))? + .getattr(intern!(py, "Decimal"))? + .extract()?, + } + .into()) + } +} + +enum OutputValue { + Ok, + Fallback, +} + +impl DecimalSerializer { + fn check(&self, value: &PyAny, extra: &Extra) -> PyResult { + if extra.check.enabled() { + if value.is_instance(self.decimal_type.as_ref(value.py()))? { + Ok(OutputValue::Ok) + } else { + Ok(OutputValue::Fallback) + } + } else { + Ok(OutputValue::Ok) + } + } +} + +impl_py_gc_traverse!(DecimalSerializer { decimal_type }); + +impl TypeSerializer for DecimalSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let _py = value.py(); + match self.check(value, extra)? { + OutputValue::Ok => infer_to_python(value, include, exclude, extra), + OutputValue::Fallback => { + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + } + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + match self.check(key, extra)? { + OutputValue::Ok => infer_json_key(key, extra), + OutputValue::Fallback => { + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; + infer_json_key(key, extra) + } + } + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + match self.check(value, extra).map_err(py_err_se_err)? { + OutputValue::Ok => infer_serialize(value, serializer, include, exclude, extra), + OutputValue::Fallback => { + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + } + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index 250eb4982..b942b5b86 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -2,6 +2,7 @@ pub mod any; pub mod bytes; pub mod dataclass; pub mod datetime_etc; +pub mod decimal; pub mod definitions; pub mod dict; pub mod format; diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs new file mode 100644 index 000000000..d9992b67c --- /dev/null +++ b/src/validators/decimal.rs @@ -0,0 +1,255 @@ +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::types::{PyDict, PyTuple, PyType}; +use pyo3::{intern, AsPyPointer}; +use pyo3::{prelude::*, PyTypeInfo}; + +use crate::build_tools::{is_strict, schema_or_config_same}; +use crate::errors::Number; +use crate::errors::ValError; +use crate::errors::ValResult; +use crate::errors::{ErrorType, InputValue}; +use crate::input::Input; +use crate::recursion_guard::RecursionGuard; +use crate::tools::SchemaDict; + +use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; + +#[derive(Debug, Clone)] +pub struct DecimalValidator { + strict: bool, + allow_inf_nan: bool, + check_digits: bool, + multiple_of: Option>, + le: Option>, + lt: Option>, + ge: Option>, + gt: Option>, + max_digits: Option, + decimal_places: Option, + decimal_type: Py, +} + +impl BuildValidator for DecimalValidator { + const EXPECTED_TYPE: &'static str = "decimal"; + fn build( + schema: &PyDict, + config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let allow_inf_nan = schema_or_config_same(schema, config, intern!(py, "allow_inf_nan"))?.unwrap_or(false); + let decimal_places = schema.get_as(intern!(py, "decimal_places"))?; + let max_digits = schema.get_as(intern!(py, "max_digits"))?; + if allow_inf_nan && (decimal_places.is_some() || max_digits.is_some()) { + return Err(PyValueError::new_err( + "allow_inf_nan=True cannot be used with max_digits or decimal_places", + )); + } + Ok(Self { + strict: is_strict(schema, config)?, + allow_inf_nan, + check_digits: decimal_places.is_some() || max_digits.is_some(), + decimal_places, + multiple_of: schema.get_as(intern!(py, "multiple_of"))?, + le: schema.get_as(intern!(py, "le"))?, + lt: schema.get_as(intern!(py, "lt"))?, + ge: schema.get_as(intern!(py, "ge"))?, + gt: schema.get_as(intern!(py, "gt"))?, + max_digits, + decimal_type: py + .import(intern!(py, "decimal"))? + .getattr(intern!(py, "Decimal"))? + .extract()?, + } + .into()) + } +} + +impl_py_gc_traverse!(DecimalValidator { + multiple_of, + le, + lt, + ge, + gt, + decimal_type +}); + +impl Validator for DecimalValidator { + fn validate<'s, 'data>( + &'s self, + py: Python<'data>, + input: &'data impl Input<'data>, + extra: &Extra, + _definitions: &'data Definitions, + _recursion_guard: &'s mut RecursionGuard, + ) -> ValResult<'data, PyObject> { + let decimal = input.validate_decimal( + extra.strict.unwrap_or(self.strict), + // Safety: self and py both outlive this call + unsafe { py.from_borrowed_ptr(self.decimal_type.as_ptr()) }, + )?; + + if !self.allow_inf_nan || self.check_digits { + if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? { + return Err(ValError::new(ErrorType::FiniteNumber, input)); + } + + if self.check_digits { + let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal); + let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = + normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?; + + // finite values have numeric exponent, we checked is_finite above + let exponent: i64 = exponent.extract()?; + let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; + let decimals; + if exponent >= 0 { + // A positive exponent adds that many trailing zeros. + digits += exponent as u64; + decimals = 0; + } else { + // If the absolute value of the negative exponent is larger than the + // number of digits, then it's the same as the number of digits, + // because it'll consume all the digits in digit_tuple and then + // add abs(exponent) - len(digit_tuple) leading zeros after the + // decimal point. + decimals = exponent.unsigned_abs(); + digits = digits.max(decimals); + } + + if let Some(max_digits) = self.max_digits { + if digits > max_digits { + return Err(ValError::new(ErrorType::DecimalMaxDigits { max_digits }, input)); + } + } + + if let Some(decimal_places) = self.decimal_places { + if decimals > decimal_places { + return Err(ValError::new(ErrorType::DecimalMaxPlaces { decimal_places }, input)); + } + + if let Some(max_digits) = self.max_digits { + let whole_digits = digits.saturating_sub(decimals); + let max_whole_digits = max_digits.saturating_sub(decimal_places); + if whole_digits > max_whole_digits { + return Err(ValError::new(ErrorType::DecimalWholeDigits { max_whole_digits }, input)); + } + } + } + } + } + + if let Some(multiple_of) = &self.multiple_of { + // fraction = (decimal / multiple_of) % 1 + let fraction: &PyAny = unsafe { + let division = PyObject::from_owned_ptr_or_err( + py, + pyo3::ffi::PyNumber_TrueDivide(decimal.as_ptr(), multiple_of.as_ptr()), + )?; + let one = 1.to_object(py); + py.from_owned_ptr_or_err(pyo3::ffi::PyNumber_Remainder(division.as_ptr(), one.as_ptr()))? + }; + let zero = 0.to_object(py); + if !fraction.eq(&zero)? { + return Err(ValError::new( + ErrorType::MultipleOf { + multiple_of: multiple_of.to_string().into(), + }, + input, + )); + } + } + + if let Some(le) = &self.le { + if !decimal.le(le)? { + return Err(ValError::new( + ErrorType::LessThanEqual { + le: Number::String(le.to_string()), + }, + input, + )); + } + } + if let Some(lt) = &self.lt { + if !decimal.lt(lt)? { + return Err(ValError::new( + ErrorType::LessThan { + lt: Number::String(lt.to_string()), + }, + input, + )); + } + } + if let Some(ge) = &self.ge { + if !decimal.ge(ge)? { + return Err(ValError::new( + ErrorType::GreaterThanEqual { + ge: Number::String(ge.to_string()), + }, + input, + )); + } + } + if let Some(gt) = &self.gt { + if !decimal.gt(gt)? { + return Err(ValError::new( + ErrorType::GreaterThan { + gt: Number::String(gt.to_string()), + }, + input, + )); + } + } + + Ok(decimal.into()) + } + + fn different_strict_behavior( + &self, + _definitions: Option<&DefinitionsBuilder>, + _ultra_strict: bool, + ) -> bool { + true + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } + + fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + Ok(()) + } +} + +pub(crate) fn create_decimal<'a>( + arg: &'a PyAny, + input: &'a impl Input<'a>, + decimal_type: &'a PyType, +) -> ValResult<'a, &'a PyAny> { + decimal_type.call1((arg,)).map_err(|e| { + let decimal_exception = match arg + .py() + .import("decimal") + .and_then(|decimal_module| decimal_module.getattr("DecimalException")) + { + Ok(decimal_exception) => decimal_exception, + Err(e) => return ValError::InternalErr(e), + }; + handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception) + }) +} + +fn handle_decimal_new_error<'a>( + py: Python<'a>, + input: InputValue<'a>, + error: PyErr, + decimal_exception: &'a PyAny, +) -> ValError<'a> { + if error.matches(py, decimal_exception) { + ValError::new_custom_input(ErrorType::DecimalParsing, input) + } else if error.matches(py, PyTypeError::type_object(py)) { + ValError::new_custom_input(ErrorType::DecimalType, input) + } else { + ValError::InternalErr(error) + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 676fbbb9e..9a1b2eaa1 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -27,6 +27,7 @@ mod custom_error; mod dataclass; mod date; mod datetime; +pub(crate) mod decimal; mod definitions; mod dict; mod float; @@ -451,6 +452,8 @@ pub fn build_validator<'a>( bool::BoolValidator, // floats float::FloatBuilder, + // decimals + decimal::DecimalValidator, // tuples tuple::TuplePositionalValidator, tuple::TupleVariableValidator, @@ -597,6 +600,8 @@ pub enum CombinedValidator { // floats Float(float::FloatValidator), ConstrainedFloat(float::ConstrainedFloatValidator), + // decimals + Decimal(decimal::DecimalValidator), // lists List(list::ListValidator), // sets - unique lists diff --git a/tests/benchmarks/complete_schema.py b/tests/benchmarks/complete_schema.py index edebf92fc..69a27277b 100644 --- a/tests/benchmarks/complete_schema.py +++ b/tests/benchmarks/complete_schema.py @@ -1,3 +1,6 @@ +from decimal import Decimal + + def schema(*, strict: bool = False) -> dict: class MyModel: # __slots__ is not required, but it avoids __pydantic_fields_set__ falling into __dict__ @@ -31,6 +34,7 @@ def wrap_function(input_value, validator, info): 'type': 'model-field', 'schema': {'type': 'float', 'ge': 1.0, 'le': 10.0, 'multiple_of': 0.5}, }, + 'field_decimal': {'type': 'model-field', 'schema': {'type': 'decimal'}}, 'field_bool': {'type': 'model-field', 'schema': {'type': 'bool'}}, 'field_bytes': {'type': 'model-field', 'schema': {'type': 'bytes'}}, 'field_bytes_con': { @@ -218,6 +222,7 @@ def input_data_lax(): 'field_int_con': 8, 'field_float': 1.0, 'field_float_con': 10.0, + 'field_decimal': 42.0, 'field_bool': True, 'field_bytes': b'foobar', 'field_bytes_con': b'foobar', @@ -276,6 +281,7 @@ def input_data_strict(): field_datetime=datetime(2020, 1, 1, 12, 13, 14), field_datetime_con=datetime(2020, 1, 1), field_uuid=UUID('12345678-1234-5678-1234-567812345678'), + field_decimal=Decimal('42.0'), ) return input_data @@ -288,6 +294,7 @@ def input_data_wrong(): 'field_int_con': 11, 'field_float': False, 'field_float_con': 10.1, + 'field_decimal': 'wrong', 'field_bool': 4, 'field_bytes': 42, 'field_bytes_con': b'foo', diff --git a/tests/benchmarks/test_complete_benchmark.py b/tests/benchmarks/test_complete_benchmark.py index 3657cefeb..8ed07d37c 100644 --- a/tests/benchmarks/test_complete_benchmark.py +++ b/tests/benchmarks/test_complete_benchmark.py @@ -3,6 +3,7 @@ """ import json from datetime import date, datetime, time +from decimal import Decimal from uuid import UUID import pytest @@ -18,7 +19,7 @@ def test_complete_valid(): lax_validator = SchemaValidator(lax_schema) output = lax_validator.validate_python(input_data_lax()) assert isinstance(output, cls) - assert len(output.__pydantic_fields_set__) == 40 + assert len(output.__pydantic_fields_set__) == 41 output_dict = output.__dict__ assert output_dict == { 'field_str': 'fo', @@ -27,6 +28,7 @@ def test_complete_valid(): 'field_int_con': 8, 'field_float': 1.0, 'field_float_con': 10.0, + 'field_decimal': Decimal('42.0'), 'field_bool': True, 'field_bytes': b'foobar', 'field_bytes_con': b'foobar', @@ -81,7 +83,7 @@ def test_complete_invalid(): lax_validator = SchemaValidator(lax_schema) with pytest.raises(ValidationError) as exc_info: lax_validator.validate_python(input_data_wrong()) - assert len(exc_info.value.errors(include_url=False)) == 738 + assert len(exc_info.value.errors(include_url=False)) == 739 @pytest.mark.benchmark(group='complete') diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index f1c0afe17..4ac2d30ab 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1,6 +1,7 @@ """ Numerous benchmarks of specific functionality. """ +import decimal import json import platform import sys @@ -14,7 +15,14 @@ from dirty_equals import IsStr import pydantic_core -from pydantic_core import ArgsKwargs, PydanticCustomError, SchemaValidator, ValidationError, core_schema +from pydantic_core import ( + ArgsKwargs, + PydanticCustomError, + PydanticKnownError, + SchemaValidator, + ValidationError, + core_schema, +) from pydantic_core import ValidationError as CoreValidationError skip_pypy_deep_stack = pytest.mark.skipif( @@ -1310,3 +1318,61 @@ def f(v: int, info: core_schema.FieldValidationInfo) -> int: assert v.validate_python(payload) == {'x': limit} benchmark(v.validate_python, payload) + + +class TestBenchmarkDecimal: + @pytest.fixture(scope='class') + def validator(self): + return SchemaValidator({'type': 'decimal'}) + + @pytest.fixture(scope='class') + def pydantic_validator(self): + Decimal = decimal.Decimal + + def to_decimal(v: str) -> decimal.Decimal: + try: + return Decimal(v) + except decimal.DecimalException as e: + raise PydanticCustomError('decimal_parsing', 'Input should be a valid decimal') from e + + primitive_schema = core_schema.union_schema( + [ + # if it's an int keep it like that and pass it straight to Decimal + # but if it's not make it a string + # we don't use JSON -> float because parsing to any float will cause + # loss of precision + core_schema.int_schema(strict=True), + core_schema.str_schema(strict=True, strip_whitespace=True), + core_schema.no_info_plain_validator_function(str), + ] + ) + json_schema = core_schema.no_info_after_validator_function(to_decimal, primitive_schema) + schema = core_schema.json_or_python_schema( + json_schema=json_schema, + python_schema=core_schema.lax_or_strict_schema( + lax_schema=core_schema.union_schema([core_schema.is_instance_schema(decimal.Decimal), json_schema]), + strict_schema=core_schema.is_instance_schema(decimal.Decimal), + ), + serialization=core_schema.to_string_ser_schema(when_used='json'), + ) + + def check_finite(value: decimal.Decimal) -> decimal.Decimal: + if not value.is_finite(): + raise PydanticKnownError('finite_number') + return value + + schema = core_schema.no_info_after_validator_function(check_finite, schema) + + return SchemaValidator(schema) + + @pytest.mark.benchmark(group='decimal from str') + def test_decimal_from_string_core(self, benchmark, validator): + benchmark(validator.validate_python, '123.456789') + + @pytest.mark.benchmark(group='decimal from str') + def test_decimal_from_string_pyd(self, benchmark, pydantic_validator): + benchmark(pydantic_validator.validate_python, '123.456789') + + @pytest.mark.benchmark(group='decimal from str') + def test_decimal_from_string_limit(self, benchmark): + benchmark(decimal.Decimal, '123.456789') diff --git a/tests/test_errors.py b/tests/test_errors.py index a8289f087..50462d1a7 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -285,6 +285,15 @@ def f(input_value, info): ('uuid_type', 'UUID input should be a string, bytes or UUID object', None), ('uuid_parsing', 'Input should be a valid UUID, Foobar', {'error': 'Foobar'}), ('uuid_version', 'UUID version 42 expected', {'expected_version': 42}), + ('decimal_type', 'Decimal input should be an integer, float, string or Decimal object', None), + ('decimal_parsing', 'Input should be a valid decimal', None), + ('decimal_max_digits', 'Decimal input should have no more than 42 digits in total', {'max_digits': 42}), + ('decimal_max_places', 'Decimal input should have no more than 42 decimal places', {'decimal_places': 42}), + ( + 'decimal_whole_digits', + 'Decimal input should have no more than 42 digits before the decimal point', + {'max_whole_digits': 42}, + ), ] diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 18e6ed717..512a1977d 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -279,6 +279,8 @@ def args(*args, **kwargs): {'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass, 'slots': True}, ), (core_schema.uuid_schema, args(), {'type': 'uuid'}), + (core_schema.decimal_schema, args(), {'type': 'decimal'}), + (core_schema.decimal_schema, args(multiple_of=5, gt=1.2), {'type': 'decimal', 'multiple_of': 5, 'gt': 1.2}), ]