Skip to content

Commit

Permalink
Replicate concurrency helper from v2 for v1 limits (#15037)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Streed <[email protected]>
Co-authored-by: Chris Guidry <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 737f536 commit eac7892
Show file tree
Hide file tree
Showing 21 changed files with 1,481 additions and 37 deletions.
8 changes: 7 additions & 1 deletion docs/3.0rc/api-ref/rest-api/server/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -4323,7 +4323,13 @@
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MinimalConcurrencyLimitResponse"
},
"title": "Response Increment Concurrency Limits V1 Concurrency Limits Increment Post"
}
}
}
},
Expand Down
75 changes: 75 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,57 @@ async def delete_concurrency_limit_by_tag(
else:
raise

async def increment_v1_concurrency_slots(
self,
names: List[str],
task_run_id: UUID,
) -> httpx.Response:
"""
Increment concurrency limit slots for the specified limits.
Args:
names (List[str]): A list of limit names for which to increment limits.
task_run_id (UUID): The task run ID incrementing the limits.
"""
data = {
"names": names,
"task_run_id": str(task_run_id),
}

return await self._client.post(
"/concurrency_limits/increment",
json=data,
)

async def decrement_v1_concurrency_slots(
self,
names: List[str],
task_run_id: UUID,
occupancy_seconds: float,
) -> httpx.Response:
"""
Decrement concurrency limit slots for the specified limits.
Args:
names (List[str]): A list of limit names to decrement.
task_run_id (UUID): The task run ID that incremented the limits.
occupancy_seconds (float): The duration in seconds that the limits
were held.
Returns:
httpx.Response: The HTTP response from the server.
"""
data = {
"names": names,
"task_run_id": str(task_run_id),
"occupancy_seconds": occupancy_seconds,
}

return await self._client.post(
"/concurrency_limits/decrement",
json=data,
)

async def create_work_queue(
self,
name: str,
Expand Down Expand Up @@ -4116,3 +4167,27 @@ def release_concurrency_slots(
"occupancy_seconds": occupancy_seconds,
},
)

def decrement_v1_concurrency_slots(
self, names: List[str], occupancy_seconds: float, task_run_id: UUID
) -> httpx.Response:
"""
Release the specified concurrency limits.
Args:
names (List[str]): A list of limit names to decrement.
occupancy_seconds (float): The duration in seconds that the slots
were held.
task_run_id (UUID): The task run ID that incremented the limits.
Returns:
httpx.Response: The HTTP response from the server.
"""
return self._client.post(
"/concurrency_limits/decrement",
json={
"names": names,
"occupancy_seconds": occupancy_seconds,
"task_run_id": str(task_run_id),
},
)
Empty file.
143 changes: 143 additions & 0 deletions src/prefect/concurrency/v1/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional, Union, cast
from uuid import UUID

import anyio
import httpx
import pendulum

from ...client.schemas.responses import MinimalConcurrencyLimitResponse

try:
from pendulum import Interval
except ImportError:
# pendulum < 3
from pendulum.period import Period as Interval # type: ignore

from prefect.client.orchestration import get_client

from .context import ConcurrencyContext
from .events import (
_emit_concurrency_acquisition_events,
_emit_concurrency_release_events,
)
from .services import ConcurrencySlotAcquisitionService


class ConcurrencySlotAcquisitionError(Exception):
"""Raised when an unhandlable occurs while acquiring concurrency slots."""


class AcquireConcurrencySlotTimeoutError(TimeoutError):
"""Raised when acquiring a concurrency slot times out."""


@asynccontextmanager
async def concurrency(
names: Union[str, List[str]],
task_run_id: UUID,
timeout_seconds: Optional[float] = None,
) -> AsyncGenerator[None, None]:
"""A context manager that acquires and releases concurrency slots from the
given concurrency limits.
Args:
names: The names of the concurrency limits to acquire slots from.
task_run_id: The name of the task_run_id that is incrementing the slots.
timeout_seconds: The number of seconds to wait for the slots to be acquired before
raising a `TimeoutError`. A timeout of `None` will wait indefinitely.
Raises:
TimeoutError: If the slots are not acquired within the given timeout.
Example:
A simple example of using the async `concurrency` context manager:
```python
from prefect.concurrency.v1.asyncio import concurrency
async def resource_heavy():
async with concurrency("test", task_run_id):
print("Resource heavy task")
async def main():
await resource_heavy()
```
"""
if not names:
yield
return

names_normalized: List[str] = names if isinstance(names, list) else [names]

limits = await _acquire_concurrency_slots(
names_normalized,
task_run_id=task_run_id,
timeout_seconds=timeout_seconds,
)
acquisition_time = pendulum.now("UTC")
emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id)

