Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass loading context through _blocking_batch_load #25377

Open
wants to merge 1 commit into
base: briantu/make-compute-subset-async
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions python_modules/dagster/dagster/_core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ def instance(self) -> "DagsterInstance":
def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]:
raise NotImplementedError()

@staticmethod
def ephemeral(instance: "DagsterInstance") -> "LoadingContext":
return EphemeralLoadingContext(instance)

def get_loaders_for(
self, ttype: Type["InstanceLoadableBy"]
) -> Tuple[DataLoader, BlockingDataLoader]:
if ttype not in self.loaders:
if not issubclass(ttype, InstanceLoadableBy):
check.failed(f"{ttype} is not Loadable")

batch_load_fn = partial(ttype._batch_load, instance=self.instance) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, instance=self.instance) # noqa
batch_load_fn = partial(ttype._batch_load, context=self) # noqa
blocking_batch_load_fn = partial(ttype._blocking_batch_load, context=self) # noqa

self.loaders[ttype] = (
DataLoader(batch_load_fn=batch_load_fn),
Expand All @@ -80,6 +84,22 @@ def clear_loaders(self) -> None:
del self.loaders[ttype]


class EphemeralLoadingContext(LoadingContext):
"""Loading context that can be constructed for short-lived method resolution."""

def __init__(self, instance: "DagsterInstance"):
self._instance = instance
self._loaders = {}

@property
def instance(self) -> "DagsterInstance":
return self._instance

@property
def loaders(self) -> Dict[Type, Tuple[DataLoader, BlockingDataLoader]]:
return self._loaders


# Expected there may be other "Loadable" base classes based on what is needed to load.


Expand All @@ -88,14 +108,14 @@ class InstanceLoadableBy(ABC, Generic[TKey]):

@classmethod
async def _batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
return cls._blocking_batch_load(keys, instance)
return cls._blocking_batch_load(keys, context)

@classmethod
@abstractmethod
def _blocking_batch_load(
cls, keys: Iterable[TKey], instance: "DagsterInstance"
cls, keys: Iterable[TKey], context: "LoadingContext"
) -> Iterable[Optional[Self]]:
# There is no good way of turning an async function into a sync one that
# will allow us to execute that sync function inside of a broader async context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.asset_key import AssetCheckKey
from dagster._core.events.log import DagsterEventType, EventLogEntry
from dagster._core.instance import DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.dagster_run import DagsterRunStatus, RunRecord
from dagster._serdes.serdes import deserialize_value
Expand Down Expand Up @@ -124,9 +123,9 @@ def from_db_row(cls, row, key: AssetCheckKey) -> "AssetCheckExecutionRecord":

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckExecutionRecord"]]:
records_by_key = instance.event_log_storage.get_latest_asset_check_execution_by_key(
records_by_key = context.instance.event_log_storage.get_latest_asset_check_execution_by_key(
list(keys)
)
return [records_by_key.get(key) for key in keys]
Expand Down
7 changes: 3 additions & 4 deletions python_modules/dagster/dagster/_core/storage/dagster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dagster._annotations import PublicAttr, experimental_param, public
from dagster._core.definitions.asset_check_spec import AssetCheckKey
from dagster._core.definitions.events import AssetKey
from dagster._core.loader import InstanceLoadableBy
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.origin import JobPythonOrigin
from dagster._core.storage.tags import (
ASSET_EVALUATION_ID_TAG,
Expand All @@ -41,7 +41,6 @@
if TYPE_CHECKING:
from dagster._core.definitions.schedule_definition import ScheduleDefinition
from dagster._core.definitions.sensor_definition import SensorDefinition
from dagster._core.instance import DagsterInstance
from dagster._core.remote_representation.external import RemoteSchedule, RemoteSensor
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._core.scheduler.instigation import InstigatorState
Expand Down Expand Up @@ -643,12 +642,12 @@ def __new__(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: "DagsterInstance"
cls, keys: Iterable[str], context: LoadingContext
) -> Iterable[Optional["RunRecord"]]:
result_map: Dict[str, Optional[RunRecord]] = {run_id: None for run_id in keys}

# this should be replaced with an async DB call
records = instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))
records = context.instance.get_run_records(RunsFilter(run_ids=list(result_map.keys())))

for record in records:
result_map[record.dagster_run.run_id] = record
Expand Down
19 changes: 12 additions & 7 deletions python_modules/dagster/dagster/_core/storage/event_log/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
build_run_stats_from_events,
build_run_step_stats_from_events,
)
from dagster._core.instance import DagsterInstance, MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy
from dagster._core.instance import MayHaveInstanceWeakref, T_DagsterInstance
from dagster._core.loader import InstanceLoadableBy, LoadingContext
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecord
from dagster._core.storage.dagster_run import DagsterRunStatsSnapshot
from dagster._core.storage.partition_status_cache import get_and_update_asset_status_cache_value
Expand Down Expand Up @@ -138,11 +138,11 @@ class AssetRecord(

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetKey], instance: DagsterInstance
cls, keys: Iterable[AssetKey], context: LoadingContext
) -> Iterable[Optional["AssetRecord"]]:
records_by_key = {
record.asset_entry.asset_key: record
for record in instance.get_asset_records(list(keys))
for record in context.instance.get_asset_records(list(keys))
}
return [records_by_key.get(key) for key in keys]

