Skip to content

Commit

Permalink
disable task run recorder in tests (#14969)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan authored Aug 16, 2024
1 parent 56d6fb0 commit 2a797a7
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 9 deletions.
41 changes: 41 additions & 0 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,47 @@ async def _emit(self, event: Event) -> None:
await asyncio.sleep(1)


class AssertingPassthroughEventsClient(PrefectEventsClient):
"""A Prefect Events client that BOTH records all events sent to it for inspection
during tests AND sends them to a Prefect server."""

last: ClassVar["Optional[AssertingPassthroughEventsClient]"] = None
all: ClassVar[List["AssertingPassthroughEventsClient"]] = []

args: Tuple
kwargs: Dict[str, Any]
events: List[Event]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
AssertingPassthroughEventsClient.last = self
AssertingPassthroughEventsClient.all.append(self)
self.args = args
self.kwargs = kwargs

@classmethod
def reset(cls) -> None:
cls.last = None
cls.all = []

def pop_events(self) -> List[Event]:
events = self.events
self.events = []
return events

async def _emit(self, event: Event) -> None:
# actually send the event to the server
await super()._emit(event)

# record the event for inspection
self.events.append(event)

async def __aenter__(self) -> Self:
await super().__aenter__()
self.events = []
return self


class PrefectCloudEventsClient(PrefectEventsClient):
"""A Prefect Events client that streams events to a Prefect Cloud Workspace"""

Expand Down
2 changes: 2 additions & 0 deletions src/prefect/events/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .clients import (
AssertingEventsClient,
AssertingPassthroughEventsClient,
PrefectCloudEventsClient,
PrefectEventsClient,
)
Expand Down Expand Up @@ -49,6 +50,7 @@ def emit_event(
return None

operational_clients = [
AssertingPassthroughEventsClient,
AssertingEventsClient,
PrefectCloudEventsClient,
PrefectEventsClient,
Expand Down
6 changes: 3 additions & 3 deletions src/prefect/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,15 @@ async def start_services():
service_instances.append(ProactiveTriggers())
service_instances.append(Actions())

if prefect.settings.PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED:
service_instances.append(TaskRunRecorder())

if prefect.settings.PREFECT_API_SERVICES_EVENT_PERSISTER_ENABLED:
service_instances.append(EventPersister())

if prefect.settings.PREFECT_API_EVENTS_STREAM_OUT_ENABLED:
service_instances.append(stream.Distributor())

if prefect.settings.PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED:
service_instances.append(TaskRunRecorder())

loop = asyncio.get_running_loop()

app.state.services = {
Expand Down
32 changes: 31 additions & 1 deletion src/prefect/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from websockets.legacy.server import WebSocketServer, WebSocketServerProtocol, serve

from prefect.events import Event
from prefect.events.clients import AssertingEventsClient
from prefect.events.clients import (
AssertingEventsClient,
AssertingPassthroughEventsClient,
)
from prefect.events.filters import EventFilter
from prefect.events.worker import EventsWorker
from prefect.server.api.server import SubprocessASGIServer
Expand Down Expand Up @@ -380,6 +383,19 @@ def asserting_events_worker(monkeypatch) -> Generator[EventsWorker, None, None]:
worker.drain()


@pytest.fixture
def asserting_and_emitting_events_worker(
monkeypatch,
) -> Generator[EventsWorker, None, None]:
worker = EventsWorker.instance(AssertingPassthroughEventsClient)
# Always yield the asserting worker when new instances are retrieved
monkeypatch.setattr(EventsWorker, "instance", lambda *_: worker)
try:
yield worker
finally:
worker.drain()


@pytest.fixture
async def events_pipeline(asserting_events_worker: EventsWorker):
class AssertingEventsPipeline(EventsPipeline):
Expand Down Expand Up @@ -415,6 +431,20 @@ async def wait_for_min_events():
yield AssertingEventsPipeline()


@pytest.fixture
async def emitting_events_pipeline(asserting_and_emitting_events_worker: EventsWorker):
class AssertingAndEmittingEventsPipeline(EventsPipeline):
@sync_compatible
async def process_events(self):
asserting_and_emitting_events_worker.wait_until_empty()
events = asserting_and_emitting_events_worker._client.pop_events()

messages = self.events_to_messages(events)
await self.process_messages(messages)

yield AssertingAndEmittingEventsPipeline()


@pytest.fixture
def reset_worker_events(asserting_events_worker: EventsWorker):
yield
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
PREFECT_API_SERVICES_LATE_RUNS_ENABLED,
PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_ENABLED,
PREFECT_API_SERVICES_SCHEDULER_ENABLED,
PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED,
PREFECT_API_SERVICES_TRIGGERS_ENABLED,
PREFECT_API_URL,
PREFECT_ASYNC_FETCH_STATE_RESULT,
Expand Down Expand Up @@ -335,6 +336,8 @@ def pytest_sessionstart(session):
# lock the DB during tests while writing events
PREFECT_API_SERVICES_EVENT_PERSISTER_ENABLED: False,
PREFECT_API_SERVICES_TRIGGERS_ENABLED: False,
# Disable the task run recorder service
PREFECT_API_SERVICES_TASK_RUN_RECORDER_ENABLED: False,
},
source=__file__,
)
Expand Down
6 changes: 5 additions & 1 deletion tests/events/client/test_events_related_from_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def test_flow():
assert f"prefect.flow-run.{flow_run.id}" not in related


async def test_gets_related_from_task_run_context(prefect_client):
async def test_gets_related_from_task_run_context(prefect_client, events_pipeline):
@task
async def test_task():
# Clear the FlowRunContext to simulated a task run in a remote worker.
Expand All @@ -149,10 +149,14 @@ async def test_flow():
return await test_task(return_state=True)

state = await test_flow(return_state=True)

await events_pipeline.process_events()

task_state = await state.result()

flow_run = await prefect_client.read_flow_run(state.state_details.flow_run_id)
db_flow = await prefect_client.read_flow(flow_run.flow_id)

task_run = await prefect_client.read_task_run(task_state.state_details.task_run_id)

related = await task_state.result()
Expand Down
5 changes: 4 additions & 1 deletion tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def test_flow():
}


async def test_run_logger_in_task(prefect_client):
async def test_run_logger_in_task(prefect_client, events_pipeline):
@task
def test_task():
return get_run_logger()
Expand All @@ -1013,6 +1013,9 @@ def test_flow():
flow_state = test_flow(return_state=True)
flow_run = await prefect_client.read_flow_run(flow_state.state_details.flow_run_id)
task_state = await flow_state.result()

await events_pipeline.process_events()

task_run = await prefect_client.read_task_run(task_state.state_details.task_run_id)
logger = await task_state.result()
assert logger.name == "prefect.task_runs"
Expand Down
16 changes: 13 additions & 3 deletions tests/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_instance_returns_instance_after_stop(self):

@pytest.mark.timeout(20)
@pytest.mark.usefixtures("use_hosted_api_server")
async def test_wait_for_task_run(self, prefect_client):
async def test_wait_for_task_run(self, prefect_client, emitting_events_pipeline):
"""This test will fail with a timeout error if waiting is not working correctly."""

@task
Expand All @@ -37,6 +37,8 @@ async def test_task():

await TaskRunWaiter.wait_for_task_run(task_run_id)

await emitting_events_pipeline.process_events()

task_run = await prefect_client.read_task_run(task_run_id)
assert task_run.state.is_completed()

Expand All @@ -59,7 +61,7 @@ async def test_task():

@pytest.mark.timeout(20)
@pytest.mark.usefixtures("use_hosted_api_server")
async def test_non_singleton_mode(self, prefect_client):
async def test_non_singleton_mode(self, prefect_client, emitting_events_pipeline):
waiter = TaskRunWaiter()
assert waiter is not TaskRunWaiter.instance()

Expand All @@ -72,14 +74,18 @@ async def test_task():

await waiter.wait_for_task_run(task_run_id)

await emitting_events_pipeline.process_events()

task_run = await prefect_client.read_task_run(task_run_id)
assert task_run.state.is_completed()

waiter.stop()

@pytest.mark.timeout(20)
@pytest.mark.usefixtures("use_hosted_api_server")
async def test_handles_concurrent_task_runs(self, prefect_client):
async def test_handles_concurrent_task_runs(
self, prefect_client, emitting_events_pipeline
):
@task
async def fast_task():
await asyncio.sleep(1)
Expand All @@ -96,6 +102,8 @@ async def slow_task():

await TaskRunWaiter.wait_for_task_run(task_run_id_1)

await emitting_events_pipeline.process_events()

task_run_1 = await prefect_client.read_task_run(task_run_id_1)
task_run_2 = await prefect_client.read_task_run(task_run_id_2)

Expand All @@ -104,6 +112,8 @@ async def slow_task():

await TaskRunWaiter.wait_for_task_run(task_run_id_2)

await emitting_events_pipeline.process_events()

task_run_1 = await prefect_client.read_task_run(task_run_id_1)
task_run_2 = await prefect_client.read_task_run(task_run_id_2)

Expand Down

0 comments on commit 2a797a7

Please sign in to comment.