diff --git a/src/cdf/injector/registry.py b/src/cdf/injector/registry.py index b72b83a..b5ca2e4 100644 --- a/src/cdf/injector/registry.py +++ b/src/cdf/injector/registry.py @@ -24,10 +24,12 @@ class DependencyCycleError(Exception): class DependencyRegistry: + lc = Lifecycle + def __init__(self) -> None: - self._dependencies = {} + """Initialize the registry.""" self._singletons = {} - self._wired = set() + self._dependencies = {} self._resolving = set() def add( @@ -35,7 +37,6 @@ def add( name: str, dependency: t.Any, lifecycle=Lifecycle.INSTANCE, - recursive: bool = True, **lazy_kwargs: t.Any, ) -> None: """Register a dependency with the container.""" @@ -44,15 +45,31 @@ def add( "Cannot pass kwargs for instance dependencies. " "Please use prototype or singleton." ) - if recursive: - dependency = self.inject_defaults(dependency) + if name in self._dependencies: + raise ValueError(f'Dependency "{name}" is already registered') self._dependencies[name] = (dependency, lifecycle, lazy_kwargs) + def remove(self, name: str) -> None: + """Remove a dependency from the container.""" + if name in self._dependencies: + del self._dependencies[name] + if name in self._singletons: + del self._singletons[name] + + def clear(self) -> None: + """Clear all dependencies and singletons.""" + self._dependencies.clear() + self._singletons.clear() + + def has(self, name: str) -> bool: + """Check if a dependency is registered.""" + return name in self._dependencies + def get(self, name: str, must_exist: bool = False) -> t.Any: """Get a dependency""" if name not in self._dependencies: if must_exist: - raise ValueError(f'Dependency "{name}" is not registered') + raise KeyError(f'Dependency "{name}" is not registered') return None if name in self._resolving: @@ -84,10 +101,30 @@ def get(self, name: str, must_exist: bool = False) -> t.Any: else: raise ValueError(f"Unknown lifecycle: {lifecycle}") + def __contains__(self, name: str) -> bool: + """Check if a dependency is registered.""" + return self.has(name) + + def __getitem__(self, name: str) -> t.Any: + """Get a dependency. Raises KeyError if not found.""" + return self.get(name, must_exist=True) + + def __setitem__(self, name: str, dependency: t.Any) -> None: + """Add a dependency. Defaults to singleton lifecycle if callable, else instance.""" + self.add( + name, + dependency, + Lifecycle.SINGLETON if callable(dependency) else Lifecycle.INSTANCE, + ) + + def __delitem__(self, name: str) -> None: + """Remove a dependency.""" + self.remove(name) + def inject_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]: """Inject dependencies into a function.""" _instance = unwrap(func_or_cls) - if id(_instance) in self._wired or not callable(func_or_cls): + if not callable(func_or_cls): return func_or_cls sig = signature(func_or_cls) @@ -102,7 +139,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: bound_args.arguments[name] = dependency return func_or_cls(*bound_args.args, **bound_args.kwargs) - self._wired.add(id(_instance)) return wrapper def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: @@ -123,6 +159,26 @@ def recursive_inject(func: t.Callable[P, T]) -> t.Callable[P, T]: return recursive_inject(func_or_cls) + def __call__( + self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any + ) -> T: + """Invoke a callable with dependencies injected from the registry.""" + return self.wire(func_or_cls)(*args, **kwargs) + + def __iter__(self) -> t.Iterator[str]: + """Iterate over dependency names.""" + return iter(self._dependencies) + + def __len__(self) -> int: + """Return the number of dependencies.""" + return len(self._dependencies) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + GLOBAL_REGISTRY = DependencyRegistry() diff --git a/src/cdf/v2/__init__.py b/src/cdf/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cdf/v2/workspace.py b/src/cdf/v2/workspace.py new file mode 100644 index 0000000..de05058 --- /dev/null +++ b/src/cdf/v2/workspace.py @@ -0,0 +1,24 @@ +from cdf.injector.registry import DependencyRegistry + +r = DependencyRegistry() +r.add("a", lambda: 1, r.lc.SINGLETON) +r.add("b", lambda a: a + 1, r.lc.SINGLETON) +r.add("obj_proto", object, r.lc.PROTOTYPE) +r.add("obj_singleton", object, r.lc.SINGLETON) + + +def foo(a: int, b: int, c: int = 0) -> int: + return a + b + + +foo_wired = r.wire(foo) + +assert foo_wired() == 3 +assert foo_wired(1) == 3 +assert foo_wired(2) == 4 +assert foo_wired(3, 3) == 6 + +assert r.get("obj_proto") is not r.get("obj_proto") +assert r.get("obj_singleton") is r.get("obj_singleton") + +assert r(foo) == 3