Skip to content

Commit

Permalink
Remove UnpersistedResult (#15056)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Aug 23, 2024
1 parent cb2874f commit 444eeb2
Show file tree
Hide file tree
Showing 18 changed files with 80 additions and 442 deletions.
7 changes: 6 additions & 1 deletion src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,16 @@ def to_state_create(self):
from prefect.client.schemas.actions import StateCreate
from prefect.results import BaseResult

if isinstance(self.data, BaseResult) and self.data.serialize_to_none is False:
data = self.data
else:
data = None

return StateCreate(
type=self.type,
name=self.name,
message=self.message,
data=self.data if isinstance(self.data, BaseResult) else None,
data=data,
state_details=self.state_details,
)

Expand Down
2 changes: 2 additions & 0 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def handle_success(self, result: R) -> R:
return_value_to_state(
resolved_result,
result_factory=result_factory,
write_result=True,
)
)
self.set_state(terminal_state)
Expand All @@ -287,6 +288,7 @@ def handle_exception(
message=msg or "Flow run encountered an exception:",
result_factory=result_factory
or getattr(context, "result_factory", None),
write_result=True,
)
)
state = self.set_state(terminal_state)
Expand Down
107 changes: 15 additions & 92 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import prefect
from prefect.blocks.core import Block
from prefect.client.utilities import inject_client
from prefect.exceptions import MissingResult
from prefect.filesystems import (
LocalFileSystem,
WritableFileSystem,
Expand Down Expand Up @@ -332,22 +331,15 @@ async def create_result(
obj: R,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
defer_persistence: bool = False,
) -> Union[R, "BaseResult[R]"]:
"""
Create a result type for the given object.
If persistence is disabled, the object is wrapped in an `UnpersistedResult` and
returned.
If persistence is enabled the object is serialized, persisted to storage, and a reference is returned.
"""
# Null objects are "cached" in memory at no cost
should_cache_object = self.cache_result_in_memory or obj is None

if not self.persist_result:
return await UnpersistedResult.create(obj, cache_object=should_cache_object)

if key:

def key_fn():
Expand All @@ -365,7 +357,7 @@ def key_fn():
serializer=self.serializer,
cache_object=should_cache_object,
expiration=expiration,
defer_persistence=defer_persistence,
serialize_to_none=not self.persist_result,
)

@sync_compatible
Expand Down Expand Up @@ -432,34 +424,6 @@ def __dispatch_key__(cls, **kwargs):
return cls.__name__ if isinstance(default, PydanticUndefinedType) else default


class UnpersistedResult(BaseResult):
"""
Result type for results that are not persisted outside of local memory.
"""

type: str = "unpersisted"

@sync_compatible
async def get(self) -> R:
if self.has_cached_object():
return self._cache

raise MissingResult("The result was not persisted and is no longer available.")

@classmethod
@sync_compatible
async def create(
cls: "Type[UnpersistedResult]",
obj: R,
cache_object: bool = True,
) -> "UnpersistedResult[R]":
result = cls()
# Only store the object in local memory, it will not be sent to the API
if cache_object:
result._cache_object(obj)
return result


class PersistedResult(BaseResult):
"""
Result type which stores a reference to a persisted result.
Expand All @@ -476,12 +440,19 @@ class PersistedResult(BaseResult):
storage_key: str
storage_block_id: Optional[uuid.UUID] = None
expiration: Optional[DateTime] = None
serialize_to_none: bool = False

_should_cache_object: bool = PrivateAttr(default=True)
_persisted: bool = PrivateAttr(default=False)
_should_cache_object: bool = PrivateAttr(default=True)
_storage_block: WritableFileSystem = PrivateAttr(default=None)
_serializer: Serializer = PrivateAttr(default=None)

def model_dump(self, *args, **kwargs):
if self.serialize_to_none:
return None
else:
return super().model_dump(*args, **kwargs)

def _cache_object(
self,
obj: Any,
Expand Down Expand Up @@ -547,7 +518,7 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
Write the result to the storage block.
"""

if self._persisted:
if self._persisted or self.serialize_to_none:
# don't double write or overwrite
return

