diff --git a/cognite/client/exceptions.py b/cognite/client/exceptions.py index d9dfff4d9..69d63196d 100644 --- a/cognite/client/exceptions.py +++ b/cognite/client/exceptions.py @@ -16,6 +16,9 @@ class CogniteException(Exception): pass +class CogniteTypeError(CogniteException): ... + + @dataclass class GraphQLErrorSpec: message: str diff --git a/cognite/client/utils/_runtime_type_checking.py b/cognite/client/utils/_runtime_type_checking.py new file mode 100644 index 000000000..6a3748448 --- /dev/null +++ b/cognite/client/utils/_runtime_type_checking.py @@ -0,0 +1,36 @@ +import sys +from inspect import isfunction +from typing import Any, Callable, TypeVar + +from beartype import beartype +from beartype.roar import BeartypeCallHintParamViolation + +from cognite.client.exceptions import CogniteTypeError + +T_Callable = TypeVar("T_Callable", bound=Callable) +T_Class = TypeVar("T_Class", bound=type) + + +class Settings: + enable_runtime_type_checking: bool = False + + +def runtime_type_checked_method(f: T_Callable) -> T_Callable: + if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking: + return f + beartyped_f = beartype(f) + + def f_wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return beartyped_f(*args, **kwargs) + except BeartypeCallHintParamViolation as e: + raise CogniteTypeError(e.args[0]) + + return f_wrapped + + +def runtime_type_checked(c: T_Class) -> T_Class: + for name in dir(c): + if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)): + setattr(c, name, runtime_type_checked_method(getattr(c, name))) + return c diff --git a/tests/tests_unit/test_utils/test_runtime_type_checking.py b/tests/tests_unit/test_utils/test_runtime_type_checking.py new file mode 100644 index 000000000..2b86ea42d --- /dev/null +++ b/tests/tests_unit/test_utils/test_runtime_type_checking.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import re +import sys +from dataclasses import dataclass +from typing import overload + +import pytest + +from cognite.client.exceptions import CogniteTypeError +from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked + +pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")] + + +Settings.enable_runtime_type_checking = True + + +class Foo: ... + + +class TestTypes: + @runtime_type_checked + class Types: + def primitive(self, x: int) -> None: ... + + def list(self, x: list[str]) -> None: ... + + def custom_class(self, x: Foo) -> None: ... + + def test_primitive(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.Types().primitive("1") + + self.Types().primitive(1) + + def test_list(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " + "violates type hint list[str], as str '1' not instance of list." + ), + ): + self.Types().list("1") + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " + "violates type hint list[str], as list index 0 item int 1 not instance of str." + ), + ): + self.Types().list([1]) + + self.Types().list(["ok"]) + + def test_custom_type(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " + "parameter x='1' violates type hint " + ", as str '1' not instance " + 'of ' + ), + ): + self.Types().custom_class("1") + + self.Types().custom_class(Foo()) + + @runtime_type_checked + class ClassWithConstructor: + def __init__(self, x: int, y: str) -> None: + self.x = x + self.y = y + + def test_constructor_for_class(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.ClassWithConstructor("1", "1") + + def test_constructor_for_subclass(self) -> None: + class SubDataClassWithConstructor(self.ClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + @runtime_type_checked + @dataclass + class DataClassWithConstructor: + x: int + y: int + + def test_constructor_for_dataclass(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.DataClassWithConstructor("1", "1") + + def test_constructor_for_dataclass_subclass(self) -> None: + class SubDataClassWithConstructor(self.DataClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + +class TestOverloads: + @runtime_type_checked + class WithOverload: + @overload + def foo(self, x: int, y: int) -> str: ... + + @overload + def foo(self, x: str, y: str) -> str: ... + + def foo(self, x: int | str, y: int | str) -> str: + return f"{x}{y}" + + def test_overloads(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " + "parameter y=1.0 violates type hint int | str, as float 1.0 not str or int." + ), + ): + self.WithOverload().foo(1, 1.0) + + # Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet + self.WithOverload().foo(1, "1")