From 1d74dd38ac79044f269b213fb6dd750e5e9e596d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 16 Feb 2022 10:38:51 +0000 Subject: [PATCH 1/3] Auto apply dataclass decorator to subclasses --- coqpit/coqpit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 07cd423..105f941 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -578,6 +578,9 @@ class Coqpit(Serializable, MutableMapping): _initialized = False + def __init_subclass__(cls, **kwargs): + return dataclass(_cls=cls, **kwargs) + def _is_initialized(self): """Check if Coqpit is initialized. Useful to prevent running some aux functions at the initialization when no attribute has been defined.""" From 30fc47ae30d5eb263ca2bc50d714e00c5d706287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 16 Feb 2022 10:41:19 +0000 Subject: [PATCH 2/3] Add test --- tests/test_metaing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 tests/test_metaing.py diff --git a/tests/test_metaing.py b/tests/test_metaing.py new file mode 100644 index 0000000..c8671b5 --- /dev/null +++ b/tests/test_metaing.py @@ -0,0 +1,10 @@ +import dataclasses +from coqpit.coqpit import Coqpit + + +class SimpleConstructConfig(Coqpit): + a: int = 20 + + +def test_copying(): + assert dataclasses.is_dataclass(SimpleConstructConfig()) From f0cf5e5da93b69829e4a79b038a723a51f913355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 16 Feb 2022 10:44:56 +0000 Subject: [PATCH 3/3] Fix linter --- coqpit/coqpit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 105f941..55a53ff 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -408,7 +408,7 @@ def deserialize(self, data: dict) -> "Serializable": init_kwargs[field.name] = value continue if value == MISSING: - raise ValueError("deserialized with unknown value for {} in {}".format(field.name, self.__name__)) + raise ValueError(f"deserialized with unknown value for {field.name} in {self.__name__}") value = _deserialize(value, field.type) init_kwargs[field.name] = value for k, v in init_kwargs.items(): @@ -438,7 +438,7 @@ def deserialize_immutable(cls, data: dict) -> "Serializable": init_kwargs[field.name] = value continue if value == MISSING: - raise ValueError("Deserialized with unknown value for {} in {}".format(field.name, cls.__name__)) + raise ValueError(f"Deserialized with unknown value for {field.name} in {cls.__name__}") value = _deserialize(value, field.type) init_kwargs[field.name] = value return cls(**init_kwargs)