Skip to content

Commit

Permalink
better decorator typing (#29)
Browse files Browse the repository at this point in the history
* improve typing and add mypy test to API

* ci

* fix lint

* fix lint
  • Loading branch information
Yiling-J authored Oct 11, 2024
1 parent a5d0419 commit c84a81a
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 62 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
branches:
- main
- "[0-9].[0-9]"
- typing
pull_request:
branches: ["main"]

Expand Down Expand Up @@ -41,6 +42,12 @@ jobs:
run: "poetry install --no-interaction --no-root --all-extras"
- name: "Run Lint"
run: "make lint"
- name: "Mypy api success"
run: "make lint-pass"
- name: "Mypy api failed"
run: |
error_count=$(make lint-failed 2>&1 | grep -c 'error:')
[ "$error_count" -eq 4 ]
- name: "Run Tests"
env:
CI: "TRUE"
Expand Down
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,14 @@ benchmark:
.PHONY: lint
lint:
poetry run mypy .

.PHONY: lint-pass
lint-pass:
poetry run mypy tests/typing/api_pass.py

.PHONY: lint-failed
lint-failed:
poetry run mypy tests/typing/api_failed.py

trace_bench:
poetry run python -m benchmarks.trace_bench
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ theine-core = "^0.4.3"
pytest = "^7.2.1"
pytest-benchmark = "^4.0.0"
typing-extensions = "^4.4.0"
mypy = "1.11.1"
mypy = "^1.11.1"
django = "^3.2"
pytest-django = "^4.5.2"
pytest-asyncio = "^0.20.3"
Expand All @@ -41,7 +41,7 @@ exclude = [
]

[tool.django-stubs]
django_settings_module = 'theine'
django_settings_module = 'tests.adapters.settings.theine'
strict_settings = false

[build-system]
Expand Down
29 changes: 29 additions & 0 deletions tests/typing/api_failed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, Dict, List
from theine import Memoize, Cache


@Memoize(Cache("tlfu", 1000), None)
def foo(id: int) -> Dict[str, int]:
return {"id": id}


@foo.key
def _(id: int, name: str) -> str:
return f"id-{id}"


class Bar:

@Memoize(Cache("tlfu", 1000), None)
def foo(self, id: int) -> Dict[str, int]:
return {"id": id}

@foo.key
def _(self, id: int, name: str) -> str:
return f"id-{id}"


def run() -> None:
v: Dict[str, int] = foo(12, 13)
bar = Bar()
b: Dict[str, int] = bar.foo(12, 13)
29 changes: 29 additions & 0 deletions tests/typing/api_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, Dict, List
from theine import Memoize, Cache


@Memoize(Cache("tlfu", 1000), None)
def foo(id: int) -> Dict[str, int]:
return {"id": id}


@foo.key
def _(id: int) -> str:
return f"id-{id}"


class Bar:

@Memoize(Cache("tlfu", 1000), None)
def foo(self, id: int) -> Dict[str, int]:
return {"id": id}

@foo.key
def _(self, id: int) -> str:
return f"id-{id}"


def run() -> None:
v: Dict[str, int] = foo(12)
bar = Bar()
b: Dict[str, int] = bar.foo(13)
129 changes: 69 additions & 60 deletions theine/theine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,31 @@
from threading import Event, Thread
from typing import (
Any,
Awaitable, Callable,
Awaitable,
Callable,
Dict,
Hashable,
List,
Optional,
TYPE_CHECKING, Tuple,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
cast, )
cast,
overload,
no_type_check,
)

from mypy_extensions import KwArg, VarArg
from theine_core import ClockProCore, LruCore, TlfuCore
from typing_extensions import ParamSpec, Protocol
from typing_extensions import ParamSpec, Protocol, Concatenate

from theine.exceptions import InvalidTTL
from theine.models import CacheStats

S = TypeVar("S", contravariant=True)
P = ParamSpec("P")
R = TypeVar("R", covariant=True, bound=Any)
R_A = TypeVar("R_A", covariant=True, bound=Union[Awaitable[Any], Callable[..., Any]])
if TYPE_CHECKING:
from functools import _Wrapped

Expand Down Expand Up @@ -59,51 +63,49 @@ def len(self) -> int:


class Core(Protocol):
def __init__(self, size: int):
...
def __init__(self, size: int): ...

def set(self, key: str, ttl: int) -> Tuple[int, Optional[int], Optional[str]]:
...
def set(self, key: str, ttl: int) -> Tuple[int, Optional[int], Optional[str]]: ...

def remove(self, key: str) -> Optional[int]:
...
def remove(self, key: str) -> Optional[int]: ...

def access(self, key: str) -> Optional[int]:
...
def access(self, key: str) -> Optional[int]: ...