try:
yield
finally:
occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time))
try:
await _release_concurrency_slots(
names_normalized, task_run_id, occupancy_period.total_seconds()
)
except anyio.get_cancelled_exc_class():
# The task was cancelled before it could release the slots. Add the
# slots to the cleanup list so they can be released when the
# concurrency context is exited.
if ctx := ConcurrencyContext.get():
ctx.cleanup_slots.append(
(names_normalized, occupancy_period.total_seconds(), task_run_id)
)

_emit_concurrency_release_events(limits, emitted_events, task_run_id)


async def _acquire_concurrency_slots(
names: List[str],
task_run_id: UUID,
timeout_seconds: Optional[float] = None,
) -> List[MinimalConcurrencyLimitResponse]:
service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
future = service.send((task_run_id, timeout_seconds))
response_or_exception = await asyncio.wrap_future(future)

if isinstance(response_or_exception, Exception):
if isinstance(response_or_exception, TimeoutError):
raise AcquireConcurrencySlotTimeoutError(
f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)"
) from response_or_exception

raise ConcurrencySlotAcquisitionError(
f"Unable to acquire concurrency limits {names!r}"
) from response_or_exception

return _response_to_concurrency_limit_response(response_or_exception)


async def _release_concurrency_slots(
names: List[str],
task_run_id: UUID,
occupancy_seconds: float,
) -> List[MinimalConcurrencyLimitResponse]:
async with get_client() as client:
response = await client.decrement_v1_concurrency_slots(
names=names,
task_run_id=task_run_id,
occupancy_seconds=occupancy_seconds,
)
return _response_to_concurrency_limit_response(response)


def _response_to_concurrency_limit_response(
response: httpx.Response,
) -> List[MinimalConcurrencyLimitResponse]:
data = response.json() or []
return [
MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data
]
27 changes: 27 additions & 0 deletions src/prefect/concurrency/v1/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from contextvars import ContextVar
from typing import List, Tuple
from uuid import UUID

from prefect.client.orchestration import get_client
from prefect.context import ContextModel, Field


class ConcurrencyContext(ContextModel):
__var__: ContextVar = ContextVar("concurrency_v1")

# Track the limits that have been acquired but were not able to be released
# due to cancellation or some other error. These limits are released when
# the context manager exits.
cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list)

def __exit__(self, *exc_info):
if self.cleanup_slots:
with get_client(sync_client=True) as client:
for names, occupancy_seconds, task_run_id in self.cleanup_slots:
client.decrement_v1_concurrency_slots(
names=names,
occupancy_seconds=occupancy_seconds,
task_run_id=task_run_id,
)

return super().__exit__(*exc_info)
61 changes: 61 additions & 0 deletions src/prefect/concurrency/v1/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Dict, List, Literal, Optional, Union
from uuid import UUID

from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
from prefect.events import Event, RelatedResource, emit_event


def _emit_concurrency_event(
phase: Union[Literal["acquired"], Literal["released"]],
primary_limit: MinimalConcurrencyLimitResponse,
related_limits: List[MinimalConcurrencyLimitResponse],
task_run_id: UUID,
follows: Union[Event, None] = None,
) -> Union[Event, None]:
resource: Dict[str, str] = {
"prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}",
"prefect.resource.name": primary_limit.name,
"limit": str(primary_limit.limit),
"task_run_id": str(task_run_id),
}

related = [
RelatedResource.model_validate(
{
"prefect.resource.id": f"prefect.concurrency-limit.v1.{limit.id}",
"prefect.resource.role": "concurrency-limit",
}
)
for limit in related_limits
if limit.id != primary_limit.id
]

return emit_event(
f"prefect.concurrency-limit.v1.{phase}",
resource=resource,
related=related,
follows=follows,
)


def _emit_concurrency_acquisition_events(
limits: List[MinimalConcurrencyLimitResponse],
task_run_id: UUID,
) -> Dict[UUID, Optional[Event]]:
events = {}
for limit in limits:
event = _emit_concurrency_event("acquired", limit, limits, task_run_id)
events[limit.id] = event

return events


def _emit_concurrency_release_events(
limits: List[MinimalConcurrencyLimitResponse],
events: Dict[UUID, Optional[Event]],
task_run_id: UUID,
) -> None:
for limit in limits:
_emit_concurrency_event(
"released", limit, limits, task_run_id, events[limit.id]
)
Loading

0 comments on commit eac7892

Please sign in to comment.