Skip to content

Commit

Permalink
add decimal validator
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jul 20, 2023
1 parent f38bea6 commit e88e436
Show file tree
Hide file tree
Showing 12 changed files with 605 additions and 7 deletions.
9 changes: 6 additions & 3 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from __future__ import annotations as _annotations

import decimal
import importlib.util
import re
from collections.abc import Callable
Expand All @@ -23,6 +24,7 @@
UnionType = Union[TypingUnionType, TypesUnionType]

except ImportError:
TypesUnionType = None
UnionType = TypingUnionType


Expand All @@ -46,8 +48,8 @@
def get_schema(obj) -> core_schema.CoreSchema:
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str):
return {'type': obj.__name__}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj)
elif obj == Any or obj == type:
Expand All @@ -57,7 +59,7 @@ def get_schema(obj) -> core_schema.CoreSchema:

origin = get_origin(obj)
assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py'
if origin is Union:
if origin is Union or origin is TypesUnionType:
return union_schema(obj)
elif obj is Callable or origin is Callable:
return {'type': 'callable'}
Expand All @@ -79,6 +81,7 @@ def get_schema(obj) -> core_schema.CoreSchema:
# can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators
return {'type': 'any'}
else:
print(origin)
# debug(obj)
raise TypeError(f'Unknown type: {obj!r}')

Expand Down
69 changes: 69 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -658,6 +659,72 @@ def float_schema(
)


class DecimalSchema(TypedDict, total=False):
type: Required[Literal['decimal']]
gt: int | Decimal
ge: int | Decimal
lt: int | Decimal
le: int | Decimal
max_digits: int
decimal_places: int
multiple_of: int | Decimal
allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: False
check_digits: bool # FIXME document. default: False
strict: bool
ref: str
metadata: Any
serialization: SerSchema


def decimal_schema(
*,
gt: int | Decimal | None = None,
ge: int | Decimal | None = None,
lt: int | Decimal | None = None,
le: int | Decimal | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
multiple_of: int | Decimal | None = None,
allow_inf_nan: bool = None,
check_digits: bool = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> DecimalSchema:
"""
Returns a schema that matches a decimal value, e.g.:
```py
from decimal import Decimal
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.decimal_schema(le=0.8, ge=0.2)
v = SchemaValidator(schema)
assert v.validate_python('0.5') == Decimal('0.5')
```
Args:
FIXME document
"""
return _dict_not_none(
type='decimal',
gt=gt,
ge=ge,
lt=lt,
le=le,
max_digits=max_digits,
decimal_places=decimal_places,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
check_digits=check_digits,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: str
Expand Down Expand Up @@ -3690,6 +3757,7 @@ def definition_reference_schema(
BoolSchema,
IntSchema,
FloatSchema,
DecimalSchema,
StringSchema,
BytesSchema,
DateSchema,
Expand Down Expand Up @@ -3743,6 +3811,7 @@ def definition_reference_schema(
'bool',
'int',
'float',
'decimal',
'str',
'bytes',
'date',
Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod value_exception;

pub use self::line_error::{InputValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::types::{list_all_errors, ErrorMode, ErrorType};
pub use self::types::{list_all_errors, ErrorMode, ErrorType, Number};
pub use self::validation_exception::ValidationError;
pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault};

Expand Down
13 changes: 13 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_float()
}

fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
if strict {
self.strict_decimal(decimal_type)
} else {
self.lax_decimal(decimal_type)
}
}
fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
self.strict_decimal(decimal_type)
}

fn validate_dict(&'a self, strict: bool) -> ValResult<GenericMapping<'a>> {
if strict {
self.strict_dict()
Expand Down
26 changes: 25 additions & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyType};
use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
use crate::PydanticCustomError;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -172,6 +173,25 @@ impl<'a> Input<'a> for JsonInput {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
match self {
JsonInput::String(str) => Ok(decimal_type.call1((str,))?),
_ =>
// FIXME make a known error
{
Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
decimal_type.py(),
"decimal_type".into(),
"Input should be a valid Decimal instance or decimal string in JSON".into(),
None,
)),
self,
))
}
}
}

fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
match self {
JsonInput::Object(dict) => Ok(dict.into()),
Expand Down Expand Up @@ -417,6 +437,10 @@ impl<'a> Input<'a> for String {
}
}

fn strict_decimal(&'a self, _decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
todo!()
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
Err(ValError::new(ErrorType::DictType, self))
Expand Down
60 changes: 59 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl, PydanticCustomError};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
Expand Down Expand Up @@ -352,6 +352,64 @@ impl<'a> Input<'a> for PyAny {
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
let py = self.py();

// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

// Try subclasses of decimals, they will be upcast to Decimal
if self.is_instance(decimal_type)? {
match decimal_type.call1((self,)) {
Ok(decimal) => return Ok(decimal),
Err(e) if e.is_instance(py, py.import("decimal")?.getattr("DecimalException")?) => {}
Err(e) => return Err(ValError::InternalErr(e)),
}
}

// FIXME make a known error
Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
self.py(),
"decimal_type".into(),
"Input should be a valid Decimal instance or decimal string in JSON".into(),
None,
)),
self,
))
}

fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
let py = self.py();

// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self);
}

// The decimal constructor will accept list and tuple types, but we don't want
// to accept them.
if !(self.is_instance_of::<PyTuple>() || self.is_instance_of::<PyList>()) {
match decimal_type.call1((self,)) {
Ok(decimal) => return Ok(decimal),
Err(e) if e.is_instance(py, py.import("decimal")?.getattr("DecimalException")?) => {}
Err(e) => return Err(ValError::InternalErr(e)),
}
};

Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
py,
"decimal_parsing".into(),
"input should be a valid decimal".into(),
None,
)),
self,
))
}

fn strict_dict(&'a self) -> ValResult<GenericMapping<'a>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(dict.into())
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ combined_serializer! {
Int: super::type_serializers::simple::IntSerializer;
Bool: super::type_serializers::simple::BoolSerializer;
Float: super::type_serializers::simple::FloatSerializer;
Decimal: super::type_serializers::decimal::DecimalSerializer;
Str: super::type_serializers::string::StrSerializer;
Bytes: super::type_serializers::bytes::BytesSerializer;
Datetime: super::type_serializers::datetime_etc::DatetimeSerializer;
Expand Down Expand Up @@ -228,6 +229,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Decimal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit),
Expand Down
Loading

0 comments on commit e88e436

Please sign in to comment.