Skip to content

Commit

Permalink
Add runtime type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
erlendvollset committed Sep 10, 2024
1 parent d587246 commit 853611d
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cognite/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class CogniteException(Exception):
pass


class CogniteTypeError(CogniteException): ...


@dataclass
class GraphQLErrorSpec:
message: str
Expand Down
36 changes: 36 additions & 0 deletions cognite/client/utils/_runtime_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys
from inspect import isfunction
from typing import Any, Callable, TypeVar

from beartype import beartype
from beartype.roar import BeartypeCallHintParamViolation

from cognite.client.exceptions import CogniteTypeError

T_Callable = TypeVar("T_Callable", bound=Callable)
T_Class = TypeVar("T_Class", bound=type)


class Settings:
enable_runtime_type_checking: bool = False


def runtime_type_checked_method(f: T_Callable) -> T_Callable:
if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking:
return f
beartyped_f = beartype(f)

def f_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return beartyped_f(*args, **kwargs)
except BeartypeCallHintParamViolation as e:
raise CogniteTypeError(e.args[0])

return f_wrapped


def runtime_type_checked(c: T_Class) -> T_Class:
for name in dir(c):
if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)):
setattr(c, name, runtime_type_checked_method(getattr(c, name)))
return c
160 changes: 160 additions & 0 deletions tests/tests_unit/test_utils/test_runtime_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

import re
import sys
from dataclasses import dataclass
from typing import overload

import pytest

from cognite.client.exceptions import CogniteTypeError
from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked

pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")]


Settings.enable_runtime_type_checking = True


class Foo: ...


class TestTypes:
@runtime_type_checked
class Types:
def primitive(self, x: int) -> None: ...

def list(self, x: list[str]) -> None: ...

def custom_class(self, x: Foo) -> None: ...

def test_primitive(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
self.Types().primitive("1")

self.Types().primitive(1)

def test_list(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' "
"violates type hint list[str], as str '1' not instance of list."
),
):
self.Types().list("1")

with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] "
"violates type hint list[str], as list index 0 item int 1 not instance of str."
),
):
self.Types().list([1])

self.Types().list(["ok"])

def test_custom_type(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() "
"parameter x='1' violates type hint "
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance "
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">'
),
):
self.Types().custom_class("1")

self.Types().custom_class(Foo())

@runtime_type_checked
class ClassWithConstructor:
def __init__(self, x: int, y: str) -> None:
self.x = x
self.y = y

def test_constructor_for_class(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
self.ClassWithConstructor("1", "1")

def test_constructor_for_subclass(self) -> None:
class SubDataClassWithConstructor(self.ClassWithConstructor):
pass

with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
SubDataClassWithConstructor("1", "1")

@runtime_type_checked
@dataclass
class DataClassWithConstructor:
x: int
y: int

def test_constructor_for_dataclass(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
self.DataClassWithConstructor("1", "1")

def test_constructor_for_dataclass_subclass(self) -> None:
class SubDataClassWithConstructor(self.DataClassWithConstructor):
pass

with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() "
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int."
),
):
SubDataClassWithConstructor("1", "1")


class TestOverloads:
@runtime_type_checked
class WithOverload:
@overload
def foo(self, x: int, y: int) -> str: ...

@overload
def foo(self, x: str, y: str) -> str: ...

def foo(self, x: int | str, y: int | str) -> str:
return f"{x}{y}"

def test_overloads(self) -> None:
with pytest.raises(
CogniteTypeError,
match=re.escape(
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() "
"parameter y=1.0 violates type hint int | str, as float 1.0 not str or int."
),
):
self.WithOverload().foo(1, 1.0)

# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet
self.WithOverload().foo(1, "1")

0 comments on commit 853611d

Please sign in to comment.