Skip to content

Commit

Permalink
Make validating assignment work properly with allowed extra (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu authored Jul 13, 2023
1 parent f5b804b commit 3f7c010
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 22 deletions.
5 changes: 4 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ class SchemaValidator:
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
) -> dict[str, Any]: ...
) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]:
"""
ModelValidator and ModelFieldsValidator will return a tuple of (fields data, extra data, fields set)
"""
def get_default_value(self, *, strict: bool | None = None, context: Any = None) -> Some | None: ...

_IncEx: TypeAlias = set[int] | set[str] | dict[int, _IncEx] | dict[str, _IncEx] | None
Expand Down
25 changes: 17 additions & 8 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,32 +183,41 @@ impl Validator for ModelValidator {
Ok(model.into_py(py))
};
}
let dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;
let old_dict: &PyDict = model.getattr(intern!(py, DUNDER_DICT))?.downcast()?;

let new_dict = dict.copy()?;
new_dict.set_item(field_name, field_value)?;
let input_dict = old_dict.copy()?;
let old_extra: Option<&PyDict> = model.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?.downcast().ok();
if let Some(old_extra) = old_extra {
input_dict.update(old_extra.as_mapping())?;
}
input_dict.set_item(field_name, field_value)?;

let output = self.validator.validate_assignment(
py,
new_dict,
input_dict,
field_name,
field_value,
extra,
definitions,
recursion_guard,
)?;

let (output, _, updated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?;
let (validated_dict, validated_extra, validated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?;

if let Ok(fields_set) = model.getattr(intern!(py, DUNDER_FIELDS_SET_KEY)) {
let fields_set: &PySet = fields_set.downcast()?;
for field_name in updated_fields_set {
for field_name in validated_fields_set {
fields_set.add(field_name)?;
}
}
let output = output.to_object(py);

force_setattr(py, model, intern!(py, DUNDER_DICT), output)?;
force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
force_setattr(
py,
model,
intern!(py, DUNDER_MODEL_EXTRA_KEY),
validated_extra.to_object(py),
)?;
Ok(model.into_py(py))
}

Expand Down
26 changes: 21 additions & 5 deletions src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ impl Validator for ModelFieldsValidator {
) -> ValResult<'data, PyObject> {
let dict: &PyDict = obj.downcast()?;

let ok = |output: PyObject| {
let get_updated_dict = |output: PyObject| {
dict.set_item(field_name, output)?;
Ok(dict.to_object(py))
Ok(dict)
};

let prepare_result = |result: ValResult<'data, PyObject>| match result {
Ok(output) => ok(output),
Ok(output) => get_updated_dict(output),
Err(ValError::LineErrors(line_errors)) => {
let errors = line_errors
.into_iter()
Expand Down Expand Up @@ -358,7 +358,7 @@ impl Validator for ModelFieldsValidator {
Some(ref validator) => {
prepare_result(validator.validate(py, field_value, &extra, definitions, recursion_guard))
}
None => ok(field_value.to_object(py)),
None => get_updated_dict(field_value.to_object(py)),
},
ExtraBehavior::Forbid | ExtraBehavior::Ignore => {
return Err(ValError::new_with_loc(
Expand All @@ -372,8 +372,24 @@ impl Validator for ModelFieldsValidator {
}
}?;

let new_extra = match &self.extra_behavior {
ExtraBehavior::Allow => {
let non_extra_data = PyDict::new(py);
self.fields.iter().for_each(|f| {
let popped_value = PyAny::get_item(new_data, &f.name).unwrap();
new_data.del_item(&f.name).unwrap();
non_extra_data.set_item(&f.name, popped_value).unwrap();
});
let new_extra = new_data.copy()?;
new_data.clear();
new_data.update(non_extra_data.as_mapping())?;
new_extra.to_object(py)
}
_ => py.None(),
};

let fields_set: &PySet = PySet::new(py, &[field_name.to_string()])?;
Ok((new_data, py.None(), fields_set.to_object(py)).to_object(py))
Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py))
}

fn different_strict_behavior(
Expand Down
4 changes: 4 additions & 0 deletions tests/validators/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ class Model:
__slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__'
field_a: str

def __init__(self):
self.__pydantic_extra__ = None # this attribute must be present for validate_assignment

v = SchemaValidator(
core_schema.no_info_after_validator_function(
f,
Expand All @@ -474,6 +477,7 @@ class Model:
assert m.field_a == 'test'
assert m.__pydantic_fields_set__ == {'field_a'}
assert m.__dict__ == {'field_a': 'test', 'more': 'foobar'}
assert m.__pydantic_extra__ is None

m2 = Model()
m2.field_a = 'test'
Expand Down
8 changes: 7 additions & 1 deletion tests/validators/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,9 @@ class MyModel:
field_a: str
field_b: int

def __init__(self):
self.__pydantic_extra__ = None

v = SchemaValidator(
{
'type': 'model',
Expand Down Expand Up @@ -1019,7 +1022,10 @@ def func(x, info):

def test_validate_assignment_no_fields_set():
class MyModel:
__slots__ = ('__dict__',)
__slots__ = ('__dict__', '__pydantic_extra__')

def __init__(self):
self.__pydantic_extra__ = None

v = SchemaValidator(
{
Expand Down
16 changes: 9 additions & 7 deletions tests/validators/test_model_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def test_validate_assignment_allow_extra():
assert v.validate_python({'field_a': 'test'}) == ({'field_a': 'test'}, {}, {'field_a'})

assert v.validate_assignment({'field_a': 'test'}, 'other_field', 456) == (
{'field_a': 'test', 'other_field': 456},
None,
{'field_a': 'test'},
{'other_field': 456},
{'other_field'},
)

Expand All @@ -364,8 +364,8 @@ def test_validate_assignment_allow_extra_validate():
)

assert v.validate_assignment({'field_a': 'test'}, 'other_field', '456') == (
{'field_a': 'test', 'other_field': 456},
None,
{'field_a': 'test'},
{'other_field': 456},
{'other_field'},
)

Expand Down Expand Up @@ -1682,10 +1682,12 @@ def test_extra_behavior_allow(
assert fields_set == {'f', 'extra_field'}

v.validate_assignment(m, 'f', 'y')
assert m['f'] == 'y'
assert m == {'f': 'y'}

v.validate_assignment(m, 'not_f', '123')
assert m['not_f'] == expected_extra_value
new_m, new_model_extra, new_fields_set = v.validate_assignment({**m, **model_extra}, 'not_f', '123')
assert new_m == {'f': 'y'}
assert new_model_extra == {'extra_field': expected_extra_value, 'not_f': expected_extra_value}
assert new_fields_set == {'not_f'}


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3f7c010

Please sign in to comment.