Expand Down Expand Up @@ -627,7 +598,7 @@ async def create(
storage_block_id: Optional[uuid.UUID] = None,
cache_object: bool = True,
expiration: Optional[DateTime] = None,
defer_persistence: bool = False,
serialize_to_none: bool = False,
) -> "PersistedResult[R]":
"""
Create a new result reference from a user's object.
Expand All @@ -651,24 +622,13 @@ async def create(
storage_block_id=storage_block_id,
storage_key=key,
expiration=expiration,
serialize_to_none=serialize_to_none,
)

if cache_object and not defer_persistence:
# Attach the object to the result so it's available without deserialization
result._cache_object(
obj, storage_block=storage_block, serializer=serializer
)

object.__setattr__(result, "_should_cache_object", cache_object)

if not defer_persistence:
await result.write(obj=obj)
else:
# we must cache temporarily to allow for writing later
# the cache will be removed on write
result._cache_object(
obj, storage_block=storage_block, serializer=serializer
)
# we must cache temporarily to allow for writing later
# the cache will be removed on write
result._cache_object(obj, storage_block=storage_block, serializer=serializer)

return result

Expand Down Expand Up @@ -701,40 +661,3 @@ def load(self) -> Any:

def to_bytes(self) -> bytes:
return self.model_dump_json(serialize_as_any=True).encode()


class UnknownResult(BaseResult):
"""
Result type for unknown results. Typically used to represent the result
of tasks that were forced from a failure state into a completed state.
The value for this result is always None and is not persisted to external
result storage, but orchestration treats the result the same as persisted
results when determining orchestration rules, such as whether to rerun a
completed task.
"""

type: str = "unknown"
value: None

def has_cached_object(self) -> bool:
# This result type always has the object cached in memory
return True

@sync_compatible
async def get(self) -> R:
return self.value

@classmethod
@sync_compatible
async def create(
cls: "Type[UnknownResult]",
obj: R = None,
) -> "UnknownResult[R]":
if obj is not None:
raise TypeError(
f"Unsupported type {type(obj).__name__!r} for unknown result. "
"Only None is supported."
)

return cls(value=obj)
42 changes: 0 additions & 42 deletions src/prefect/server/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from packaging.version import Version
from sqlalchemy import select

from prefect.results import UnknownResult
from prefect.server import models
from prefect.server.database.dependencies import inject_db
from prefect.server.database.interface import PrefectDBInterface
Expand Down Expand Up @@ -130,7 +129,6 @@ def priority():
class MinimalFlowPolicy(BaseOrchestrationPolicy):
def priority():
return [
AddUnknownResult, # mark forced completions with an unknown result
BypassCancellingFlowRunsWithNoInfra, # cancel scheduled or suspended runs from the UI
InstrumentFlowRunStateTransitions,
]
Expand All @@ -148,7 +146,6 @@ class MinimalTaskPolicy(BaseOrchestrationPolicy):
def priority():
return [
ReleaseTaskConcurrencySlots, # always release concurrency slots
AddUnknownResult, # mark forced completions with a result placeholder
]


Expand Down Expand Up @@ -260,45 +257,6 @@ async def after_transition(
cl.active_slots = list(active_slots)


class AddUnknownResult(BaseOrchestrationRule):
"""
Assign an "unknown" result to runs that are forced to complete from a
failed or crashed state, if the previous state used a persisted result.
When we retry a flow run, we retry any task runs that were in a failed or
crashed state, but we also retry completed task runs that didn't use a
persisted result. This means that without a sentinel value for unknown
results, a task run forced into Completed state will always get rerun if the
flow run retries because the task run lacks a persisted result. The
"unknown" sentinel ensures that when we see a completed task run with an
unknown result, we know that it was forced to complete and we shouldn't
rerun it.
Flow runs forced into a Completed state have a similar problem: without a
sentinel value, attempting to refer to the flow run's result will raise an
exception because the flow run has no result. The sentinel ensures that we
can distinguish between a flow run that has no result and a flow run that
has an unknown result.
"""

FROM_STATES = [StateType.FAILED, StateType.CRASHED]
TO_STATES = [StateType.COMPLETED]

async def before_transition(
self,
initial_state: Optional[states.State],
proposed_state: Optional[states.State],
context: TaskOrchestrationContext,
) -> None:
if (
initial_state
and initial_state.data
and initial_state.data.get("type") == "reference"
):
unknown_result = await UnknownResult.create()
self.context.proposed_state.data = unknown_result.model_dump()


class CacheInsertion(BaseOrchestrationRule):
"""
Caches completed states with cache keys after they are validated.
Expand Down
40 changes: 23 additions & 17 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ async def exception_to_crashed_state(
async def exception_to_failed_state(
exc: Optional[BaseException] = None,
result_factory: Optional[ResultFactory] = None,
write_result: bool = False,
**kwargs,
) -> State:
"""
Expand All @@ -234,6 +235,8 @@ async def exception_to_failed_state(

if result_factory:
data = await result_factory.create_result(exc)
if write_result:
await data.write()
else:
# Attach the exception for local usage, will not be available when retrieved
# from the API
Expand All @@ -258,7 +261,7 @@ async def return_value_to_state(
result_factory: ResultFactory,
key: Optional[str] = None,
expiration: Optional[datetime.datetime] = None,
defer_persistence: bool = False,
write_result: bool = False,
) -> State[R]:
"""
Given a return value from a user's function, create a `State` the run should
Expand Down Expand Up @@ -291,13 +294,14 @@ async def return_value_to_state(
# Unless the user has already constructed a result explicitly, use the factory
# to update the data to the correct type
if not isinstance(state.data, BaseResult):
state.data = await result_factory.create_result(
result = await result_factory.create_result(
state.data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)

if write_result:
await result.write()
state.data = result
return state

# Determine a new state from the aggregate of contained states
Expand Down Expand Up @@ -333,15 +337,17 @@ async def return_value_to_state(
# TODO: We may actually want to set the data to a `StateGroup` object and just
# allow it to be unpacked into a tuple and such so users can interact with
# it
result = await result_factory.create_result(
retval,
key=key,
expiration=expiration,
)
if write_result:
await result.write()
return State(
type=new_state_type,
message=message,
data=await result_factory.create_result(
retval,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
),
data=result,
)

# Generators aren't portable, implicitly convert them to a list.
Expand All @@ -354,14 +360,14 @@ async def return_value_to_state(
if isinstance(data, BaseResult):
return Completed(data=data)
else:
return Completed(
data=await result_factory.create_result(
data,
key=key,
expiration=expiration,
defer_persistence=defer_persistence,
)
result = await result_factory.create_result(
data,
key=key,
expiration=expiration,
)
if write_result:
await result.write()
return Completed(data=result)


@sync_compatible
Expand Down
Loading

0 comments on commit 444eeb2

Please sign in to comment.