diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 10eeba343e5c..1cc0eb9977b1 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -276,11 +276,16 @@ def to_state_create(self): from prefect.client.schemas.actions import StateCreate from prefect.results import BaseResult + if isinstance(self.data, BaseResult) and self.data.serialize_to_none is False: + data = self.data + else: + data = None + return StateCreate( type=self.type, name=self.name, message=self.message, - data=self.data if isinstance(self.data, BaseResult) else None, + data=data, state_details=self.state_details, ) diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 5e11f25b1c4f..d67d83e00648 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -268,6 +268,7 @@ def handle_success(self, result: R) -> R: return_value_to_state( resolved_result, result_factory=result_factory, + write_result=True, ) ) self.set_state(terminal_state) @@ -287,6 +288,7 @@ def handle_exception( message=msg or "Flow run encountered an exception:", result_factory=result_factory or getattr(context, "result_factory", None), + write_result=True, ) ) state = self.set_state(terminal_state) diff --git a/src/prefect/results.py b/src/prefect/results.py index 233dbc7d7cbb..a45f53780a2d 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -25,7 +25,6 @@ import prefect from prefect.blocks.core import Block from prefect.client.utilities import inject_client -from prefect.exceptions import MissingResult from prefect.filesystems import ( LocalFileSystem, WritableFileSystem, @@ -332,22 +331,15 @@ async def create_result( obj: R, key: Optional[str] = None, expiration: Optional[DateTime] = None, - defer_persistence: bool = False, ) -> Union[R, "BaseResult[R]"]: """ Create a result type for the given object. - If persistence is disabled, the object is wrapped in an `UnpersistedResult` and - returned. - If persistence is enabled the object is serialized, persisted to storage, and a reference is returned. """ # Null objects are "cached" in memory at no cost should_cache_object = self.cache_result_in_memory or obj is None - if not self.persist_result: - return await UnpersistedResult.create(obj, cache_object=should_cache_object) - if key: def key_fn(): @@ -365,7 +357,7 @@ def key_fn(): serializer=self.serializer, cache_object=should_cache_object, expiration=expiration, - defer_persistence=defer_persistence, + serialize_to_none=not self.persist_result, ) @sync_compatible @@ -432,34 +424,6 @@ def __dispatch_key__(cls, **kwargs): return cls.__name__ if isinstance(default, PydanticUndefinedType) else default -class UnpersistedResult(BaseResult): - """ - Result type for results that are not persisted outside of local memory. - """ - - type: str = "unpersisted" - - @sync_compatible - async def get(self) -> R: - if self.has_cached_object(): - return self._cache - - raise MissingResult("The result was not persisted and is no longer available.") - - @classmethod - @sync_compatible - async def create( - cls: "Type[UnpersistedResult]", - obj: R, - cache_object: bool = True, - ) -> "UnpersistedResult[R]": - result = cls() - # Only store the object in local memory, it will not be sent to the API - if cache_object: - result._cache_object(obj) - return result - - class PersistedResult(BaseResult): """ Result type which stores a reference to a persisted result. @@ -476,12 +440,19 @@ class PersistedResult(BaseResult): storage_key: str storage_block_id: Optional[uuid.UUID] = None expiration: Optional[DateTime] = None + serialize_to_none: bool = False - _should_cache_object: bool = PrivateAttr(default=True) _persisted: bool = PrivateAttr(default=False) + _should_cache_object: bool = PrivateAttr(default=True) _storage_block: WritableFileSystem = PrivateAttr(default=None) _serializer: Serializer = PrivateAttr(default=None) + def model_dump(self, *args, **kwargs): + if self.serialize_to_none: + return None + else: + return super().model_dump(*args, **kwargs) + def _cache_object( self, obj: Any, @@ -547,7 +518,7 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None: Write the result to the storage block. """ - if self._persisted: + if self._persisted or self.serialize_to_none: # don't double write or overwrite return @@ -627,7 +598,7 @@ async def create( storage_block_id: Optional[uuid.UUID] = None, cache_object: bool = True, expiration: Optional[DateTime] = None, - defer_persistence: bool = False, + serialize_to_none: bool = False, ) -> "PersistedResult[R]": """ Create a new result reference from a user's object. @@ -651,24 +622,13 @@ async def create( storage_block_id=storage_block_id, storage_key=key, expiration=expiration, + serialize_to_none=serialize_to_none, ) - if cache_object and not defer_persistence: - # Attach the object to the result so it's available without deserialization - result._cache_object( - obj, storage_block=storage_block, serializer=serializer - ) - object.__setattr__(result, "_should_cache_object", cache_object) - - if not defer_persistence: - await result.write(obj=obj) - else: - # we must cache temporarily to allow for writing later - # the cache will be removed on write - result._cache_object( - obj, storage_block=storage_block, serializer=serializer - ) + # we must cache temporarily to allow for writing later + # the cache will be removed on write + result._cache_object(obj, storage_block=storage_block, serializer=serializer) return result @@ -701,40 +661,3 @@ def load(self) -> Any: def to_bytes(self) -> bytes: return self.model_dump_json(serialize_as_any=True).encode() - - -class UnknownResult(BaseResult): - """ - Result type for unknown results. Typically used to represent the result - of tasks that were forced from a failure state into a completed state. - - The value for this result is always None and is not persisted to external - result storage, but orchestration treats the result the same as persisted - results when determining orchestration rules, such as whether to rerun a - completed task. - """ - - type: str = "unknown" - value: None - - def has_cached_object(self) -> bool: - # This result type always has the object cached in memory - return True - - @sync_compatible - async def get(self) -> R: - return self.value - - @classmethod - @sync_compatible - async def create( - cls: "Type[UnknownResult]", - obj: R = None, - ) -> "UnknownResult[R]": - if obj is not None: - raise TypeError( - f"Unsupported type {type(obj).__name__!r} for unknown result. " - "Only None is supported." - ) - - return cls(value=obj) diff --git a/src/prefect/server/orchestration/core_policy.py b/src/prefect/server/orchestration/core_policy.py index 886df783d692..4993a7d64a5b 100644 --- a/src/prefect/server/orchestration/core_policy.py +++ b/src/prefect/server/orchestration/core_policy.py @@ -13,7 +13,6 @@ from packaging.version import Version from sqlalchemy import select -from prefect.results import UnknownResult from prefect.server import models from prefect.server.database.dependencies import inject_db from prefect.server.database.interface import PrefectDBInterface @@ -130,7 +129,6 @@ def priority(): class MinimalFlowPolicy(BaseOrchestrationPolicy): def priority(): return [ - AddUnknownResult, # mark forced completions with an unknown result BypassCancellingFlowRunsWithNoInfra, # cancel scheduled or suspended runs from the UI InstrumentFlowRunStateTransitions, ] @@ -148,7 +146,6 @@ class MinimalTaskPolicy(BaseOrchestrationPolicy): def priority(): return [ ReleaseTaskConcurrencySlots, # always release concurrency slots - AddUnknownResult, # mark forced completions with a result placeholder ] @@ -260,45 +257,6 @@ async def after_transition( cl.active_slots = list(active_slots) -class AddUnknownResult(BaseOrchestrationRule): - """ - Assign an "unknown" result to runs that are forced to complete from a - failed or crashed state, if the previous state used a persisted result. - - When we retry a flow run, we retry any task runs that were in a failed or - crashed state, but we also retry completed task runs that didn't use a - persisted result. This means that without a sentinel value for unknown - results, a task run forced into Completed state will always get rerun if the - flow run retries because the task run lacks a persisted result. The - "unknown" sentinel ensures that when we see a completed task run with an - unknown result, we know that it was forced to complete and we shouldn't - rerun it. - - Flow runs forced into a Completed state have a similar problem: without a - sentinel value, attempting to refer to the flow run's result will raise an - exception because the flow run has no result. The sentinel ensures that we - can distinguish between a flow run that has no result and a flow run that - has an unknown result. - """ - - FROM_STATES = [StateType.FAILED, StateType.CRASHED] - TO_STATES = [StateType.COMPLETED] - - async def before_transition( - self, - initial_state: Optional[states.State], - proposed_state: Optional[states.State], - context: TaskOrchestrationContext, - ) -> None: - if ( - initial_state - and initial_state.data - and initial_state.data.get("type") == "reference" - ): - unknown_result = await UnknownResult.create() - self.context.proposed_state.data = unknown_result.model_dump() - - class CacheInsertion(BaseOrchestrationRule): """ Caches completed states with cache keys after they are validated. diff --git a/src/prefect/states.py b/src/prefect/states.py index 59a7b721ab19..c47358f0c01c 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -218,6 +218,7 @@ async def exception_to_crashed_state( async def exception_to_failed_state( exc: Optional[BaseException] = None, result_factory: Optional[ResultFactory] = None, + write_result: bool = False, **kwargs, ) -> State: """ @@ -234,6 +235,8 @@ async def exception_to_failed_state( if result_factory: data = await result_factory.create_result(exc) + if write_result: + await data.write() else: # Attach the exception for local usage, will not be available when retrieved # from the API @@ -258,7 +261,7 @@ async def return_value_to_state( result_factory: ResultFactory, key: Optional[str] = None, expiration: Optional[datetime.datetime] = None, - defer_persistence: bool = False, + write_result: bool = False, ) -> State[R]: """ Given a return value from a user's function, create a `State` the run should @@ -291,13 +294,14 @@ async def return_value_to_state( # Unless the user has already constructed a result explicitly, use the factory # to update the data to the correct type if not isinstance(state.data, BaseResult): - state.data = await result_factory.create_result( + result = await result_factory.create_result( state.data, key=key, expiration=expiration, - defer_persistence=defer_persistence, ) - + if write_result: + await result.write() + state.data = result return state # Determine a new state from the aggregate of contained states @@ -333,15 +337,17 @@ async def return_value_to_state( # TODO: We may actually want to set the data to a `StateGroup` object and just # allow it to be unpacked into a tuple and such so users can interact with # it + result = await result_factory.create_result( + retval, + key=key, + expiration=expiration, + ) + if write_result: + await result.write() return State( type=new_state_type, message=message, - data=await result_factory.create_result( - retval, - key=key, - expiration=expiration, - defer_persistence=defer_persistence, - ), + data=result, ) # Generators aren't portable, implicitly convert them to a list. @@ -354,14 +360,14 @@ async def return_value_to_state( if isinstance(data, BaseResult): return Completed(data=data) else: - return Completed( - data=await result_factory.create_result( - data, - key=key, - expiration=expiration, - defer_persistence=defer_persistence, - ) + result = await result_factory.create_result( + data, + key=key, + expiration=expiration, ) + if write_result: + await result.write() + return Completed(data=result) @sync_compatible diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index a7869378a830..755ecbfcac71 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -464,8 +464,6 @@ def handle_success(self, result: R, transaction: Transaction) -> R: result_factory=result_factory, key=transaction.key, expiration=expiration, - # defer persistence to transaction commit - defer_persistence=True, ) ) transaction.stage( @@ -536,6 +534,7 @@ def handle_exception(self, exc: Exception) -> None: exc, message="Task run encountered an exception", result_factory=getattr(context, "result_factory", None), + write_result=True, ) ) self.record_terminal_state_timing(state) @@ -969,8 +968,6 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: result_factory=result_factory, key=transaction.key, expiration=expiration, - # defer persistence to transaction commit - defer_persistence=True, ) transaction.stage( terminal_state.data, diff --git a/tests/events/client/instrumentation/test_task_run_state_change_events.py b/tests/events/client/instrumentation/test_task_run_state_change_events.py index db4b641b1ed1..b402b8d5d1f7 100644 --- a/tests/events/client/instrumentation/test_task_run_state_change_events.py +++ b/tests/events/client/instrumentation/test_task_run_state_change_events.py @@ -187,7 +187,7 @@ def happy_path(): "name": "Completed", "message": "", "state_details": {}, - "data": {"type": "unpersisted"}, + "data": None, }, "task_run": { "dynamic_key": task_run.dynamic_key, @@ -387,7 +387,7 @@ def happy_path(): "Here's a happy little accident." ), "state_details": {"retriable": False}, - "data": {"type": "unpersisted"}, + "data": None, }, "task_run": { "dynamic_key": task_run.dynamic_key, diff --git a/tests/results/test_flow_results.py b/tests/results/test_flow_results.py index 1f8d276383a9..4a619d15e102 100644 --- a/tests/results/test_flow_results.py +++ b/tests/results/test_flow_results.py @@ -10,8 +10,8 @@ from prefect.exceptions import MissingResult from prefect.filesystems import LocalFileSystem from prefect.results import ( + PersistedResult, PersistedResultBlob, - UnpersistedResult, ) from prefect.serializers import ( CompressedSerializer, @@ -60,22 +60,6 @@ def foo(): await api_state.result() -async def test_flow_with_uncached_and_unpersisted_result(prefect_client): - @flow(persist_result=False, cache_result_in_memory=False) - def foo(): - return 1 - - state = foo(return_state=True) - with pytest.raises(MissingResult): - await state.result() - - api_state = ( - await prefect_client.read_flow_run(state.state_details.flow_run_id) - ).state - with pytest.raises(MissingResult): - await api_state.result() - - async def test_flow_with_uncached_and_unpersisted_null_result(prefect_client): @flow(persist_result=False, cache_result_in_memory=False) def foo(): @@ -308,7 +292,8 @@ def bar(): parent_state = foo(return_state=True) child_state = await parent_state.result() - assert isinstance(child_state.data, UnpersistedResult) + assert isinstance(child_state.data, PersistedResult) + assert child_state.data._persisted is False assert await child_state.result() is None api_state = ( diff --git a/tests/results/test_persisted_result.py b/tests/results/test_persisted_result.py index 5bd3cf85a36c..cb761f357d6e 100644 --- a/tests/results/test_persisted_result.py +++ b/tests/results/test_persisted_result.py @@ -26,6 +26,7 @@ async def test_result_reference_create_and_get(cache_object, storage_block): serializer=JSONSerializer(), cache_object=cache_object, ) + await result.write() if not cache_object: assert not result.has_cached_object() @@ -43,6 +44,7 @@ async def test_result_reference_create_uses_storage(storage_block): storage_key_fn=DEFAULT_STORAGE_KEY_FN, serializer=JSONSerializer(), ) + await result.write() assert result.storage_block_id == storage_block._block_document_id contents = await storage_block.read_path(result.storage_key) @@ -59,6 +61,7 @@ async def test_result_reference_create_uses_serializer(storage_block): storage_key_fn=DEFAULT_STORAGE_KEY_FN, serializer=serializer, ) + await result.write() assert result.serializer_type == serializer.type contents = await storage_block.read_path(result.storage_key) @@ -79,6 +82,7 @@ async def test_result_reference_file_blob_is_json(storage_block): storage_key_fn=DEFAULT_STORAGE_KEY_FN, serializer=serializer, ) + await result.write() contents = await storage_block.read_path(result.storage_key) @@ -100,6 +104,7 @@ async def test_result_reference_create_uses_storage_key_fn(storage_block): storage_key_fn=lambda: "test", serializer=JSONSerializer(), ) + await result.write() assert result.storage_key == "test" contents = await storage_block.read_path("test") @@ -185,7 +190,6 @@ async def test_write_is_idempotent(storage_block): storage_block=storage_block, storage_key_fn=lambda: "test-defer-path", serializer=JSONSerializer(), - defer_persistence=True, ) with pytest.raises(ValueError, match="does not exist"): @@ -200,14 +204,13 @@ async def test_write_is_idempotent(storage_block): assert blob.load() == "test-defer" -async def test_lifecycle_of_defer_persistence(storage_block): +async def test_lifecycle_of_deferred_persistence(storage_block): result = await PersistedResult.create( "test-defer", storage_block_id=storage_block._block_document_id, storage_block=storage_block, storage_key_fn=lambda: "test-defer-path", serializer=JSONSerializer(), - defer_persistence=True, ) assert await result.get() == "test-defer" diff --git a/tests/results/test_state_result.py b/tests/results/test_state_result.py index eb9c44ae6f76..5b95839f8d43 100644 --- a/tests/results/test_state_result.py +++ b/tests/results/test_state_result.py @@ -10,7 +10,7 @@ import prefect.states from prefect.exceptions import UnfinishedRun from prefect.filesystems import LocalFileSystem, WritableFileSystem -from prefect.results import PersistedResult, PersistedResultBlob, UnpersistedResult +from prefect.results import PersistedResult, PersistedResultBlob, ResultFactory from prefect.serializers import JSONSerializer from prefect.states import State, StateType from prefect.utilities.annotations import NotSet @@ -24,8 +24,7 @@ async def test_unfinished_states_raise_on_result_retrieval( state_type: StateType, raise_on_failure: bool ): - # We'll even attach a result to the state, but it shouldn't matter - state = State(type=state_type, data=await UnpersistedResult.create("test")) + state = State(type=state_type) with pytest.raises(UnfinishedRun): # raise_on_failure should have no effect here @@ -36,8 +35,13 @@ async def test_unfinished_states_raise_on_result_retrieval( "state_type", [StateType.CRASHED, StateType.COMPLETED, StateType.FAILED, StateType.CANCELLED], ) -async def test_finished_states_allow_result_retrieval(state_type: StateType): - state = State(type=state_type, data=await UnpersistedResult.create("test")) +async def test_finished_states_allow_result_retrieval( + prefect_client, state_type: StateType +): + factory = await ResultFactory.default_factory( + client=prefect_client, persist_result=True + ) + state = State(type=state_type, data=await factory.create_result("test")) assert await state.result(raise_on_failure=False) == "test" @@ -65,7 +69,6 @@ async def a_real_result(storage_block: WritableFileSystem) -> PersistedResult: storage_block=storage_block, storage_key_fn=lambda: "test-graceful-retry-path", serializer=JSONSerializer(), - defer_persistence=True, ) diff --git a/tests/results/test_task_results.py b/tests/results/test_task_results.py index c2932b96780a..c18ee56ee4b1 100644 --- a/tests/results/test_task_results.py +++ b/tests/results/test_task_results.py @@ -2,7 +2,6 @@ import pytest -from prefect.exceptions import MissingResult from prefect.filesystems import LocalFileSystem from prefect.flows import flow from prefect.serializers import JSONSerializer, PickleSerializer @@ -86,56 +85,6 @@ def bar(): assert await api_state.result() == 1 -async def test_task_with_uncached_and_unpersisted_result( - prefect_client, events_pipeline -): - @flow - def foo(): - return bar(return_state=True) - - @task(persist_result=False, cache_result_in_memory=False) - def bar(): - return 1 - - flow_state = foo(return_state=True) - task_state = await flow_state.result() - with pytest.raises(MissingResult): - await task_state.result() - - await events_pipeline.process_events() - - api_state = ( - await prefect_client.read_task_run(task_state.state_details.task_run_id) - ).state - with pytest.raises(MissingResult): - await api_state.result() - - -async def test_task_with_uncached_and_unpersisted_null_result( - prefect_client, events_pipeline -): - @flow - def foo(): - return bar(return_state=True) - - @task(persist_result=False, cache_result_in_memory=False) - def bar(): - return None - - flow_state = foo(return_state=True) - task_state = await flow_state.result() - # Nulls do not consume memory and are still available - assert await task_state.result() is None - - await events_pipeline.process_events() - - api_state = ( - await prefect_client.read_task_run(task_state.state_details.task_run_id) - ).state - with pytest.raises(MissingResult): - await api_state.result() - - async def test_task_with_uncached_but_persisted_result(prefect_client, events_pipeline): @flow def foo(): diff --git a/tests/results/test_unknown_results.py b/tests/results/test_unknown_results.py deleted file mode 100644 index e5e8c513bfa5..000000000000 --- a/tests/results/test_unknown_results.py +++ /dev/null @@ -1,65 +0,0 @@ -import json - -import pytest - -from prefect import flow -from prefect.results import BaseResult, UnknownResult - -INVALID_VALUES = [True, False, "hey"] - - -@pytest.mark.parametrize("value", INVALID_VALUES) -async def test_unknown_result_invalid_values(value): - with pytest.raises(TypeError, match="Unsupported type"): - await UnknownResult.create(value) - - -def test_unknown_result_create_and_get_sync(): - @flow - def sync(): - result = UnknownResult.create() - return result.get() - - assert sync() is None - - -async def test_unknown_result_create_and_get_async(): - result = await UnknownResult.create() - assert await result.get() is None - - -def test_unknown_result_create_and_get_with_explicit_value(): - @flow - def sync(): - result = UnknownResult.create(obj=None) - return result.get() - - assert sync() is None - - -async def test_result_unknown_json_roundtrip(): - result = await UnknownResult.create() - serialized = result.model_dump_json() - deserialized = UnknownResult.model_validate_json(serialized) - assert await deserialized.get() is None - - -async def test_unknown_result_json_roundtrip_base_result_parser(): - result = await UnknownResult.create() - serialized = result.model_dump_json() - deserialized = BaseResult.model_validate_json(serialized) - assert await deserialized.get() is None - - -async def test_unknown_result_null_is_distinguishable_from_none(): - """ - This is important for separating cases where _no result_ is stored in the database - because the user disabled persistence (for example) from cases where the result - is stored but is a null value. - """ - result = await UnknownResult.create(None) - assert result is not None - serialized = result.model_dump_json() - assert serialized is not None - assert serialized != "null" - assert json.loads(serialized) is not None diff --git a/tests/results/test_unpersisted_result.py b/tests/results/test_unpersisted_result.py deleted file mode 100644 index 52545492cda7..000000000000 --- a/tests/results/test_unpersisted_result.py +++ /dev/null @@ -1,47 +0,0 @@ -from dataclasses import dataclass - -import pytest - -from prefect import flow -from prefect.results import MissingResult, UnpersistedResult - - -@dataclass -class MyDataClass: - x: int - - -TEST_VALUES = [None, "test", MyDataClass(x=1)] - - -@pytest.mark.parametrize("value", TEST_VALUES) -async def test_unpersisted_result_create_and_get(value): - result = await UnpersistedResult.create(value) - assert await result.get() == value - - -@pytest.mark.parametrize("value", TEST_VALUES) -def test_unpersisted_result_create_and_get_sync(value): - @flow - def sync(): - result = UnpersistedResult.create(value) - return result.get() - - output = sync() - assert output == value - - -@pytest.mark.parametrize("value", TEST_VALUES) -async def test_unpersisted_result_create_and_get_no_cache(value): - result = await UnpersistedResult.create(value, cache_object=False) - with pytest.raises(MissingResult): - await result.get() - - -@pytest.mark.parametrize("value", TEST_VALUES) -async def test_unpersisted_result_missing_after_json_roundtrip(value): - result = await UnpersistedResult.create(value) - serialized = result.model_dump_json() - deserialized = UnpersistedResult.model_validate_json(serialized) - with pytest.raises(MissingResult): - await deserialized.get() diff --git a/tests/server/orchestration/test_core_policy.py b/tests/server/orchestration/test_core_policy.py index 2aa3885aa322..114f46f64673 100644 --- a/tests/server/orchestration/test_core_policy.py +++ b/tests/server/orchestration/test_core_policy.py @@ -10,14 +10,11 @@ from prefect.results import ( PersistedResult, - UnknownResult, - UnpersistedResult, ) from prefect.server import schemas from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models import concurrency_limits, flow_runs from prefect.server.orchestration.core_policy import ( - AddUnknownResult, BypassCancellingFlowRunsWithNoInfra, CacheInsertion, CacheRetrieval, @@ -1285,14 +1282,12 @@ async def test_transitions_from_terminal_states_to_cancelling_are_aborted( ], ids=transition_names, ) - @pytest.mark.parametrize("result_type", [None, UnpersistedResult]) async def test_transitions_from_completed_to_non_final_states_allowed_without_persisted_result( self, session, run_type, initialize_orchestration, intended_transition, - result_type, ): if run_type == "flow" and intended_transition[1] == StateType.SCHEDULED: pytest.skip( @@ -1301,12 +1296,7 @@ async def test_transitions_from_completed_to_non_final_states_allowed_without_pe ) ctx = await initialize_orchestration( - session, - run_type, - *intended_transition, - initial_state_data=result_type.model_construct().model_dump() - if result_type - else None, + session, run_type, *intended_transition, initial_state_data=None ) if run_type == "task": @@ -1336,22 +1326,18 @@ async def test_transitions_from_completed_to_non_final_states_allowed_without_pe ], ids=transition_names, ) - @pytest.mark.parametrize("result_type", [PersistedResult, UnknownResult]) async def test_transitions_from_completed_to_non_final_states_rejected_with_persisted_result( self, session, run_type, initialize_orchestration, intended_transition, - result_type, ): ctx = await initialize_orchestration( session, run_type, *intended_transition, - initial_state_data=result_type.model_construct().model_dump() - if result_type - else None, + initial_state_data=PersistedResult.model_construct().model_dump(), ) if run_type == "task": @@ -3203,75 +3189,6 @@ async def test_allows_all_other_transitions( assert ctx.validated_state_type == states.StateType.CANCELLING -@pytest.mark.parametrize("run_type", ["task", "flow"]) -class TestAddUnknownResultRule: - @pytest.mark.parametrize( - "result_type,initial_state_type", - list( - product( - (PersistedResult,), (states.StateType.FAILED, states.StateType.CRASHED) - ) - ), - ) - async def test_saves_unknown_result_if_last_state_used_persisted_results( - self, - session, - initialize_orchestration, - run_type, - result_type, - initial_state_type, - ): - proposed_state_type = states.StateType.COMPLETED - intended_transition = (initial_state_type, proposed_state_type) - ctx = await initialize_orchestration( - session, - run_type, - *intended_transition, - initial_state_data=result_type.model_construct().model_dump() - if result_type - else None, - ) - - async with AddUnknownResult(ctx, *intended_transition) as ctx: - await ctx.validate_proposed_state() - - assert ctx.proposed_state.data.get("type") == "unknown" - - @pytest.mark.parametrize( - "result_type,initial_state_type", - list( - product( - (UnpersistedResult,), - (states.StateType.FAILED, states.StateType.CRASHED), - ) - ), - ) - async def test_does_not_save_unknown_result_if_last_result_did_not_use_persisted_results( - self, - session, - initialize_orchestration, - run_type, - result_type, - initial_state_type, - ): - proposed_state_type = states.StateType.COMPLETED - intended_transition = (initial_state_type, proposed_state_type) - ctx = await initialize_orchestration( - session, - run_type, - *intended_transition, - initial_state_data=result_type.model_construct().model_dump() - if result_type - else None, - ) - - async with AddUnknownResult(ctx, *intended_transition) as ctx: - await ctx.validate_proposed_state() - - if ctx.proposed_state.data: - assert ctx.proposed_state.data.get("type") != "unknown" - - class TestPreventDuplicateTransitions: async def test_no_transition_ids( self, diff --git a/tests/test_futures.py b/tests/test_futures.py index 48ce9fdadfcf..a4a7cdca93a7 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -408,10 +408,7 @@ def my_task(): await events_pipeline.process_events() - with pytest.raises( - MissingResult, - match="The result was not persisted and is no longer available.", - ): + with pytest.raises(MissingResult): future.result() diff --git a/tests/test_states.py b/tests/test_states.py index 7edcd00db52d..e3f0845c1790 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -7,7 +7,6 @@ from prefect.results import ( PersistedResult, ResultFactory, - UnpersistedResult, ) from prefect.states import ( Cancelled, @@ -149,7 +148,8 @@ async def test_returns_single_state_with_null_data_and_persist_off( state = Completed(data=None) result_state = await return_value_to_state(state, factory) assert result_state is state - assert isinstance(result_state.data, UnpersistedResult) + assert isinstance(result_state.data, PersistedResult) + assert result_state.data._persisted is False assert await result_state.result() is None async def test_returns_single_state_with_data_to_persist(self, prefect_client): diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 3b327fbaaa50..3f3d9e7124f7 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -32,7 +32,7 @@ from prefect.exceptions import CrashedRun, MissingResult from prefect.filesystems import LocalFileSystem from prefect.logging import get_run_logger -from prefect.results import PersistedResult, ResultFactory, UnpersistedResult +from prefect.results import PersistedResult, ResultFactory from prefect.server.schemas.core import ConcurrencyLimitV2 from prefect.settings import ( PREFECT_TASK_DEFAULT_RETRIES, @@ -1653,6 +1653,7 @@ async def async_task(): client=prefect_client, persist_result=True ) result = await factory.create_result(42) + await result.write() return result assert await async_task() == 42 @@ -1665,7 +1666,8 @@ async def test_task_loads_result_if_exists_using_result_storage_key( factory = await ResultFactory.default_factory( client=prefect_client, persist_result=True ) - await factory.create_result(-92, key="foo-bar") + result = await factory.create_result(-92, key="foo-bar") + await result.write() @task(result_storage_key="foo-bar", persist_result=True) async def async_task(): @@ -1776,7 +1778,8 @@ async def async_task(): assert state.is_completed() assert await state.result() == 1800 - assert isinstance(state.data, UnpersistedResult) + assert isinstance(state.data, PersistedResult) + assert state.data._persisted is False async def test_none_return_value_does_persist(self, prefect_client, tmp_path): fs = LocalFileSystem(basepath=tmp_path) diff --git a/tests/test_transactions.py b/tests/test_transactions.py index 3f39e251fc6b..92bf1a9ad48c 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -337,6 +337,7 @@ async def test_task(): result = await txn.store.result_factory.create_result( obj={"foo": "bar"}, key=txn.key ) + await result.write() txn.stage(result) result = txn.read() @@ -360,6 +361,7 @@ async def test_task(): result = await txn.store.result_factory.create_result( obj={"foo": "bar"}, key=txn.key ) + await result.write() txn.stage(result) result = txn.read()