diff --git a/flag_engine/identities/traits/models.py b/flag_engine/identities/traits/models.py index 712df77..ee2f249 100644 --- a/flag_engine/identities/traits/models.py +++ b/flag_engine/identities/traits/models.py @@ -1,11 +1,21 @@ -from pydantic import BaseModel +from typing import Any, get_args + +from pydantic import BaseModel, validator from flag_engine.identities.traits.types import TraitValue +TRAIT_VALUE_TYPES = get_args(TraitValue) + class TraitModel(BaseModel): trait_key: str trait_value: TraitValue = ... + @validator("trait_value", pre=True) + def convert_trait_value(cls, value: Any) -> TraitValue: + if isinstance(value, TRAIT_VALUE_TYPES): + return value + return str(value) + class Config: smart_union: bool = True diff --git a/tests/unit/identities/test_identities_models.py b/tests/unit/identities/test_identities_models.py index 0a221e2..11a15dd 100644 --- a/tests/unit/identities/test_identities_models.py +++ b/tests/unit/identities/test_identities_models.py @@ -1,4 +1,5 @@ from typing import Any + import pytest from flag_engine.features.models import FeatureModel, FeatureStateModel @@ -186,6 +187,9 @@ def test_get_hash_key_with_use_identity_composite_key_for_hashing_disabled(ident (False, False), (0.0, 0.0), (0, 0), + (None, None), + ([], "[]"), + (["SUPERADMIN"], "['SUPERADMIN']"), ], ) def test_trait_model__deserialize__expected_trait_value(