Skip to content

Commit

Permalink
feat: polish dep registry interface
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 18, 2024
1 parent b9dc92b commit 053e493
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
72 changes: 64 additions & 8 deletions src/cdf/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ class DependencyCycleError(Exception):


class DependencyRegistry:
lc = Lifecycle

def __init__(self) -> None:
self._dependencies = {}
"""Initialize the registry."""
self._singletons = {}
self._wired = set()
self._dependencies = {}
self._resolving = set()

def add(
self,
name: str,
dependency: t.Any,
lifecycle=Lifecycle.INSTANCE,
recursive: bool = True,
**lazy_kwargs: t.Any,
) -> None:
"""Register a dependency with the container."""
Expand All @@ -44,15 +45,31 @@ def add(
"Cannot pass kwargs for instance dependencies. "
"Please use prototype or singleton."
)
if recursive:
dependency = self.inject_defaults(dependency)
if name in self._dependencies:
raise ValueError(f'Dependency "{name}" is already registered')
self._dependencies[name] = (dependency, lifecycle, lazy_kwargs)

def remove(self, name: str) -> None:
"""Remove a dependency from the container."""
if name in self._dependencies:
del self._dependencies[name]
if name in self._singletons:
del self._singletons[name]

def clear(self) -> None:
"""Clear all dependencies and singletons."""
self._dependencies.clear()
self._singletons.clear()

def has(self, name: str) -> bool:
"""Check if a dependency is registered."""
return name in self._dependencies

def get(self, name: str, must_exist: bool = False) -> t.Any:
"""Get a dependency"""
if name not in self._dependencies:
if must_exist:
raise ValueError(f'Dependency "{name}" is not registered')
raise KeyError(f'Dependency "{name}" is not registered')
return None

if name in self._resolving:
Expand Down Expand Up @@ -84,10 +101,30 @@ def get(self, name: str, must_exist: bool = False) -> t.Any:
else:
raise ValueError(f"Unknown lifecycle: {lifecycle}")

def __contains__(self, name: str) -> bool:
"""Check if a dependency is registered."""
return self.has(name)

def __getitem__(self, name: str) -> t.Any:
"""Get a dependency. Raises KeyError if not found."""
return self.get(name, must_exist=True)

def __setitem__(self, name: str, dependency: t.Any) -> None:
"""Add a dependency. Defaults to singleton lifecycle if callable, else instance."""
self.add(
name,
dependency,
Lifecycle.SINGLETON if callable(dependency) else Lifecycle.INSTANCE,
)

def __delitem__(self, name: str) -> None:
"""Remove a dependency."""
self.remove(name)

def inject_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
"""Inject dependencies into a function."""
_instance = unwrap(func_or_cls)
if id(_instance) in self._wired or not callable(func_or_cls):
if not callable(func_or_cls):
return func_or_cls

sig = signature(func_or_cls)
Expand All @@ -102,7 +139,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound_args.arguments[name] = dependency
return func_or_cls(*bound_args.args, **bound_args.kwargs)

self._wired.add(id(_instance))
return wrapper

def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
Expand All @@ -123,6 +159,26 @@ def recursive_inject(func: t.Callable[P, T]) -> t.Callable[P, T]:

return recursive_inject(func_or_cls)

def __call__(
self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any
) -> T:
"""Invoke a callable with dependencies injected from the registry."""
return self.wire(func_or_cls)(*args, **kwargs)

def __iter__(self) -> t.Iterator[str]:
"""Iterate over dependency names."""
return iter(self._dependencies)

def __len__(self) -> int:
"""Return the number of dependencies."""
return len(self._dependencies)

def __repr__(self) -> str:
return f"<DependencyRegistry {self._dependencies.keys()}>"

def __str__(self) -> str:
return repr(self)


GLOBAL_REGISTRY = DependencyRegistry()

Expand Down
Empty file added src/cdf/v2/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions src/cdf/v2/workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cdf.injector.registry import DependencyRegistry

r = DependencyRegistry()
r.add("a", lambda: 1, r.lc.SINGLETON)
r.add("b", lambda a: a + 1, r.lc.SINGLETON)
r.add("obj_proto", object, r.lc.PROTOTYPE)
r.add("obj_singleton", object, r.lc.SINGLETON)


def foo(a: int, b: int, c: int = 0) -> int:
return a + b


foo_wired = r.wire(foo)

assert foo_wired() == 3
assert foo_wired(1) == 3
assert foo_wired(2) == 4
assert foo_wired(3, 3) == 6

assert r.get("obj_proto") is not r.get("obj_proto")
assert r.get("obj_singleton") is r.get("obj_singleton")

assert r(foo) == 3

0 comments on commit 053e493

Please sign in to comment.