diff --git a/docs/3.0rc/api-ref/rest-api/server/schema.json b/docs/3.0rc/api-ref/rest-api/server/schema.json index de44b3ca6d0e..9d65af771d98 100644 --- a/docs/3.0rc/api-ref/rest-api/server/schema.json +++ b/docs/3.0rc/api-ref/rest-api/server/schema.json @@ -4323,7 +4323,13 @@ "description": "Successful Response", "content": { "application/json": { - "schema": {} + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MinimalConcurrencyLimitResponse" + }, + "title": "Response Increment Concurrency Limits V1 Concurrency Limits Increment Post" + } } } }, diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index 507b5fe7573f..938de46b6b8a 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -939,6 +939,57 @@ async def delete_concurrency_limit_by_tag( else: raise + async def increment_v1_concurrency_slots( + self, + names: List[str], + task_run_id: UUID, + ) -> httpx.Response: + """ + Increment concurrency limit slots for the specified limits. + + Args: + names (List[str]): A list of limit names for which to increment limits. + task_run_id (UUID): The task run ID incrementing the limits. + """ + data = { + "names": names, + "task_run_id": str(task_run_id), + } + + return await self._client.post( + "/concurrency_limits/increment", + json=data, + ) + + async def decrement_v1_concurrency_slots( + self, + names: List[str], + task_run_id: UUID, + occupancy_seconds: float, + ) -> httpx.Response: + """ + Decrement concurrency limit slots for the specified limits. + + Args: + names (List[str]): A list of limit names to decrement. + task_run_id (UUID): The task run ID that incremented the limits. + occupancy_seconds (float): The duration in seconds that the limits + were held. + + Returns: + httpx.Response: The HTTP response from the server. + """ + data = { + "names": names, + "task_run_id": str(task_run_id), + "occupancy_seconds": occupancy_seconds, + } + + return await self._client.post( + "/concurrency_limits/decrement", + json=data, + ) + async def create_work_queue( self, name: str, @@ -4116,3 +4167,27 @@ def release_concurrency_slots( "occupancy_seconds": occupancy_seconds, }, ) + + def decrement_v1_concurrency_slots( + self, names: List[str], occupancy_seconds: float, task_run_id: UUID + ) -> httpx.Response: + """ + Release the specified concurrency limits. + + Args: + names (List[str]): A list of limit names to decrement. + occupancy_seconds (float): The duration in seconds that the slots + were held. + task_run_id (UUID): The task run ID that incremented the limits. + + Returns: + httpx.Response: The HTTP response from the server. + """ + return self._client.post( + "/concurrency_limits/decrement", + json={ + "names": names, + "occupancy_seconds": occupancy_seconds, + "task_run_id": str(task_run_id), + }, + ) diff --git a/src/prefect/concurrency/v1/__init__.py b/src/prefect/concurrency/v1/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/prefect/concurrency/v1/asyncio.py b/src/prefect/concurrency/v1/asyncio.py new file mode 100644 index 000000000000..d7fda4519051 --- /dev/null +++ b/src/prefect/concurrency/v1/asyncio.py @@ -0,0 +1,143 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncGenerator, List, Optional, Union, cast +from uuid import UUID + +import anyio +import httpx +import pendulum + +from ...client.schemas.responses import MinimalConcurrencyLimitResponse + +try: + from pendulum import Interval +except ImportError: + # pendulum < 3 + from pendulum.period import Period as Interval # type: ignore + +from prefect.client.orchestration import get_client + +from .context import ConcurrencyContext +from .events import ( + _emit_concurrency_acquisition_events, + _emit_concurrency_release_events, +) +from .services import ConcurrencySlotAcquisitionService + + +class ConcurrencySlotAcquisitionError(Exception): + """Raised when an unhandlable occurs while acquiring concurrency slots.""" + + +class AcquireConcurrencySlotTimeoutError(TimeoutError): + """Raised when acquiring a concurrency slot times out.""" + + +@asynccontextmanager +async def concurrency( + names: Union[str, List[str]], + task_run_id: UUID, + timeout_seconds: Optional[float] = None, +) -> AsyncGenerator[None, None]: + """A context manager that acquires and releases concurrency slots from the + given concurrency limits. + + Args: + names: The names of the concurrency limits to acquire slots from. + task_run_id: The name of the task_run_id that is incrementing the slots. + timeout_seconds: The number of seconds to wait for the slots to be acquired before + raising a `TimeoutError`. A timeout of `None` will wait indefinitely. + + Raises: + TimeoutError: If the slots are not acquired within the given timeout. + + Example: + A simple example of using the async `concurrency` context manager: + ```python + from prefect.concurrency.v1.asyncio import concurrency + + async def resource_heavy(): + async with concurrency("test", task_run_id): + print("Resource heavy task") + + async def main(): + await resource_heavy() + ``` + """ + if not names: + yield + return + + names_normalized: List[str] = names if isinstance(names, list) else [names] + + limits = await _acquire_concurrency_slots( + names_normalized, + task_run_id=task_run_id, + timeout_seconds=timeout_seconds, + ) + acquisition_time = pendulum.now("UTC") + emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + + try: + yield + finally: + occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time)) + try: + await _release_concurrency_slots( + names_normalized, task_run_id, occupancy_period.total_seconds() + ) + except anyio.get_cancelled_exc_class(): + # The task was cancelled before it could release the slots. Add the + # slots to the cleanup list so they can be released when the + # concurrency context is exited. + if ctx := ConcurrencyContext.get(): + ctx.cleanup_slots.append( + (names_normalized, occupancy_period.total_seconds(), task_run_id) + ) + + _emit_concurrency_release_events(limits, emitted_events, task_run_id) + + +async def _acquire_concurrency_slots( + names: List[str], + task_run_id: UUID, + timeout_seconds: Optional[float] = None, +) -> List[MinimalConcurrencyLimitResponse]: + service = ConcurrencySlotAcquisitionService.instance(frozenset(names)) + future = service.send((task_run_id, timeout_seconds)) + response_or_exception = await asyncio.wrap_future(future) + + if isinstance(response_or_exception, Exception): + if isinstance(response_or_exception, TimeoutError): + raise AcquireConcurrencySlotTimeoutError( + f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)" + ) from response_or_exception + + raise ConcurrencySlotAcquisitionError( + f"Unable to acquire concurrency limits {names!r}" + ) from response_or_exception + + return _response_to_concurrency_limit_response(response_or_exception) + + +async def _release_concurrency_slots( + names: List[str], + task_run_id: UUID, + occupancy_seconds: float, +) -> List[MinimalConcurrencyLimitResponse]: + async with get_client() as client: + response = await client.decrement_v1_concurrency_slots( + names=names, + task_run_id=task_run_id, + occupancy_seconds=occupancy_seconds, + ) + return _response_to_concurrency_limit_response(response) + + +def _response_to_concurrency_limit_response( + response: httpx.Response, +) -> List[MinimalConcurrencyLimitResponse]: + data = response.json() or [] + return [ + MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data + ] diff --git a/src/prefect/concurrency/v1/context.py b/src/prefect/concurrency/v1/context.py new file mode 100644 index 000000000000..f413c84ed1f4 --- /dev/null +++ b/src/prefect/concurrency/v1/context.py @@ -0,0 +1,27 @@ +from contextvars import ContextVar +from typing import List, Tuple +from uuid import UUID + +from prefect.client.orchestration import get_client +from prefect.context import ContextModel, Field + + +class ConcurrencyContext(ContextModel): + __var__: ContextVar = ContextVar("concurrency_v1") + + # Track the limits that have been acquired but were not able to be released + # due to cancellation or some other error. These limits are released when + # the context manager exits. + cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list) + + def __exit__(self, *exc_info): + if self.cleanup_slots: + with get_client(sync_client=True) as client: + for names, occupancy_seconds, task_run_id in self.cleanup_slots: + client.decrement_v1_concurrency_slots( + names=names, + occupancy_seconds=occupancy_seconds, + task_run_id=task_run_id, + ) + + return super().__exit__(*exc_info) diff --git a/src/prefect/concurrency/v1/events.py b/src/prefect/concurrency/v1/events.py new file mode 100644 index 000000000000..3fa5193e6fea --- /dev/null +++ b/src/prefect/concurrency/v1/events.py @@ -0,0 +1,61 @@ +from typing import Dict, List, Literal, Optional, Union +from uuid import UUID + +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.events import Event, RelatedResource, emit_event + + +def _emit_concurrency_event( + phase: Union[Literal["acquired"], Literal["released"]], + primary_limit: MinimalConcurrencyLimitResponse, + related_limits: List[MinimalConcurrencyLimitResponse], + task_run_id: UUID, + follows: Union[Event, None] = None, +) -> Union[Event, None]: + resource: Dict[str, str] = { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}", + "prefect.resource.name": primary_limit.name, + "limit": str(primary_limit.limit), + "task_run_id": str(task_run_id), + } + + related = [ + RelatedResource.model_validate( + { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{limit.id}", + "prefect.resource.role": "concurrency-limit", + } + ) + for limit in related_limits + if limit.id != primary_limit.id + ] + + return emit_event( + f"prefect.concurrency-limit.v1.{phase}", + resource=resource, + related=related, + follows=follows, + ) + + +def _emit_concurrency_acquisition_events( + limits: List[MinimalConcurrencyLimitResponse], + task_run_id: UUID, +) -> Dict[UUID, Optional[Event]]: + events = {} + for limit in limits: + event = _emit_concurrency_event("acquired", limit, limits, task_run_id) + events[limit.id] = event + + return events + + +def _emit_concurrency_release_events( + limits: List[MinimalConcurrencyLimitResponse], + events: Dict[UUID, Optional[Event]], + task_run_id: UUID, +) -> None: + for limit in limits: + _emit_concurrency_event( + "released", limit, limits, task_run_id, events[limit.id] + ) diff --git a/src/prefect/concurrency/v1/services.py b/src/prefect/concurrency/v1/services.py new file mode 100644 index 000000000000..1199c7ef3373 --- /dev/null +++ b/src/prefect/concurrency/v1/services.py @@ -0,0 +1,116 @@ +import asyncio +import concurrent.futures +from contextlib import asynccontextmanager +from json import JSONDecodeError +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + FrozenSet, + Optional, + Tuple, +) +from uuid import UUID + +import httpx +from starlette import status + +from prefect._internal.concurrency import logger +from prefect._internal.concurrency.services import QueueService +from prefect.client.orchestration import get_client +from prefect.utilities.timeout import timeout_async + +if TYPE_CHECKING: + from prefect.client.orchestration import PrefectClient + + +class ConcurrencySlotAcquisitionServiceError(Exception): + """Raised when an error occurs while acquiring concurrency slots.""" + + +class ConcurrencySlotAcquisitionService(QueueService): + def __init__(self, concurrency_limit_names: FrozenSet[str]): + super().__init__(concurrency_limit_names) + self._client: "PrefectClient" + self.concurrency_limit_names = sorted(list(concurrency_limit_names)) + + @asynccontextmanager + async def _lifespan(self) -> AsyncGenerator[None, None]: + async with get_client() as client: + self._client = client + yield + + async def _handle( + self, + item: Tuple[ + UUID, + concurrent.futures.Future, + Optional[float], + ], + ) -> None: + task_run_id, future, timeout_seconds = item + try: + response = await self.acquire_slots(task_run_id, timeout_seconds) + except Exception as exc: + # If the request to the increment endpoint fails in a non-standard + # way, we need to set the future's result so that the caller can + # handle the exception and then re-raise. + future.set_result(exc) + raise exc + else: + future.set_result(response) + + async def acquire_slots( + self, + task_run_id: UUID, + timeout_seconds: Optional[float] = None, + ) -> httpx.Response: + with timeout_async(seconds=timeout_seconds): + while True: + try: + response = await self._client.increment_v1_concurrency_slots( + task_run_id=task_run_id, + names=self.concurrency_limit_names, + ) + except Exception as exc: + if ( + isinstance(exc, httpx.HTTPStatusError) + and exc.response.status_code == status.HTTP_423_LOCKED + ): + retry_after = exc.response.headers.get("Retry-After") + if retry_after: + retry_after = float(retry_after) + await asyncio.sleep(retry_after) + else: + # We received a 423 but no Retry-After header. This + # should indicate that the server told us to abort + # because the concurrency limit is set to 0, i.e. + # effectively disabled. + try: + reason = exc.response.json()["detail"] + except (JSONDecodeError, KeyError): + logger.error( + "Failed to parse response from concurrency limit 423 Locked response: %s", + exc.response.content, + ) + reason = "Concurrency limit is locked (server did not specify the reason)" + raise ConcurrencySlotAcquisitionServiceError( + reason + ) from exc + + else: + raise exc # type: ignore + else: + return response + + def send(self, item: Tuple[UUID, Optional[float]]) -> concurrent.futures.Future: + with self._lock: + if self._stopped: + raise RuntimeError("Cannot put items in a stopped service instance.") + + logger.debug("Service %r enqueuing item %r", self, item) + future: concurrent.futures.Future = concurrent.futures.Future() + + task_run_id, timeout_seconds = item + self._queue.put_nowait((task_run_id, future, timeout_seconds)) + + return future diff --git a/src/prefect/concurrency/v1/sync.py b/src/prefect/concurrency/v1/sync.py new file mode 100644 index 000000000000..9da49a87bb90 --- /dev/null +++ b/src/prefect/concurrency/v1/sync.py @@ -0,0 +1,92 @@ +from contextlib import contextmanager +from typing import ( + Generator, + List, + Optional, + TypeVar, + Union, + cast, +) +from uuid import UUID + +import pendulum + +from ...client.schemas.responses import MinimalConcurrencyLimitResponse +from ..sync import _call_async_function_from_sync + +try: + from pendulum import Interval +except ImportError: + # pendulum < 3 + from pendulum.period import Period as Interval # type: ignore + +from .asyncio import ( + _acquire_concurrency_slots, + _release_concurrency_slots, +) +from .events import ( + _emit_concurrency_acquisition_events, + _emit_concurrency_release_events, +) + +T = TypeVar("T") + + +@contextmanager +def concurrency( + names: Union[str, List[str]], + task_run_id: UUID, + timeout_seconds: Optional[float] = None, +) -> Generator[None, None, None]: + """ + A context manager that acquires and releases concurrency slots from the + given concurrency limits. + + Args: + names: The names of the concurrency limits to acquire. + task_run_id: The task run ID acquiring the limits. + timeout_seconds: The number of seconds to wait to acquire the limits before + raising a `TimeoutError`. A timeout of `None` will wait indefinitely. + + Raises: + TimeoutError: If the limits are not acquired within the given timeout. + + Example: + A simple example of using the sync `concurrency` context manager: + ```python + from prefect.concurrency.v1.sync import concurrency + + def resource_heavy(): + with concurrency("test"): + print("Resource heavy task") + + def main(): + resource_heavy() + ``` + """ + if not names: + yield + return + + names = names if isinstance(names, list) else [names] + + limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync( + _acquire_concurrency_slots, + names, + timeout_seconds=timeout_seconds, + task_run_id=task_run_id, + ) + acquisition_time = pendulum.now("UTC") + emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id) + + try: + yield + finally: + occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time) + _call_async_function_from_sync( + _release_concurrency_slots, + names, + task_run_id, + occupancy_period.total_seconds(), + ) + _emit_concurrency_release_events(limits, emitted_events, task_run_id) diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 7b72fb20b511..5e11f25b1c4f 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -30,6 +30,7 @@ from prefect.client.schemas.filters import FlowRunFilter from prefect.client.schemas.sorting import FlowRunSort from prefect.concurrency.context import ConcurrencyContext +from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1 from prefect.context import FlowRunContext, SyncClientContext, TagsContext from prefect.exceptions import ( Abort, @@ -506,6 +507,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): task_runner=task_runner, ) ) + stack.enter_context(ConcurrencyContextV1()) stack.enter_context(ConcurrencyContext()) # set the logger to the flow run logger diff --git a/src/prefect/server/api/concurrency_limits.py b/src/prefect/server/api/concurrency_limits.py index 3fe71fc71935..9414ff018a15 100644 --- a/src/prefect/server/api/concurrency_limits.py +++ b/src/prefect/server/api/concurrency_limits.py @@ -2,7 +2,7 @@ Routes for interacting with concurrency limit objects. """ -from typing import List, Optional +from typing import List, Optional, Sequence from uuid import UUID import pendulum @@ -11,6 +11,7 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas +from prefect.server.api.concurrency_limits_v2 import MinimalConcurrencyLimitResponse from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.models import concurrency_limits @@ -95,7 +96,7 @@ async def read_concurrency_limits( limit: int = dependencies.LimitBody(), offset: int = Body(0, ge=0), db: PrefectDBInterface = Depends(provide_database_interface), -) -> List[schemas.core.ConcurrencyLimit]: +) -> Sequence[schemas.core.ConcurrencyLimit]: """ Query for concurrency limits. @@ -180,12 +181,12 @@ async def increment_concurrency_limits_v1( ..., description="The ID of the task run acquiring the slot" ), db: PrefectDBInterface = Depends(provide_database_interface), -): - applied_limits = [] +) -> List[MinimalConcurrencyLimitResponse]: + applied_limits = {} async with db.session_context(begin_transaction=True) as session: try: - applied_limits = [] + applied_limits = {} filtered_limits = ( await concurrency_limits.filter_concurrency_limits_for_orchestration( session, tags=names @@ -196,7 +197,7 @@ async def increment_concurrency_limits_v1( limit = cl.concurrency_limit if limit == 0: # limits of 0 will deadlock, and the transition needs to abort - for stale_tag in applied_limits: + for stale_tag in applied_limits.keys(): stale_limit = run_limits.get(stale_tag, None) active_slots = set(stale_limit.active_slots) active_slots.discard(str(task_run_id)) @@ -210,7 +211,7 @@ async def increment_concurrency_limits_v1( ) elif len(cl.active_slots) >= limit: # if the limit has already been reached, delay the transition - for stale_tag in applied_limits: + for stale_tag in applied_limits.keys(): stale_limit = run_limits.get(stale_tag, None) active_slots = set(stale_limit.active_slots) active_slots.discard(str(task_run_id)) @@ -222,12 +223,12 @@ async def increment_concurrency_limits_v1( ) else: # log the TaskRun ID to active_slots - applied_limits.append(tag) + applied_limits[tag] = cl active_slots = set(cl.active_slots) active_slots.add(str(task_run_id)) cl.active_slots = list(active_slots) except Exception as e: - for tag in applied_limits: + for tag in applied_limits.keys(): cl = await concurrency_limits.read_concurrency_limit_by_tag( session, tag ) @@ -248,6 +249,12 @@ async def increment_concurrency_limits_v1( ) else: raise + return [ + MinimalConcurrencyLimitResponse( + name=limit.tag, limit=limit.concurrency_limit, id=limit.id + ) + for limit in applied_limits.values() + ] @router.post("/decrement") @@ -269,3 +276,10 @@ async def decrement_concurrency_limits_v1( active_slots = set(cl.active_slots) active_slots.discard(str(task_run_id)) cl.active_slots = list(active_slots) + + return [ + MinimalConcurrencyLimitResponse( + name=limit.tag, limit=limit.concurrency_limit, id=limit.id + ) + for limit in run_limits.values() + ] diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 771d72d612b7..65f1d19f1587 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -33,9 +33,10 @@ from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput -from prefect.concurrency.asyncio import concurrency as aconcurrency from prefect.concurrency.context import ConcurrencyContext -from prefect.concurrency.sync import concurrency +from prefect.concurrency.v1.asyncio import concurrency as aconcurrency +from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1 +from prefect.concurrency.v1.sync import concurrency from prefect.context import ( AsyncClientContext, FlowRunContext, @@ -589,6 +590,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): client=client, ) ) + stack.enter_context(ConcurrencyContextV1()) stack.enter_context(ConcurrencyContext()) self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore @@ -754,9 +756,7 @@ def call_task_fn( if self.task.tags: # Acquire a concurrency slot for each tag, but only if a limit # matching the tag already exists. - with concurrency( - list(self.task.tags), occupy=1, create_if_missing=False - ): + with concurrency(list(self.task.tags), self.task_run.id): result = call_with_parameters(self.task.fn, parameters) else: result = call_with_parameters(self.task.fn, parameters) @@ -1250,9 +1250,7 @@ async def call_task_fn( if self.task.tags: # Acquire a concurrency slot for each tag, but only if a limit # matching the tag already exists. - async with aconcurrency( - list(self.task.tags), occupy=1, create_if_missing=False - ): + async with aconcurrency(list(self.task.tags), self.task_run.id): result = await call_with_parameters(self.task.fn, parameters) else: result = await call_with_parameters(self.task.fn, parameters) diff --git a/tests/concurrency/v1/__init__.py b/tests/concurrency/v1/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/concurrency/v1/conftest.py b/tests/concurrency/v1/conftest.py new file mode 100644 index 000000000000..438790747ef8 --- /dev/null +++ b/tests/concurrency/v1/conftest.py @@ -0,0 +1,29 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from prefect.server.models.concurrency_limits import create_concurrency_limit +from prefect.server.schemas.core import ConcurrencyLimit + + +@pytest.fixture +async def v1_concurrency_limit(session: AsyncSession) -> ConcurrencyLimit: + concurrency_limit = await create_concurrency_limit( + session=session, + concurrency_limit=ConcurrencyLimit(tag="test", concurrency_limit=1), + ) + + await session.commit() + + return ConcurrencyLimit.model_validate(concurrency_limit, from_attributes=True) + + +@pytest.fixture +async def other_v1_concurrency_limit(session: AsyncSession) -> ConcurrencyLimit: + concurrency_limit = await create_concurrency_limit( + session=session, + concurrency_limit=ConcurrencyLimit(tag="other", concurrency_limit=1), + ) + + await session.commit() + + return ConcurrencyLimit.model_validate(concurrency_limit, from_attributes=True) diff --git a/tests/concurrency/v1/test_concurrency_asyncio.py b/tests/concurrency/v1/test_concurrency_asyncio.py new file mode 100644 index 000000000000..2a3941522622 --- /dev/null +++ b/tests/concurrency/v1/test_concurrency_asyncio.py @@ -0,0 +1,277 @@ +from unittest import mock +from uuid import UUID + +import pytest +from httpx import HTTPStatusError, Request, Response +from starlette import status + +from prefect import flow, task +from prefect.concurrency.v1.asyncio import ( + ConcurrencySlotAcquisitionError, + _acquire_concurrency_slots, + _release_concurrency_slots, + concurrency, +) +from prefect.events.clients import AssertingEventsClient +from prefect.events.worker import EventsWorker +from prefect.server.schemas.core import ConcurrencyLimit + + +async def test_concurrency_orchestrates_api(v1_concurrency_limit: ConcurrencyLimit): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + async def resource_heavy(): + nonlocal executed + async with concurrency("test", task_run_id): + executed = True + + assert not executed + + with mock.patch( + "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", + wraps=_acquire_concurrency_slots, + ) as acquire_spy: + with mock.patch( + "prefect.concurrency.v1.asyncio._release_concurrency_slots", + wraps=_release_concurrency_slots, + ) as release_spy: + await resource_heavy() + + acquire_spy.assert_called_once_with( + ["test"], task_run_id=task_run_id, timeout_seconds=None + ) + + # On release, we calculate how many seconds the slots were occupied + # for, so here we really just want to make sure that the value + # passed as `occupy_seconds` is > 0. + + ( + names, + _task_run_id, + occupy_seconds, + ) = release_spy.call_args[0] + assert names == ["test"] + assert _task_run_id == task_run_id + assert occupy_seconds > 0 + + assert executed + + +async def test_concurrency_can_be_used_within_a_flow( + concurrency_limit: ConcurrencyLimit, +): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + @task + async def resource_heavy(): + nonlocal executed + async with concurrency("test", task_run_id): + executed = True + + @flow + async def my_flow(): + await resource_heavy() + + assert not executed + + await my_flow() + + assert executed + + +async def test_concurrency_emits_events( + v1_concurrency_limit: ConcurrencyLimit, + other_v1_concurrency_limit: ConcurrencyLimit, + asserting_events_worker: EventsWorker, + mock_should_emit_events, + reset_worker_events, +): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + async def resource_heavy(): + nonlocal executed + async with concurrency(["test", "other"], task_run_id): + executed = True + + await resource_heavy() + assert executed + + await asserting_events_worker.drain() + assert isinstance(asserting_events_worker._client, AssertingEventsClient) + assert len(asserting_events_worker._client.events) == 4 # 2 acquire, 2 release + + # Check the events for the `test` concurrency_limit. + for phase in ["acquired", "released"]: + event = next( + filter( + lambda e: e.event == f"prefect.concurrency-limit.v1.{phase}" + and e.resource.id + == f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + asserting_events_worker._client.events, + ) + ) + + assert dict(event.resource) == { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + "prefect.resource.name": v1_concurrency_limit.tag, + "limit": str(v1_concurrency_limit.concurrency_limit), + "task_run_id": "00000000-0000-0000-0000-000000000000", + } + + # Since they were used together we expect that the `test` limit events + # should also include the `other` limit as a related resource. + + assert len(event.related) == 1 + assert dict(event.related[0]) == { + "prefect.resource.id": ( + f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}" + ), + "prefect.resource.role": "concurrency-limit", + } + + # Check the events for the `other` concurrency_limit. + for phase in ["acquired", "released"]: + event = next( + filter( + lambda e: e.event == f"prefect.concurrency-limit.v1.{phase}" + and e.resource.id + == f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}", + asserting_events_worker._client.events, + ) + ) + + assert dict(event.resource) == { + "prefect.resource.id": ( + f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}" + ), + "prefect.resource.name": other_v1_concurrency_limit.tag, + "limit": str(other_v1_concurrency_limit.concurrency_limit), + "task_run_id": "00000000-0000-0000-0000-000000000000", + } + + # Since they were used together we expect that the `other` limit events + # should also include the `test` limit as a related resource. + + assert len(event.related) == 1 + assert dict(event.related[0]) == { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + "prefect.resource.role": "concurrency-limit", + } + + +@pytest.fixture +def mock_increment_concurrency_slots(monkeypatch): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + headers={"Retry-After": "0.01"}, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) + + monkeypatch.setattr( + "prefect.client.orchestration.PrefectClient.increment_v1_concurrency_slots", + mocked_increment_concurrency_slots, + ) + + +@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots") +async def test_concurrency_respects_timeout(): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*"): + async with concurrency("test", task_run_id, timeout_seconds=0.01): + print("should not be executed") + + +@pytest.fixture +def mock_increment_concurrency_locked_with_no_retry_after(monkeypatch): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) + + monkeypatch.setattr( + "prefect.client.orchestration.PrefectClient.increment_v1_concurrency_slots", + mocked_increment_concurrency_slots, + ) + + +@pytest.mark.usefixtures( + "concurrency_limit", "mock_increment_concurrency_locked_with_no_retry_after" +) +async def test_concurrency_raises_when_locked_with_no_retry_after(): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + with pytest.raises(ConcurrencySlotAcquisitionError): + async with concurrency("test", task_run_id): + print("should not be executed") + + +@pytest.fixture +def mock_increment_concurrency_locked_with_details_and_no_retry_after( + monkeypatch, +): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + json={"details": "It's broken"}, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) + + monkeypatch.setattr( + "prefect.client.orchestration.PrefectClient.increment_v1_concurrency_slots", + mocked_increment_concurrency_slots, + ) + + +@pytest.mark.usefixtures( + "concurrency_limit", + "mock_increment_concurrency_locked_with_details_and_no_retry_after", +) +async def test_concurrency_raises_when_locked_with_details_and_no_retry_after(): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + with pytest.raises(ConcurrencySlotAcquisitionError): + async with concurrency("test", task_run_id): + print("should not be executed") + + +@pytest.mark.parametrize("names", [[], None]) +async def test_concurrency_without_limit_names(names): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + async def resource_heavy(): + nonlocal executed + async with concurrency(names, task_run_id): + executed = True + + assert not executed + + with mock.patch( + "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", + wraps=lambda *args, **kwargs: None, + ) as acquire_spy: + with mock.patch( + "prefect.concurrency.v1.asyncio._release_concurrency_slots", + wraps=lambda *args, **kwargs: None, + ) as release_spy: + await resource_heavy() + + acquire_spy.assert_not_called() + release_spy.assert_not_called() + + assert executed diff --git a/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py new file mode 100644 index 000000000000..bcadee6bdd30 --- /dev/null +++ b/tests/concurrency/v1/test_concurrency_limit_acquisition_service.py @@ -0,0 +1,118 @@ +import asyncio +from unittest import mock +from uuid import UUID + +import pytest +from httpx import HTTPStatusError, Request, Response + +from prefect.client.orchestration import get_client +from prefect.concurrency.v1.services import ConcurrencySlotAcquisitionService + + +@pytest.fixture +async def mocked_client(test_database_connection_url): + async with get_client() as client: + with mock.patch.object(client, "increment_v1_concurrency_slots", autospec=True): + + class ClientWrapper: + def __init__(self, client): + self.client = client + + async def __aenter__(self): + return self.client + + async def __aexit__(self, *args): + pass + + wrapped_client = ClientWrapper(client) + with mock.patch( + "prefect.concurrency.v1.services.get_client", lambda: wrapped_client + ): + yield wrapped_client + + +async def test_returns_successful_response(mocked_client): + response = Response(200) + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + mocked_method = mocked_client.client.increment_v1_concurrency_slots + mocked_method.return_value = response + + expected_names = sorted(["api", "database"]) + + service = ConcurrencySlotAcquisitionService.instance(frozenset(expected_names)) + future = service.send((task_run_id, None)) + await service.drain() + returned_response = await asyncio.wrap_future(future) + assert returned_response == response + + mocked_method.assert_called_once_with( + task_run_id=task_run_id, + names=expected_names, + ) + + +async def test_retries_failed_call_respects_retry_after_header(mocked_client): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + responses = [ + HTTPStatusError( + "Limit is locked", + request=Request("get", "/"), + response=Response(423, headers={"Retry-After": "2"}), + ), + Response(200), + ] + + mocked_client.client.increment_v1_concurrency_slots.side_effect = responses + + limit_names = sorted(["api", "database"]) + service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) + + with mock.patch("prefect.concurrency.v1.asyncio.asyncio.sleep") as sleep: + future = service.send((task_run_id, None)) + service.drain() + returned_response = await asyncio.wrap_future(future) + + assert returned_response == responses[1] + + sleep.assert_called_once_with( + float(responses[0].response.headers["Retry-After"]) + ) + assert mocked_client.client.increment_v1_concurrency_slots.call_count == 2 + + +async def test_failed_call_status_code_not_retryable_returns_exception(mocked_client): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + response = HTTPStatusError( + "Too many requests", + request=Request("get", "/"), + response=Response(500, headers={"Retry-After": "2"}), + ) + + mocked_client.client.increment_v1_concurrency_slots.return_value = response + + limit_names = sorted(["api", "database"]) + service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) + + future = service.send((task_run_id, None)) + await service.drain() + exception = await asyncio.wrap_future(future) + + assert isinstance(exception, HTTPStatusError) + assert exception == response + + +async def test_basic_exception_returns_exception(mocked_client): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + exc = Exception("Something went wrong") + mocked_client.client.increment_v1_concurrency_slots.side_effect = exc + + limit_names = sorted(["api", "database"]) + service = ConcurrencySlotAcquisitionService.instance(frozenset(limit_names)) + + future = service.send((task_run_id, None)) + await service.drain() + exception = await asyncio.wrap_future(future) + + assert isinstance(exception, Exception) + assert exception == exc diff --git a/tests/concurrency/v1/test_concurrency_sync.py b/tests/concurrency/v1/test_concurrency_sync.py new file mode 100644 index 000000000000..e56f21afd494 --- /dev/null +++ b/tests/concurrency/v1/test_concurrency_sync.py @@ -0,0 +1,216 @@ +from unittest import mock +from uuid import UUID + +import pytest +from httpx import HTTPStatusError, Request, Response +from starlette import status + +from prefect import flow, task +from prefect.concurrency.v1.asyncio import ( + _acquire_concurrency_slots, + _release_concurrency_slots, +) +from prefect.concurrency.v1.sync import concurrency +from prefect.events.clients import AssertingEventsClient +from prefect.events.worker import EventsWorker +from prefect.server.schemas.core import ConcurrencyLimit + + +def test_concurrency_orchestrates_api(concurrency_limit: ConcurrencyLimit): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + def resource_heavy(): + nonlocal executed + with concurrency("test", task_run_id): + executed = True + + assert not executed + + with mock.patch( + "prefect.concurrency.v1.sync._acquire_concurrency_slots", + wraps=_acquire_concurrency_slots, + ) as acquire_spy: + with mock.patch( + "prefect.concurrency.v1.sync._release_concurrency_slots", + wraps=_release_concurrency_slots, + ) as release_spy: + resource_heavy() + + acquire_spy.assert_called_once_with( + ["test"], timeout_seconds=None, task_run_id=task_run_id + ) + + names, _task_run_id, occupy_seconds = release_spy.call_args[0] + assert names == ["test"] + assert _task_run_id == task_run_id + assert occupy_seconds > 0 + + assert executed + + +def test_concurrency_emits_events( + v1_concurrency_limit: ConcurrencyLimit, + other_v1_concurrency_limit: ConcurrencyLimit, + asserting_events_worker: EventsWorker, + mock_should_emit_events, + reset_worker_events, +): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + def resource_heavy(): + with concurrency(["test", "other"], task_run_id): + pass + + resource_heavy() + + asserting_events_worker.drain() + assert isinstance(asserting_events_worker._client, AssertingEventsClient) + assert len(asserting_events_worker._client.events) == 4 # 2 acquire, 2 release + + for phase in ["acquired", "released"]: + event = next( + filter( + lambda e: e.event == f"prefect.concurrency-limit.v1.{phase}" + and e.resource.id + == f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + asserting_events_worker._client.events, + ) + ) + + assert dict(event.resource) == { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + "prefect.resource.name": v1_concurrency_limit.tag, + "task_run_id": str(task_run_id), + "limit": str(v1_concurrency_limit.concurrency_limit), + } + + assert len(event.related) == 1 + assert dict(event.related[0]) == { + "prefect.resource.id": ( + f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}" + ), + "prefect.resource.role": "concurrency-limit", + } + + for phase in ["acquired", "released"]: + event = next( + filter( + lambda e: e.event == f"prefect.concurrency-limit.v1.{phase}" + and e.resource.id + == f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}", + asserting_events_worker._client.events, + ) + ) + + assert dict(event.resource) == { + "prefect.resource.id": ( + f"prefect.concurrency-limit.v1.{other_v1_concurrency_limit.id}" + ), + "prefect.resource.name": other_v1_concurrency_limit.tag, + "task_run_id": str(task_run_id), + "limit": str(other_v1_concurrency_limit.concurrency_limit), + } + + assert len(event.related) == 1 + assert dict(event.related[0]) == { + "prefect.resource.id": f"prefect.concurrency-limit.v1.{v1_concurrency_limit.id}", + "prefect.resource.role": "concurrency-limit", + } + + +def test_concurrency_can_be_used_within_a_flow( + concurrency_limit: ConcurrencyLimit, +): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + @task + def resource_heavy(): + nonlocal executed + with concurrency("test", task_run_id): + executed = True + + @flow + def my_flow(): + resource_heavy() + + assert not executed + + my_flow() + + assert executed + + +async def test_concurrency_can_be_used_while_event_loop_is_running( + concurrency_limit: ConcurrencyLimit, +): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + def resource_heavy(): + nonlocal executed + with concurrency("test", task_run_id): + executed = True + + assert not executed + + resource_heavy() + + assert executed + + +@pytest.fixture +def mock_increment_concurrency_slots(monkeypatch): + async def mocked_increment_concurrency_slots(*args, **kwargs): + response = Response( + status_code=status.HTTP_423_LOCKED, + headers={"Retry-After": "0.01"}, + ) + raise HTTPStatusError( + message="Locked", + request=Request("GET", "http://test.com"), + response=response, + ) + + monkeypatch.setattr( + "prefect.client.orchestration.PrefectClient.increment_v1_concurrency_slots", + mocked_increment_concurrency_slots, + ) + + +@pytest.mark.usefixtures("concurrency_limit", "mock_increment_concurrency_slots") +def test_concurrency_respects_timeout(): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + with pytest.raises(TimeoutError, match=".*timed out after 0.01 second(s)*."): + with concurrency("test", task_run_id=task_run_id, timeout_seconds=0.01): + print("should not be executed") + + +@pytest.mark.parametrize("names", [[], None]) +def test_concurrency_without_limit_names_sync(names): + executed = False + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + def resource_heavy(): + nonlocal executed + with concurrency(names=names, task_run_id=task_run_id): + executed = True + + assert not executed + + with mock.patch( + "prefect.concurrency.v1.sync._acquire_concurrency_slots", + wraps=lambda *args, **kwargs: None, + ) as acquire_spy: + with mock.patch( + "prefect.concurrency.v1.sync._release_concurrency_slots", + wraps=lambda *args, **kwargs: None, + ) as release_spy: + resource_heavy() + + acquire_spy.assert_not_called() + release_spy.assert_not_called() + + assert executed diff --git a/tests/concurrency/v1/test_context.py b/tests/concurrency/v1/test_context.py new file mode 100644 index 000000000000..50a013a4680d --- /dev/null +++ b/tests/concurrency/v1/test_context.py @@ -0,0 +1,66 @@ +import asyncio +import time +from uuid import UUID + +import pytest + +from prefect.client.orchestration import PrefectClient, get_client +from prefect.concurrency.v1.asyncio import concurrency as aconcurrency +from prefect.concurrency.v1.context import ConcurrencyContext +from prefect.concurrency.v1.sync import concurrency +from prefect.server.schemas.core import ConcurrencyLimit +from prefect.utilities.asyncutils import run_coro_as_sync +from prefect.utilities.timeout import timeout, timeout_async + + +async def test_concurrency_context_releases_slots_async( + v1_concurrency_limit: ConcurrencyLimit, prefect_client: PrefectClient +): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + async def expensive_task(): + async with aconcurrency(v1_concurrency_limit.tag, task_run_id): + response = await prefect_client.read_concurrency_limit_by_tag( + v1_concurrency_limit.tag + ) + assert response.active_slots == [task_run_id] + + # Occupy the slot for longer than the timeout + await asyncio.sleep(1) + + with pytest.raises(TimeoutError): + with timeout_async(seconds=0.5): + with ConcurrencyContext(): + await expensive_task() + + response = await prefect_client.read_concurrency_limit_by_tag( + v1_concurrency_limit.tag + ) + assert response.active_slots == [] + + +async def test_concurrency_context_releases_slots_sync( + v1_concurrency_limit: ConcurrencyLimit, prefect_client: PrefectClient +): + task_run_id = UUID("00000000-0000-0000-0000-000000000000") + + def expensive_task(): + with concurrency(v1_concurrency_limit.tag, task_run_id): + client = get_client() + response = run_coro_as_sync( + client.read_concurrency_limit_by_tag(v1_concurrency_limit.tag) + ) + assert response and response.active_slots == [task_run_id] + + # Occupy the slot for longer than the timeout + time.sleep(1) + + with pytest.raises(TimeoutError): + with timeout(seconds=0.5): + with ConcurrencyContext(): + expensive_task() + + response = await prefect_client.read_concurrency_limit_by_tag( + v1_concurrency_limit.tag + ) + assert response.active_slots == [] diff --git a/tests/concurrency/v1/test_decrement_concurrency_slots.py b/tests/concurrency/v1/test_decrement_concurrency_slots.py new file mode 100644 index 000000000000..697c12b081a2 --- /dev/null +++ b/tests/concurrency/v1/test_decrement_concurrency_slots.py @@ -0,0 +1,57 @@ +import uuid +from unittest import mock + +from httpx import Response + +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.concurrency.v1.asyncio import _release_concurrency_slots + + +async def test_calls_release_client_method(): + task_run_id = uuid.UUID("00000000-0000-0000-0000-000000000000") + + limits = [ + MinimalConcurrencyLimitResponse(id=uuid.uuid4(), name=f"test-{i}", limit=i) + for i in range(1, 3) + ] + + with mock.patch( + "prefect.client.orchestration.PrefectClient.decrement_v1_concurrency_slots" + ) as client_decrement_v1_concurrency_slots: + response = Response( + 200, json=[limit.model_dump(mode="json") for limit in limits] + ) + client_decrement_v1_concurrency_slots.return_value = response + + await _release_concurrency_slots( + names=["test-1", "test-2"], task_run_id=task_run_id, occupancy_seconds=1.0 + ) + client_decrement_v1_concurrency_slots.assert_called_once_with( + names=["test-1", "test-2"], + task_run_id=task_run_id, + occupancy_seconds=1.0, + ) + + +async def test_returns_minimal_concurrency_limit(): + task_run_id = uuid.UUID("00000000-0000-0000-0000-000000000000") + + limits = [ + MinimalConcurrencyLimitResponse(id=uuid.uuid4(), name=f"test-{i}", limit=i) + for i in range(1, 3) + ] + + with mock.patch( + "prefect.client.orchestration.PrefectClient.decrement_v1_concurrency_slots" + ) as client_decrement_v1_concurrency_slots: + response = Response( + 200, json=[limit.model_dump(mode="json") for limit in limits] + ) + client_decrement_v1_concurrency_slots.return_value = response + + result = await _release_concurrency_slots( + ["test-1", "test-2"], + task_run_id, + 1.0, + ) + assert result == limits diff --git a/tests/concurrency/v1/test_increment_concurrency_limits.py b/tests/concurrency/v1/test_increment_concurrency_limits.py new file mode 100644 index 000000000000..874c93612bad --- /dev/null +++ b/tests/concurrency/v1/test_increment_concurrency_limits.py @@ -0,0 +1,58 @@ +import uuid +from unittest import mock + +from httpx import Response + +from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse +from prefect.concurrency.asyncio import _acquire_concurrency_slots + + +async def test_calls_increment_client_method(): + limits = [ + MinimalConcurrencyLimitResponse( + id=uuid.uuid4(), + name=f"test-{i}", + limit=i, + ) + for i in range(1, 3) + ] + + with mock.patch( + "prefect.client.orchestration.PrefectClient.increment_concurrency_slots" + ) as increment_concurrency_slots: + response = Response( + 200, json=[limit.model_dump(mode="json") for limit in limits] + ) + increment_concurrency_slots.return_value = response + + await _acquire_concurrency_slots( + names=["test-1", "test-2"], slots=1, mode="concurrency" + ) + increment_concurrency_slots.assert_called_once_with( + names=["test-1", "test-2"], + slots=1, + mode="concurrency", + create_if_missing=True, + ) + + +async def test_returns_minimal_concurrency_limit(): + limits = [ + MinimalConcurrencyLimitResponse( + id=uuid.uuid4(), + name=f"test-{i}", + limit=i, + ) + for i in range(1, 3) + ] + + with mock.patch( + "prefect.client.orchestration.PrefectClient.increment_concurrency_slots" + ) as increment_concurrency_slots: + response = Response( + 200, json=[limit.model_dump(mode="json") for limit in limits] + ) + increment_concurrency_slots.return_value = response + + result = await _acquire_concurrency_slots(["test-1", "test-2"], 1) + assert result == limits diff --git a/tests/server/orchestration/api/test_concurrency_limits.py b/tests/server/orchestration/api/test_concurrency_limits.py index 5c8ee16dcd15..54170bfc4fb6 100644 --- a/tests/server/orchestration/api/test_concurrency_limits.py +++ b/tests/server/orchestration/api/test_concurrency_limits.py @@ -6,6 +6,7 @@ from starlette import status from prefect.server import schemas +from prefect.server.api.concurrency_limits_v2 import MinimalConcurrencyLimitResponse from prefect.server.schemas.actions import ConcurrencyLimitCreate from prefect.settings import PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS @@ -270,3 +271,79 @@ async def test_setting_tag_to_zero_concurrency( read_response.json() ) assert concurrency_limit.active_slots == [] + + async def test_acquiring_returns_limits( + self, + client: AsyncClient, + tags_with_limits: List[str], + ): + task_run_id = uuid4() + tags = tags_with_limits + ["does-not-exist"] + + response = await client.post( + "/concurrency_limits/increment", + json={"names": tags, "task_run_id": str(task_run_id)}, + ) + assert response.status_code == status.HTTP_200_OK + + limits = [ + MinimalConcurrencyLimitResponse.model_validate(limit) + for limit in response.json() + ] + assert len(limits) == 2 # ignores tags that don't exist + + async def test_releasing_returns_limits( + self, + client: AsyncClient, + tags_with_limits: List[str], + ): + task_run_id = uuid4() + tags = tags_with_limits + ["does-not-exist"] + + response = await client.post( + "/concurrency_limits/increment", + json={"names": tags, "task_run_id": str(task_run_id)}, + ) + assert response.status_code == status.HTTP_200_OK + + response = await client.post( + "/concurrency_limits/decrement", + json={"names": tags, "task_run_id": str(task_run_id)}, + ) + assert response.status_code == status.HTTP_200_OK + + limits = [ + MinimalConcurrencyLimitResponse.model_validate(limit) + for limit in response.json() + ] + assert len(limits) == 2 # ignores tags that don't exist + + async def test_acquiring_returns_empty_list_if_no_limits( + self, + client: AsyncClient, + tags_with_limits: List[str], + ): + task_run_id = uuid4() + tags = ["does-not-exist"] + + response = await client.post( + "/concurrency_limits/increment", + json={"names": tags, "task_run_id": str(task_run_id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + async def test_releasing_returns_empty_list_if_no_limits( + self, + client: AsyncClient, + tags_with_limits: List[str], + ): + task_run_id = uuid4() + tags = ["does-not-exist"] + + response = await client.post( + "/concurrency_limits/decrement", + json={"names": tags, "task_run_id": str(task_run_id)}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 3256da19fc75..3b327fbaaa50 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -17,12 +17,12 @@ from prefect.cache_policies import FLOW_PARAMETERS from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.objects import StateType -from prefect.concurrency.asyncio import ( +from prefect.concurrency.asyncio import concurrency as aconcurrency +from prefect.concurrency.sync import concurrency +from prefect.concurrency.v1.asyncio import ( _acquire_concurrency_slots, _release_concurrency_slots, ) -from prefect.concurrency.asyncio import concurrency as aconcurrency -from prefect.concurrency.sync import concurrency from prefect.context import ( EngineContext, FlowRunContext, @@ -2271,51 +2271,59 @@ async def g(): class TestTaskConcurrencyLimits: async def test_tag_concurrency(self): + task_run_id = None + @task(tags=["limit-tag"]) async def bar(): + nonlocal task_run_id + task_run_id = TaskRunContext.get().task_run.id return 42 with mock.patch( - "prefect.concurrency.asyncio._acquire_concurrency_slots", + "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", wraps=_acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._release_concurrency_slots", + "prefect.concurrency.v1.asyncio._release_concurrency_slots", wraps=_release_concurrency_slots, ) as release_spy: await bar() acquire_spy.assert_called_once_with( - ["limit-tag"], 1, timeout_seconds=None, create_if_missing=False + ["limit-tag"], task_run_id=task_run_id, timeout_seconds=None ) - names, occupy, occupy_seconds = release_spy.call_args[0] + names, _task_run_id, occupy_seconds = release_spy.call_args[0] assert names == ["limit-tag"] - assert occupy == 1 + assert _task_run_id == task_run_id assert occupy_seconds > 0 def test_tag_concurrency_sync(self): + task_run_id = None + @task(tags=["limit-tag"]) def bar(): + nonlocal task_run_id + task_run_id = TaskRunContext.get().task_run.id return 42 with mock.patch( - "prefect.concurrency.sync._acquire_concurrency_slots", + "prefect.concurrency.v1.sync._acquire_concurrency_slots", wraps=_acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._release_concurrency_slots", + "prefect.concurrency.v1.sync._release_concurrency_slots", wraps=_release_concurrency_slots, ) as release_spy: bar() acquire_spy.assert_called_once_with( - ["limit-tag"], 1, timeout_seconds=None, create_if_missing=False + ["limit-tag"], task_run_id=task_run_id, timeout_seconds=None ) - names, occupy, occupy_seconds = release_spy.call_args[0] + names, _task_run_id, occupy_seconds = release_spy.call_args[0] assert names == ["limit-tag"] - assert occupy == 1 + assert _task_run_id == task_run_id assert occupy_seconds > 0 async def test_no_tags_no_concurrency(self): @@ -2324,11 +2332,11 @@ async def bar(): return 42 with mock.patch( - "prefect.concurrency.asyncio._acquire_concurrency_slots", + "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", wraps=_acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.asyncio._release_concurrency_slots", + "prefect.concurrency.v1.asyncio._release_concurrency_slots", wraps=_release_concurrency_slots, ) as release_spy: await bar() @@ -2342,11 +2350,11 @@ def bar(): return 42 with mock.patch( - "prefect.concurrency.sync._acquire_concurrency_slots", + "prefect.concurrency.v1.sync._acquire_concurrency_slots", wraps=_acquire_concurrency_slots, ) as acquire_spy: with mock.patch( - "prefect.concurrency.sync._release_concurrency_slots", + "prefect.concurrency.v1.sync._release_concurrency_slots", wraps=_release_concurrency_slots, ) as release_spy: bar() @@ -2355,18 +2363,22 @@ def bar(): assert release_spy.call_count == 0 async def test_tag_concurrency_does_not_create_limits(self, prefect_client): + task_run_id = None + @task(tags=["limit-tag"]) async def bar(): + nonlocal task_run_id + task_run_id = TaskRunContext.get().task_run.id return 42 with mock.patch( - "prefect.concurrency.asyncio._acquire_concurrency_slots", + "prefect.concurrency.v1.asyncio._acquire_concurrency_slots", wraps=_acquire_concurrency_slots, ) as acquire_spy: await bar() acquire_spy.assert_called_once_with( - ["limit-tag"], 1, timeout_seconds=None, create_if_missing=False + ["limit-tag"], task_run_id=task_run_id, timeout_seconds=None ) limits = await prefect_client.read_concurrency_limits(10, 0)