Skip to content

Commit

Permalink
adds NotResolved type annotations that excludes type from resolving i…
Browse files Browse the repository at this point in the history
…n configspec
  • Loading branch information
rudolfix committed May 25, 2024
1 parent 30ba63e commit c004a1f
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 21 deletions.
9 changes: 8 additions & 1 deletion dlt/common/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type
from .specs.base_configuration import (
configspec,
is_valid_hint,
is_secret_hint,
resolve_type,
NotResolved,
)
from .specs import known_sections
from .resolve import resolve_configuration, inject_section
from .inject import with_config, last_config, get_fun_spec, create_resolved_partial
Expand All @@ -15,6 +21,7 @@
"configspec",
"is_valid_hint",
"is_secret_hint",
"NotResolved",
"resolve_type",
"known_sections",
"resolve_configuration",
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
StrAny,
TSecretValue,
get_all_types_of_class_in_union,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -21,6 +20,7 @@
is_context_inner_hint,
is_base_configuration_inner_hint,
is_valid_hint,
is_hint_not_resolved,
)
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.specs.exceptions import NativeValueError
Expand Down Expand Up @@ -194,7 +194,7 @@ def _resolve_config_fields(
if explicit_values:
explicit_value = explicit_values.get(key)
else:
if is_final_type(hint):
if is_hint_not_resolved(hint):
# for final fields default value is like explicit
explicit_value = default_value
else:
Expand Down Expand Up @@ -258,7 +258,7 @@ def _resolve_config_fields(
unresolved_fields[key] = traces
# set resolved value in config
if default_value != current_value:
if not is_final_type(hint):
if not is_hint_not_resolved(hint):
# ignore final types
setattr(config, key, current_value)

Expand Down
38 changes: 37 additions & 1 deletion dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ClassVar,
TypeVar,
)
from typing_extensions import get_args, get_origin, dataclass_transform
from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias
from functools import wraps

if TYPE_CHECKING:
Expand All @@ -29,8 +29,11 @@
TDtcField = dataclasses.Field

from dlt.common.typing import (
AnyType,
TAnyClass,
extract_inner_type,
is_annotated,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -48,6 +51,34 @@
_C = TypeVar("_C", bound="CredentialsConfiguration")


class NotResolved:
"""Used in type annotations to indicate types that should not be resolved."""

def __init__(self, not_resolved: bool = True):
self.not_resolved = not_resolved

def __bool__(self) -> bool:
return self.not_resolved


def is_hint_not_resolved(hint: AnyType) -> bool:
"""Checks if hint should NOT be resolved. Final and types annotated like
>>> Annotated[str, NotResolved()]
are not resolved.
"""
if is_final_type(hint):
return True

if is_annotated(hint):
_, *a_m = get_args(hint)
for annotation in a_m:
if isinstance(annotation, NotResolved):
return bool(annotation)
return False


def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool:
return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration)

Expand All @@ -70,6 +101,11 @@ def is_valid_hint(hint: Type[Any]) -> bool:
if get_origin(hint) is ClassVar:
# class vars are skipped by dataclass
return True

if is_hint_not_resolved(hint):
# all hints that are not resolved are valid
return True

hint = extract_inner_type(hint)
hint = get_config_if_union_hint(hint) or hint
hint = get_origin(hint) or hint
Expand Down
16 changes: 8 additions & 8 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
Any,
TypeVar,
Generic,
Final,
)
from typing_extensions import Annotated
import datetime # noqa: 251
from copy import deepcopy
import inspect
Expand All @@ -35,7 +35,7 @@
has_column_with_prop,
get_first_column_name_with_prop,
)
from dlt.common.configuration import configspec, resolve_configuration, known_sections
from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved
from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration
from dlt.common.configuration.accessors import config
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
Expand Down Expand Up @@ -78,7 +78,7 @@ class StateInfo(NamedTuple):

