Skip to content

Commit

Permalink
adds before_add, after_remove and improves add_extra when adding to c…
Browse files Browse the repository at this point in the history
…ontainer, tracks reference to container in context
  • Loading branch information
rudolfix committed Oct 12, 2024
1 parent a35caa5 commit 0912dbc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
40 changes: 26 additions & 14 deletions dlt/common/configuration/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import threading
from typing import ClassVar, Dict, Iterator, Optional, Tuple, Type, TypeVar, Any

from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext
from dlt.common.configuration.specs.base_configuration import (
ContainerInjectableContext,
TInjectableContext,
)
from dlt.common.configuration.exceptions import (
ContainerInjectableContextMangled,
ContextDefaultCannotBeCreated,
)
from dlt.common.typing import is_subclass

TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext)


class Container:
"""A singleton injection container holding several injection contexts. Implements basic dictionary interface.
Expand Down Expand Up @@ -55,7 +56,7 @@ def __new__(cls: Type["Container"]) -> "Container":
def __init__(self) -> None:
pass

def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration:
def __getitem__(self, spec: Type[TInjectableContext]) -> TInjectableContext:
# return existing config object or create it from spec
if not is_subclass(spec, ContainerInjectableContext):
raise KeyError(f"{spec.__name__} is not a context")
Expand All @@ -65,28 +66,27 @@ def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration:
if spec.can_create_default:
item = spec()
self._thread_setitem(context, spec, item)
item.add_extras()
else:
raise ContextDefaultCannotBeCreated(spec)

return item # type: ignore[return-value]

def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None:
def __setitem__(self, spec: Type[TInjectableContext], value: TInjectableContext) -> None:
# value passed to container must be final
value.resolve()
# put it into context
self._thread_setitem(self._thread_context(spec), spec, value)

def __delitem__(self, spec: Type[TConfiguration]) -> None:
def __delitem__(self, spec: Type[TInjectableContext]) -> None:
context = self._thread_context(spec)
self._thread_delitem(context, spec)

def __contains__(self, spec: Type[TConfiguration]) -> bool:
def __contains__(self, spec: Type[TInjectableContext]) -> bool:
context = self._thread_context(spec)
return spec in context

def _thread_context(
self, spec: Type[TConfiguration]
self, spec: Type[TInjectableContext]
) -> Dict[Type[ContainerInjectableContext], ContainerInjectableContext]:
if spec.global_affinity:
return self.main_context
Expand All @@ -107,7 +107,7 @@ def _thread_context(
return context

def _thread_getitem(
self, spec: Type[TConfiguration]
self, spec: Type[TInjectableContext]
) -> Tuple[
Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
ContainerInjectableContext,
Expand All @@ -120,21 +120,33 @@ def _thread_setitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
value: TConfiguration,
value: TInjectableContext,
) -> None:
old_ctx = context.get(spec)
if old_ctx:
old_ctx.before_remove()
old_ctx.in_container = False
context[spec] = value
value.in_container = True
value.after_add()
if not value.extras_added:
value.add_extras()
value.extras_added = True

def _thread_delitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
) -> None:
old_ctx = context[spec]
old_ctx.before_remove()
del context[spec]
old_ctx.in_container = False

@contextmanager
def injectable_context(
self, config: TConfiguration, lock_context: bool = False
) -> Iterator[TConfiguration]:
self, config: TInjectableContext, lock_context: bool = False
) -> Iterator[TInjectableContext]:
"""A context manager that will insert `config` into the container and restore the previous value when it gets out of scope."""

config.resolve()
Expand Down Expand Up @@ -171,7 +183,7 @@ def injectable_context(
# value was modified in the meantime and not restored
raise ContainerInjectableContextMangled(spec, context[spec], config)

def get(self, spec: Type[TConfiguration]) -> Optional[TConfiguration]:
def get(self, spec: Type[TInjectableContext]) -> Optional[TInjectableContext]:
try:
return self[spec]
except KeyError:
Expand Down
19 changes: 18 additions & 1 deletion dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ def default_credentials(self) -> Any:
return None


TInjectableContext = TypeVar("TInjectableContext", bound="ContainerInjectableContext")


@configspec
class ContainerInjectableContext(BaseConfiguration):
"""Base class for all configurations that may be injected from a Container. Injectable configuration is called a context"""
Expand All @@ -494,11 +497,25 @@ class ContainerInjectableContext(BaseConfiguration):
"""If True, `Container` is allowed to create default context instance, if none exists"""
global_affinity: ClassVar[bool] = False
"""If True, `Container` will create context that will be visible in any thread. If False, per thread context is created"""
in_container: Annotated[bool, NotResolved()] = dataclasses.field(
default=False, init=False, repr=False, compare=False
)
"""Current container, if None then not injected"""
extras_added: Annotated[bool, NotResolved()] = dataclasses.field(
default=False, init=False, repr=False, compare=False
)
"""Tells if extras were already added to this context"""

def add_extras(self) -> None:
"""Called right after context was added to the container. Benefits mostly the config provider injection context which adds extra providers using the initial ones."""
"""Called once after default context was created and added to the container. Benefits mostly the config provider injection context which adds extra providers using the initial ones."""
pass

def after_add(self) -> None:
"""Called each time after context is added to container"""

def before_remove(self) -> None:
"""Called each time before context is removed from container"""


_F_ContainerInjectableContext = ContainerInjectableContext

Expand Down

0 comments on commit 0912dbc

Please sign in to comment.