Expand All @@ -160,9 +160,11 @@ class AssetCheckSummaryRecord(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[AssetCheckKey], instance: DagsterInstance
cls, keys: Iterable[AssetCheckKey], context: LoadingContext
) -> Iterable[Optional["AssetCheckSummaryRecord"]]:
records_by_key = instance.event_log_storage.get_asset_check_summary_records(list(keys))
records_by_key = context.instance.event_log_storage.get_asset_check_summary_records(
list(keys)
)
return [records_by_key[key] for key in keys]


Expand Down Expand Up @@ -653,11 +655,14 @@ def default_run_scoped_event_tailer_offset(self) -> int:
def get_asset_status_cache_values(
self,
partitions_defs_by_key: Mapping[AssetKey, Optional[PartitionsDefinition]],
context: LoadingContext,
) -> Sequence[Optional["AssetStatusCacheValue"]]:
"""Get the cached status information for each asset."""
values = []
for asset_key, partitions_def in partitions_defs_by_key.items():
values.append(
get_and_update_asset_status_cache_value(self._instance, asset_key, partitions_def)
get_and_update_asset_status_cache_value(
self._instance, asset_key, partitions_def, loading_context=context
)
)
return values
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def from_db_string(db_string: str) -> Optional["AssetStatusCacheValue"]:

@classmethod
def _blocking_batch_load(
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], instance: "DagsterInstance"
cls, keys: Iterable[Tuple[AssetKey, PartitionsDefinition]], context: LoadingContext
) -> Iterable[Optional["AssetStatusCacheValue"]]:
return instance.event_log_storage.get_asset_status_cache_values(dict(keys))
return context.instance.event_log_storage.get_asset_status_cache_values(dict(keys), context)

def deserialize_materialized_partition_subsets(
self, partitions_def: PartitionsDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from dagster._core.execution.plan.objects import StepFailureData, StepSuccessData
from dagster._core.execution.stats import StepEventStatus
from dagster._core.instance import RUNLESS_JOB_NAME, RUNLESS_RUN_ID
from dagster._core.loader import LoadingContext
from dagster._core.remote_representation.external_data import PartitionsSnap
from dagster._core.remote_representation.origin import (
InProcessCodeLocationOrigin,
Expand Down Expand Up @@ -6023,7 +6024,9 @@ def test_get_updated_asset_status_cache_values(
AssetKey("static"): StaticPartitionsDefinition(["a", "b", "c"]),
}

assert storage.get_asset_status_cache_values(partition_defs_by_key) == [
assert storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContext.ephemeral(instance)
) == [
None,
None,
None,
Expand All @@ -6038,6 +6041,10 @@ def test_get_updated_asset_status_cache_values(
instance.report_runless_asset_event(AssetMaterialization(asset_key="static", partition="a"))

partition_defs = list(partition_defs_by_key.values())
for i, value in enumerate(storage.get_asset_status_cache_values(partition_defs_by_key)):
for i, value in enumerate(
storage.get_asset_status_cache_values(
partition_defs_by_key, LoadingContext.ephemeral(instance)
),
):
assert value is not None
assert len(value.deserialize_materialized_partition_subsets(partition_defs[i])) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ class LoadableThing(
):
@classmethod
def _blocking_batch_load(
cls, keys: Iterable[str], instance: mock.MagicMock
cls, keys: Iterable[str], context: mock.MagicMock
) -> List["LoadableThing"]:
instance.query(keys)
context.query(keys)
return [LoadableThing(key, random.randint(0, 100000)) for key in keys]


Expand Down