Skip to content

Commit

Permalink
add decimal validator
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jul 25, 2023
1 parent f5ef7af commit 0f7ca0d
Show file tree
Hide file tree
Showing 17 changed files with 654 additions and 11 deletions.
9 changes: 6 additions & 3 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from __future__ import annotations as _annotations

import decimal
import importlib.util
import re
from collections.abc import Callable
Expand All @@ -23,6 +24,7 @@
UnionType = Union[TypingUnionType, TypesUnionType]

except ImportError:
TypesUnionType = None
UnionType = TypingUnionType


Expand All @@ -46,8 +48,8 @@
def get_schema(obj) -> core_schema.CoreSchema:
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str):
return {'type': obj.__name__}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj)
elif obj == Any or obj == type:
Expand All @@ -57,7 +59,7 @@ def get_schema(obj) -> core_schema.CoreSchema:

origin = get_origin(obj)
assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py'
if origin is Union:
if origin is Union or origin is TypesUnionType:
return union_schema(obj)
elif obj is Callable or origin is Callable:
return {'type': 'callable'}
Expand All @@ -79,6 +81,7 @@ def get_schema(obj) -> core_schema.CoreSchema:
# can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators
return {'type': 'any'}
else:
print(origin)
# debug(obj)
raise TypeError(f'Unknown type: {obj!r}')

Expand Down
74 changes: 74 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -659,6 +660,72 @@ def float_schema(
)


class DecimalSchema(TypedDict, total=False):
type: Required[Literal['decimal']]
gt: Union[int, Decimal]
ge: Union[int, Decimal]
lt: Union[int, Decimal]
le: Union[int, Decimal]
max_digits: int
decimal_places: int
multiple_of: Union[int, Decimal]
allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: False
check_digits: bool # FIXME document. default: False
strict: bool
ref: str
metadata: Any
serialization: SerSchema