def advance(self, cache: List[Any], sentinel: Any, kh: Dict[int, Hashable], hk: Dict[Hashable, int]) -> None:
...
def advance(
self,
cache: List[Any],
sentinel: Any,
kh: Dict[int, Hashable],
hk: Dict[Hashable, int],
) -> None: ...

def clear(self) -> None:
...
def clear(self) -> None: ...

def len(self) -> int:
...
def len(self) -> int: ...


class ClockProCoreP(Protocol):
def __init__(self, size: int):
...
def __init__(self, size: int): ...

def set(
self, key: str, ttl: int
) -> Tuple[int, Optional[int], Optional[int], Optional[str]]:
...
) -> Tuple[int, Optional[int], Optional[int], Optional[str]]: ...

def remove(self, key: str) -> Optional[int]:
...
def remove(self, key: str) -> Optional[int]: ...

def access(self, key: str) -> Optional[int]:
...
def access(self, key: str) -> Optional[int]: ...

def advance(self, cache: List[Any], sentinel: Any, kh: Dict[int, Hashable], hk: Dict[Hashable, int]) -> None:
...
def advance(
self,
cache: List[Any],
sentinel: Any,
kh: Dict[int, Hashable],
hk: Dict[Hashable, int],
) -> None: ...

def clear(self) -> None:
...
def clear(self) -> None: ...

def len(self) -> int:
...
def len(self) -> int: ...


CORES: Dict[str, Union[Type[Core], Type[ClockProCoreP]]] = {
Expand Down Expand Up @@ -149,54 +151,61 @@ def __init__(self) -> None:
self.event = Event()


class Cached(Protocol[P, R_A]):
class Cached(Protocol[S, P, R]):
_cache: "Cache"

def key(self, fn: Callable[P, Hashable]) -> None:
...
@overload
def key(self, fn: Callable[P, Hashable]) -> None: ...

@overload
def key(self, fn: Callable[Concatenate[S, P], Hashable]) -> None: ...

def __call__(self, *args: Any, **kwargs: Any) -> R_A:
...
@overload
def __call__(self, _arg_first: S, *args: P.args, **kwargs: P.kwargs) -> R: ...

@overload
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...


@no_type_check
def Wrapper(
fn: Callable[P, R_A],
fn: Callable,
timeout: Optional[timedelta],
cache: "Cache",
typed: bool,
lock: bool,
) -> Cached[P, R_A]:
_key_func: Optional[Callable[..., Hashable]] = None
_events: Dict[Hashable, EventData] = {}
_func: Callable[P, R_A] = fn
_cache: "Cache" = cache
_timeout: Optional[timedelta] = timeout
_typed: bool = typed
_auto_key: bool = True
):
_key_func = None
_events = {}
_func = fn
_cache = cache
_timeout = timeout
_typed = typed
_auto_key = True
_lock = lock

def key(fn: Callable[P, Hashable]) -> None:
def key(fn) -> None:
nonlocal _key_func
nonlocal _auto_key
_key_func = fn
_auto_key = False

def fetch(*args: P.args, **kwargs: P.kwargs) -> R_A:
def fetch(*args, **kwargs):
if _auto_key:
key = _make_key(args, kwargs, _typed)
else:
key = _key_func(*args, **kwargs) # type: ignore
key = _key_func(*args, **kwargs)

if inspect.iscoroutinefunction(fn):
result = _cache.get(key, sentinel)
if result is sentinel:
result = CachedAwaitable(cast(Awaitable[Any], _func(*args, **kwargs)))
result = CachedAwaitable(_func(*args, **kwargs))
_cache.set(key, result, _timeout)
return cast(R_A, result)
return result

data = _cache.get(key, sentinel)
if data is not sentinel:
return cast(R_A, data)
return data
if _lock:
event = EventData(Event(), None)
ve = _events.setdefault(key, event)
Expand All @@ -212,11 +221,11 @@ def fetch(*args: P.args, **kwargs: P.kwargs) -> R_A:
else:
result = _func(*args, **kwargs)
_cache.set(key, result, _timeout)
return cast(R_A, result)
return result

fetch._cache = _cache # type: ignore
fetch.key = key # type: ignore
return fetch # type: ignore
fetch._cache = _cache
fetch.key = key
return fetch


class Memoize:
Expand Down Expand Up @@ -247,9 +256,9 @@ def __init__(
self.typed = typed
self.lock = lock

def __call__(self, fn: Callable[P, R_A]) -> '_Wrapped[P, R_A, [VarArg(Any), KwArg(Any)], R_A]':
def __call__(self, fn: Callable[Concatenate[S, P], R]) -> Cached[S, P, R]:
wrapper = Wrapper(fn, self.timeout, self.cache, self.typed, self.lock)
return update_wrapper(wrapper, fn)
return cast(Cached[S, P, R], update_wrapper(wrapper, fn))


class Cache:
Expand Down

0 comments on commit c84a81a

Please sign in to comment.