Skip to content

Commit

Permalink
add decimal validator
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jul 19, 2023
1 parent f38bea6 commit e8254f0
Show file tree
Hide file tree
Showing 11 changed files with 532 additions and 6 deletions.
9 changes: 6 additions & 3 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from __future__ import annotations as _annotations

import decimal
import importlib.util
import re
from collections.abc import Callable
Expand All @@ -23,6 +24,7 @@
UnionType = Union[TypingUnionType, TypesUnionType]

except ImportError:
TypesUnionType = None
UnionType = TypingUnionType


Expand All @@ -46,8 +48,8 @@
def get_schema(obj) -> core_schema.CoreSchema:
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str):
return {'type': obj.__name__}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj)
elif obj == Any or obj == type:
Expand All @@ -57,7 +59,7 @@ def get_schema(obj) -> core_schema.CoreSchema:

origin = get_origin(obj)
assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py'
if origin is Union:
if origin is Union or origin is TypesUnionType:
return union_schema(obj)
elif obj is Callable or origin is Callable:
return {'type': 'callable'}
Expand All @@ -79,6 +81,7 @@ def get_schema(obj) -> core_schema.CoreSchema:
# can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators
return {'type': 'any'}
else:
print(origin)
# debug(obj)
raise TypeError(f'Unknown type: {obj!r}')

Expand Down
69 changes: 69 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
Expand Down Expand Up @@ -658,6 +659,72 @@ def float_schema(
)


class DecimalSchema(TypedDict, total=False):
type: Required[Literal['decimal']]
gt: int | Decimal
ge: int | Decimal
lt: int | Decimal
le: int | Decimal
max_digits: int
decimal_places: int
multiple_of: int | Decimal
allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: False
check_digits: bool # FIXME document. default: False
strict: bool
ref: str
metadata: Any
serialization: SerSchema


def decimal_schema(
*,
gt: int | Decimal | None = None,
ge: int | Decimal | None = None,
lt: int | Decimal | None = None,
le: int | Decimal | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
multiple_of: int | Decimal | None = None,
allow_inf_nan: bool = None,
check_digits: bool = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> DecimalSchema:
"""
Returns a schema that matches a decimal value, e.g.:
```py
from decimal import Decimal
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.decimal_schema(le=0.8, ge=0.2)
v = SchemaValidator(schema)
assert v.validate_python('0.5') == Decimal('0.5')
```
Args:
FIXME document
"""
return _dict_not_none(
type='decimal',
gt=gt,
ge=ge,
lt=lt,
le=le,
max_digits=max_digits,
decimal_places=decimal_places,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
check_digits=check_digits,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: str
Expand Down Expand Up @@ -3690,6 +3757,7 @@ def definition_reference_schema(
BoolSchema,
IntSchema,
FloatSchema,
DecimalSchema,
StringSchema,
BytesSchema,
DateSchema,
Expand Down Expand Up @@ -3743,6 +3811,7 @@ def definition_reference_schema(
'bool',
'int',
'float',
'decimal',
'str',
'bytes',
'date',
Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod value_exception;

pub use self::line_error::{InputValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::types::{list_all_errors, ErrorMode, ErrorType};
pub use self::types::{list_all_errors, ErrorMode, ErrorType, Number};
pub use self::validation_exception::ValidationError;
pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault};

Expand Down
13 changes: 13 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_float()
}

fn validate_decimal(&'a self, strict: bool, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
if strict {
self.strict_decimal(decimal_type)
} else {
self.lax_decimal(decimal_type)
}
}
fn strict_decimal(&'a self, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_decimal(&'a self, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
self.strict_decimal(decimal_type)
}

fn validate_dict(&'a self, strict: bool) -> ValResult<GenericMapping<'a>> {
if strict {
self.strict_dict()
Expand Down
28 changes: 27 additions & 1 deletion src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyType};
use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
use crate::PydanticCustomError;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -172,6 +173,27 @@ impl<'a> Input<'a> for JsonInput {
}
}

fn strict_decimal(&'a self, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
// FIXME pure evil!
let py = unsafe { Python::assume_gil_acquired() };
match self {
JsonInput::String(str) => Ok(unsafe { std::mem::transmute(decimal_type.as_ref(py).call1((str,))?) }),
_ =>
// FIXME make a known error
{
Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
py,
"decimal_type".into(),
"Input should be a valid Decimal instance or decimal string in JSON".into(),
None,
)),
self,
))
}
}
}

fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
match self {
JsonInput::Object(dict) => Ok(dict.into()),
Expand Down Expand Up @@ -417,6 +439,10 @@ impl<'a> Input<'a> for String {
}
}

fn strict_decimal(&'a self, _decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
todo!()
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn validate_dict(&'a self, _strict: bool) -> ValResult<GenericMapping<'a>> {
Err(ValError::new(ErrorType::DictType, self))
Expand Down
54 changes: 53 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl, PydanticCustomError};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
Expand Down Expand Up @@ -352,6 +352,58 @@ impl<'a> Input<'a> for PyAny {
}
}

fn strict_decimal(&'a self, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
if self.is_instance(decimal_type.as_ref(self.py()))? {
Ok(self)
} else {
// FIXME make a known error
Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
self.py(),
"decimal_type".into(),
"Input should be a valid Decimal instance or decimal string in JSON".into(),
None,
)),
self,
))
}
}

fn lax_decimal(&'a self, decimal_type: &Py<PyType>) -> ValResult<&'a PyAny> {
let py = self.py();

if self.is_instance(decimal_type.as_ref(self.py()))? {
return Ok(self);
}

let str = if let Ok(str) = self.downcast::<PyString>() {
str
} else if let Ok(float) = self.downcast::<PyFloat>() {
float.str()?
} else if let Ok(int) = self.downcast::<PyInt>() {
int.str()?
} else {
self.str().map_or(self, |s| s.as_ref())
};

// Fixme use most efficient calling convention
match decimal_type.as_ref(py).call1((str,)) {
// Fixme lifetime hack a bit evil
Ok(decimal) => Ok(decimal.to_object(py).into_ref(py)),
// Fixme cache exception instance... on the validator?
Err(e) if e.is_instance(py, py.import("decimal")?.getattr("DecimalException")?) => Err(ValError::new(
ErrorType::new_custom_error(PydanticCustomError::py_new(
py,
"decimal_parsing".into(),
"input should be a valid decimal".into(),
None,
)),
self,
)),
Err(e) => Err(ValError::InternalErr(e)),
}
}

fn strict_dict(&'a self) -> ValResult<GenericMapping<'a>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(dict.into())
Expand Down
1 change: 1 addition & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ combined_serializer! {
Int: super::type_serializers::simple::IntSerializer;
Bool: super::type_serializers::simple::BoolSerializer;
Float: super::type_serializers::simple::FloatSerializer;
Decimal: super::type_serializers::decimal::DecimalSerializer;
Str: super::type_serializers::string::StrSerializer;
Bytes: super::type_serializers::bytes::BytesSerializer;
Datetime: super::type_serializers::datetime_etc::DatetimeSerializer;
Expand Down
Loading

0 comments on commit e8254f0

Please sign in to comment.