def decimal_schema(
*,
gt: int | Decimal | None = None,
ge: int | Decimal | None = None,
lt: int | Decimal | None = None,
le: int | Decimal | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
multiple_of: int | Decimal | None = None,
allow_inf_nan: bool = None,
check_digits: bool = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> DecimalSchema:
"""
Returns a schema that matches a decimal value, e.g.:
```py
from decimal import Decimal
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.decimal_schema(le=0.8, ge=0.2)
v = SchemaValidator(schema)
assert v.validate_python('0.5') == Decimal('0.5')
```
Args:
FIXME document
"""
return _dict_not_none(
type='decimal',
gt=gt,
ge=ge,
lt=lt,
le=le,
max_digits=max_digits,
decimal_places=decimal_places,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
check_digits=check_digits,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: str
Expand Down Expand Up @@ -3713,6 +3780,7 @@ def definition_reference_schema(
BoolSchema,
IntSchema,
FloatSchema,
DecimalSchema,
StringSchema,
BytesSchema,
DateSchema,
Expand Down Expand Up @@ -3767,6 +3835,7 @@ def definition_reference_schema(
'bool',
'int',
'float',
'decimal',
'str',
'bytes',
'date',
Expand Down Expand Up @@ -3907,6 +3976,11 @@ def definition_reference_schema(
'uuid_type',
'uuid_parsing',
'uuid_version',
'decimal_type',
'decimal_parsing',
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
]


Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod value_exception;

pub use self::line_error::{InputValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::types::{list_all_errors, ErrorMode, ErrorType};
pub use self::types::{list_all_errors, ErrorMode, ErrorType, Number};
pub use self::validation_exception::ValidationError;
pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault};

Expand Down
31 changes: 28 additions & 3 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ pub enum ErrorType {
// None errors
NoneRequired,
// ---------------------
// generic comparison errors - used for all inequality comparisons except int and float which have their
// own type, bounds arguments are Strings so they can be created from any type
// generic comparison errors
GreaterThan {
gt: Number,
},
Expand Down Expand Up @@ -316,6 +315,18 @@ pub enum ErrorType {
UuidVersion {
expected_version: usize,
},
// Decimal errors
DecimalType,
DecimalParsing,
DecimalMaxDigits {
max_digits: u64,
},
DecimalMaxPlaces {
decimal_places: u64,
},
DecimalWholeDigits {
max_whole_digits: u64,
},
}

macro_rules! render {
Expand Down Expand Up @@ -463,6 +474,9 @@ impl ErrorType {
Self::UuidVersion { .. } => {
extract_context!(UuidVersion, ctx, expected_version: usize)
}
Self::DecimalMaxDigits { .. } => extract_context!(DecimalMaxDigits, ctx, max_digits: u64),
Self::DecimalMaxPlaces { .. } => extract_context!(DecimalMaxPlaces, ctx, decimal_places: u64),
Self::DecimalWholeDigits { .. } => extract_context!(DecimalWholeDigits, ctx, max_whole_digits: u64),
_ => {
if ctx.is_some() {
py_err!(PyTypeError; "'{}' errors do not require context", value)
Expand Down Expand Up @@ -569,7 +583,12 @@ impl ErrorType {
Self::UrlScheme {..} => "URL scheme should be {expected_schemes}",
Self::UuidType => "UUID input should be a string, bytes or UUID object",
Self::UuidParsing { .. } => "Input should be a valid UUID, {error}",
Self::UuidVersion { .. } => "UUID version {expected_version} expected"
Self::UuidVersion { .. } => "UUID version {expected_version} expected",
Self::DecimalType => "Decimal input should be an integer, float, string or Decimal object",
Self::DecimalParsing => "Input should be a valid decimal",
Self::DecimalMaxDigits { .. } => "Decimal input should have no more than {max_digits} digits in total",
Self::DecimalMaxPlaces { .. } => "Decimal input should have no more than {decimal_places} decimal places",
Self::DecimalWholeDigits { .. } => "Decimal input should have no more than {max_whole_digits} digits before the decimal point",

}
}
Expand Down Expand Up @@ -692,6 +711,9 @@ impl ErrorType {
Self::UrlScheme { expected_schemes } => render!(tmpl, expected_schemes),
Self::UuidParsing { error } => render!(tmpl, error),
Self::UuidVersion { expected_version } => to_string_render!(tmpl, expected_version),
Self::DecimalMaxDigits { max_digits } => to_string_render!(tmpl, max_digits),
Self::DecimalMaxPlaces { decimal_places } => to_string_render!(tmpl, decimal_places),
Self::DecimalWholeDigits { max_whole_digits } => to_string_render!(tmpl, max_whole_digits),
_ => Ok(tmpl.to_string()),
}
}
Expand Down Expand Up @@ -755,6 +777,9 @@ impl ErrorType {

Self::UuidParsing { error } => py_dict!(py, error),
Self::UuidVersion { expected_version } => py_dict!(py, expected_version),
Self::DecimalMaxDigits { max_digits } => py_dict!(py, max_digits),
Self::DecimalMaxPlaces { decimal_places } => py_dict!(py, decimal_places),
Self::DecimalWholeDigits { max_whole_digits } => py_dict!(py, max_whole_digits),
_ => Ok(None),
}
}
Expand Down
13 changes: 13 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_float()
}

fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
if strict {
self.strict_decimal(decimal_type)
} else {
self.lax_decimal(decimal_type)
}
}
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
self.strict_decimal(decimal_type)
}

fn validate_dict(&'a self, strict: bool) -> ValResult<GenericMapping<'a>> {
if strict {
self.strict_dict()
Expand Down
23 changes: 22 additions & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyType};
use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

Expand Down Expand Up @@ -172,6 +172,23 @@ impl<'a> Input<'a> for JsonInput {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
match self {
JsonInput::String(str) => Ok(decimal_type.call1((str,))?),
_ => Err(ValError::new(ErrorType::DecimalType, self)),
}
}

fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
match self {
JsonInput::Float(f) => Ok(decimal_type.call1((f.to_string(),))?),
JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => {
Ok(decimal_type.call1((self.to_object(decimal_type.py()),))?)
}
_ => Err(ValError::new(ErrorType::DecimalType, self)),
}
}

fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
match self {
JsonInput::Object(dict) => Ok(dict.into()),
Expand Down Expand Up @@ -417,6 +434,10 @@ impl<'a> Input<'a> for String {
}
}

fn strict_decimal(&'a self, _decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
todo!()
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
Err(ValError::new(ErrorType::DictType, self))
Expand Down
52 changes: 52 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::str::from_utf8;

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
Expand All @@ -13,6 +14,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::decimal::create_decimal;
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};

use super::datetime::{
Expand Down Expand Up @@ -352,6 +354,56 @@ impl<'a> Input<'a> for PyAny {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
let py = decimal_type.py();

// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

// Try subclasses of decimals, they will be upcast to Decimal
if self.is_instance(decimal_type)? {
match decimal_type.call1((self,)) {
Ok(decimal) => return Ok(decimal),
Err(e)
if e.matches(
py,
PyTuple::new(
py,
[
py.import("decimal")?.getattr("DecimalException")?,
PyTypeError::type_object(py),
],
),
) => {}
Err(e) => return Err(ValError::InternalErr(e)),
}
}

Err(ValError::new(
ErrorType::IsInstanceOf {
class: decimal_type.name().unwrap_or("Decimal").to_string(),
},
self,
))
}

fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

if self.is_instance_of::<PyString>() || self.is_instance_of::<PyInt>() {
create_decimal(self, self, decimal_type)
} else if self.is_instance_of::<PyFloat>() {
create_decimal(self.str()?, self, decimal_type)
} else {
Err(ValError::new(ErrorType::DecimalType, self))
}
}

fn strict_dict(&'a self) -> ValResult<GenericMapping<'a>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(dict.into())
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ combined_serializer! {
Int: super::type_serializers::simple::IntSerializer;
Bool: super::type_serializers::simple::BoolSerializer;
Float: super::type_serializers::simple::FloatSerializer;
Decimal: super::type_serializers::decimal::DecimalSerializer;
Str: super::type_serializers::string::StrSerializer;
Bytes: super::type_serializers::bytes::BytesSerializer;
Datetime: super::type_serializers::datetime_etc::DatetimeSerializer;
Expand Down Expand Up @@ -229,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Decimal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit),
Expand Down
Loading

0 comments on commit 0f7ca0d

Please sign in to comment.