From 3f7c0105ebe2a5f042cf0585c0108cebcf5c7013 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 13 Jul 2023 11:59:45 -0600 Subject: [PATCH] Make validating assignment work properly with allowed extra (#766) --- python/pydantic_core/_pydantic_core.pyi | 5 ++++- src/validators/model.rs | 25 ++++++++++++++++-------- src/validators/model_fields.rs | 26 ++++++++++++++++++++----- tests/validators/test_function.py | 4 ++++ tests/validators/test_model.py | 8 +++++++- tests/validators/test_model_fields.py | 16 ++++++++------- 6 files changed, 62 insertions(+), 22 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index d1cf33796..11c4d611a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -102,7 +102,10 @@ class SchemaValidator: strict: bool | None = None, from_attributes: bool | None = None, context: 'dict[str, Any] | None' = None, - ) -> dict[str, Any]: ... + ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]: + """ + ModelValidator and ModelFieldsValidator will return a tuple of (fields data, extra data, fields set) + """ def get_default_value(self, *, strict: bool | None = None, context: Any = None) -> Some | None: ... _IncEx: TypeAlias = set[int] | set[str] | dict[int, _IncEx] | dict[str, _IncEx] | None diff --git a/src/validators/model.rs b/src/validators/model.rs index 51d0d4b33..a746bd38b 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -183,14 +183,18 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) }; } - let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?; + let old_dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?; - let new_dict = dict.copy()?; - new_dict.set_item(field_name, field_value)?; + let input_dict = old_dict.copy()?; + let old_extra: Option<&PyDict> = model.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?.downcast().ok(); + if let Some(old_extra) = old_extra { + input_dict.update(old_extra.as_mapping())?; + } + input_dict.set_item(field_name, field_value)?; let output = self.validator.validate_assignment( py, - new_dict, + input_dict, field_name, field_value, extra, @@ -198,17 +202,22 @@ impl Validator for ModelValidator { recursion_guard, )?; - let (output, _, updated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?; + let (validated_dict, validated_extra, validated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?; if let Ok(fields_set) = model.getattr(intern!(py, DUNDER_FIELDS_SET_KEY)) { let fields_set: &PySet = fields_set.downcast()?; - for field_name in updated_fields_set { + for field_name in validated_fields_set { fields_set.add(field_name)?; } } - let output = output.to_object(py); - force_setattr(py, model, intern!(py, DUNDER_DICT), output)?; + force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?; + force_setattr( + py, + model, + intern!(py, DUNDER_MODEL_EXTRA_KEY), + validated_extra.to_object(py), + )?; Ok(model.into_py(py)) } diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index c44a8910f..d517991d0 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -302,13 +302,13 @@ impl Validator for ModelFieldsValidator { ) -> ValResult<'data, PyObject> { let dict: &PyDict = obj.downcast()?; - let ok = |output: PyObject| { + let get_updated_dict = |output: PyObject| { dict.set_item(field_name, output)?; - Ok(dict.to_object(py)) + Ok(dict) }; let prepare_result = |result: ValResult<'data, PyObject>| match result { - Ok(output) => ok(output), + Ok(output) => get_updated_dict(output), Err(ValError::LineErrors(line_errors)) => { let errors = line_errors .into_iter() @@ -358,7 +358,7 @@ impl Validator for ModelFieldsValidator { Some(ref validator) => { prepare_result(validator.validate(py, field_value, &extra, definitions, recursion_guard)) } - None => ok(field_value.to_object(py)), + None => get_updated_dict(field_value.to_object(py)), }, ExtraBehavior::Forbid | ExtraBehavior::Ignore => { return Err(ValError::new_with_loc( @@ -372,8 +372,24 @@ impl Validator for ModelFieldsValidator { } }?; + let new_extra = match &self.extra_behavior { + ExtraBehavior::Allow => { + let non_extra_data = PyDict::new(py); + self.fields.iter().for_each(|f| { + let popped_value = PyAny::get_item(new_data, &f.name).unwrap(); + new_data.del_item(&f.name).unwrap(); + non_extra_data.set_item(&f.name, popped_value).unwrap(); + }); + let new_extra = new_data.copy()?; + new_data.clear(); + new_data.update(non_extra_data.as_mapping())?; + new_extra.to_object(py) + } + _ => py.None(), + }; + let fields_set: &PySet = PySet::new(py, &[field_name.to_string()])?; - Ok((new_data, py.None(), fields_set.to_object(py)).to_object(py)) + Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } fn different_strict_behavior( diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 03ffcf55d..fca21c558 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -461,6 +461,9 @@ class Model: __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' field_a: str + def __init__(self): + self.__pydantic_extra__ = None # this attribute must be present for validate_assignment + v = SchemaValidator( core_schema.no_info_after_validator_function( f, @@ -474,6 +477,7 @@ class Model: assert m.field_a == 'test' assert m.__pydantic_fields_set__ == {'field_a'} assert m.__dict__ == {'field_a': 'test', 'more': 'foobar'} + assert m.__pydantic_extra__ is None m2 = Model() m2.field_a = 'test' diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index d014e446c..b5afc83a4 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -943,6 +943,9 @@ class MyModel: field_a: str field_b: int + def __init__(self): + self.__pydantic_extra__ = None + v = SchemaValidator( { 'type': 'model', @@ -1019,7 +1022,10 @@ def func(x, info): def test_validate_assignment_no_fields_set(): class MyModel: - __slots__ = ('__dict__',) + __slots__ = ('__dict__', '__pydantic_extra__') + + def __init__(self): + self.__pydantic_extra__ = None v = SchemaValidator( { diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index c8b4b22ff..61d0c779b 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -347,8 +347,8 @@ def test_validate_assignment_allow_extra(): assert v.validate_python({'field_a': 'test'}) == ({'field_a': 'test'}, {}, {'field_a'}) assert v.validate_assignment({'field_a': 'test'}, 'other_field', 456) == ( - {'field_a': 'test', 'other_field': 456}, - None, + {'field_a': 'test'}, + {'other_field': 456}, {'other_field'}, ) @@ -364,8 +364,8 @@ def test_validate_assignment_allow_extra_validate(): ) assert v.validate_assignment({'field_a': 'test'}, 'other_field', '456') == ( - {'field_a': 'test', 'other_field': 456}, - None, + {'field_a': 'test'}, + {'other_field': 456}, {'other_field'}, ) @@ -1682,10 +1682,12 @@ def test_extra_behavior_allow( assert fields_set == {'f', 'extra_field'} v.validate_assignment(m, 'f', 'y') - assert m['f'] == 'y' + assert m == {'f': 'y'} - v.validate_assignment(m, 'not_f', '123') - assert m['not_f'] == expected_extra_value + new_m, new_model_extra, new_fields_set = v.validate_assignment({**m, **model_extra}, 'not_f', '123') + assert new_m == {'f': 'y'} + assert new_model_extra == {'extra_field': expected_extra_value, 'not_f': expected_extra_value} + assert new_fields_set == {'not_f'} @pytest.mark.parametrize(