@configspec
class DestinationClientConfiguration(BaseConfiguration):
destination_type: Final[str] = dataclasses.field(
destination_type: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # which destination to load data to
credentials: Optional[CredentialsConfiguration] = None
Expand All @@ -103,11 +103,11 @@ def on_resolved(self) -> None:
class DestinationClientDwhConfiguration(DestinationClientConfiguration):
"""Configuration of a destination that supports datasets/schemas"""

dataset_name: Final[str] = dataclasses.field(
dataset_name: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # dataset must be final so it is not configurable
) # dataset cannot be resolved
"""dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix"""
default_schema_name: Final[Optional[str]] = dataclasses.field(
default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
"""name of default schema to be used to name effective dataset to load data to"""
Expand All @@ -121,8 +121,8 @@ def _bind_dataset_name(
This method is intended to be used internally.
"""
self.dataset_name = dataset_name # type: ignore[misc]
self.default_schema_name = default_schema_name # type: ignore[misc]
self.dataset_name = dataset_name
self.default_schema_name = default_schema_name
return self

def normalize_dataset_name(self, schema: Schema) -> str:
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/impl/qdrant/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Optional, Final
from typing_extensions import Annotated

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import (
BaseConfiguration,
CredentialsConfiguration,
Expand Down Expand Up @@ -55,7 +56,9 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration):
dataset_separator: str = "_"

# make it optional so empty dataset is allowed
dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

# Batch size for generating embeddings
embedding_batch_size: int = 32
Expand Down
5 changes: 3 additions & 2 deletions dlt/destinations/impl/weaviate/configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Dict, Literal, Optional, Final
from typing_extensions import Annotated
from urllib.parse import urlparse

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration
from dlt.common.destination.reference import DestinationClientDwhConfiguration
from dlt.common.utils import digest128
Expand All @@ -26,7 +27,7 @@ def __str__(self) -> str:
class WeaviateClientConfiguration(DestinationClientDwhConfiguration):
destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore
# make it optional so empty dataset is allowed
dataset_name: Optional[str] = None # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = None

batch_size: int = 100
batch_workers: int = 1
Expand Down
55 changes: 54 additions & 1 deletion tests/common/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
Optional,
Type,
Union,
TYPE_CHECKING,
)
from typing_extensions import Annotated

from dlt.common import json, pendulum, Decimal, Wei
from dlt.common.configuration.providers.provider import ConfigProvider
from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved
from dlt.common.configuration.specs.gcp_credentials import (
GcpServiceAccountCredentialsWithoutDefaults,
)
Expand Down Expand Up @@ -917,6 +918,58 @@ def test_is_valid_hint() -> None:
assert is_valid_hint(Wei) is True
# any class type, except deriving from BaseConfiguration is wrong type
assert is_valid_hint(ConfigFieldMissingException) is False
# but final and annotated types are not ok because they are not resolved
assert is_valid_hint(Final[ConfigFieldMissingException]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, NotResolved()]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, "REQ"]) is False # type: ignore[arg-type]


def test_is_not_resolved_hint() -> None:
assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False
assert is_hint_not_resolved(str) is False


def test_not_resolved_hint() -> None:
class SentinelClass:
pass

@configspec
class OptionalNotResolveConfiguration(BaseConfiguration):
trace: Final[Optional[SentinelClass]] = None
traces: Annotated[Optional[List[SentinelClass]], NotResolved()] = None

c = resolve.resolve_configuration(OptionalNotResolveConfiguration())
assert c.trace is None
assert c.traces is None

s1 = SentinelClass()
s2 = SentinelClass()

c = resolve.resolve_configuration(OptionalNotResolveConfiguration(s1, [s2]))
assert c.trace is s1
assert c.traces[0] is s2

@configspec
class NotResolveConfiguration(BaseConfiguration):
trace: Final[SentinelClass] = None
traces: Annotated[List[SentinelClass], NotResolved()] = None

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration())

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(trace=s1))

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(traces=[s2]))

c2 = resolve.resolve_configuration(NotResolveConfiguration(s1, [s2]))
assert c2.trace is s1
assert c2.traces[0] is s2


def test_configspec_auto_base_config_derivation() -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,8 @@ def yield_client(
destination = Destination.from_reference(destination_type)
# create initial config
dest_config: DestinationClientDwhConfiguration = None
dest_config = destination.spec() # type: ignore[assignment]
dest_config.dataset_name = dataset_name # type: ignore[misc]
dest_config = destination.spec() # type: ignore
dest_config.dataset_name = dataset_name

if default_config_values is not None:
# apply the values to credentials, if dict is provided it will be used as default
Expand All @@ -597,7 +597,7 @@ def yield_client(
staging_config = DestinationClientStagingConfiguration(
bucket_url=AWS_BUCKET,
)._bind_dataset_name(dataset_name=dest_config.dataset_name)
staging_config.destination_type = "filesystem" # type: ignore[misc]
staging_config.destination_type = "filesystem"
staging_config.resolve()
dest_config.staging_config = staging_config # type: ignore[attr-defined]

Expand Down

0 comments on commit c004a1f

Please sign in to comment.