Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom_metric for dynamic batching #6189

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
flush_all: bool = False,
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
use_custom_metric: bool = False,
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -434,6 +436,8 @@ def dynamic_batching(
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:param custom_metric: Potential lambda function to measure the "weight" of each request.
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
:return: decorated function
"""

Expand Down Expand Up @@ -480,6 +484,8 @@ def _inject_owner_attrs(self, owner, name):
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
94 changes: 51 additions & 43 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from jina._docarray import docarray_v2

if not docarray_v2:
Expand All @@ -17,15 +17,17 @@ class BatchQueue:
"""A batch queue that holds the data request and the callable to batch requests to."""

def __init__(
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
) -> None:
self._data_lock = asyncio.Lock()
self.func = func
Expand All @@ -38,6 +40,7 @@ def __init__(
self._response_docarray_cls = response_docarray_cls
self._flush_all = flush_all
self._preferred_batch_size: int = preferred_batch_size
self._custom_metric = None if not use_custom_metric else custom_metric
self._timeout: int = timeout
self._reset()
self._flush_trigger: Event = Event()
Expand Down Expand Up @@ -66,9 +69,9 @@ def _reset(self) -> None:

def _cancel_timer_if_pending(self):
if (
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
):
self._timer_finished = False
self._timer_task.cancel()
Expand All @@ -84,7 +87,7 @@ async def _sleep_then_set(self):
self._flush_trigger.set()
self._timer_finished = True

async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.

This method creates an asyncio Queue for that request and keeps track of it. It returns
Expand Down Expand Up @@ -114,8 +117,12 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()
if self._custom_metric is not None:
JoanFM marked this conversation as resolved.
Show resolved Hide resolved
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()
else:
if self._custom_metric(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()

return queue

Expand All @@ -125,9 +132,9 @@ async def _await_then_flush(self, http=False) -> None:
"""

def _get_docs_groups_completed_request_indexes(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
):
"""
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
Expand All @@ -151,9 +158,9 @@ def _get_docs_groups_completed_request_indexes(
)
if req_idx > min_involved_req_idx:
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs += num_docs_in_req_idx
completed_req_idx.append(min_involved_req_idx)
min_involved_req_idx = req_idx
Expand All @@ -162,22 +169,22 @@ def _get_docs_groups_completed_request_indexes(
num_docs_in_req_idx += 1

if (
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== self._request_lens[req_idx]
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== self._request_lens[req_idx]
):
completed_req_idx.append(req_idx)
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs + num_docs_in_req_idx
]
distributed_requests.append(request_bucket)

return distributed_requests, completed_req_idx

async def _assign_results(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
):
"""
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
Expand All @@ -204,7 +211,7 @@ async def _assign_results(
request = self._requests[request_idx]
request_completed = self._requests_completed[request_idx]
if http is False or self._output_array_type is not None:
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.data.set_docs_convert_arrays(
docs_group, ndarray_type=self._output_array_type
)
Expand All @@ -214,15 +221,15 @@ async def _assign_results(

return num_assigned_docs

def batch(iterable_1, iterable_2, n:Optional[int] = 1):
def batch(iterable_1, iterable_2, n: Optional[int] = 1):
if n is None:
yield iterable_1, iterable_2
return
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
ndx : min(ndx + n, items)
]
yield iterable_1[ndx: min(ndx + n, items)], iterable_2[
ndx: min(ndx + n, items)
]

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
Expand All @@ -240,8 +247,9 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
# TODO: Change batch to consider potential custom metric
for docs_inner_batch, req_idxs in batch(
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand All @@ -256,8 +264,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
Expand All @@ -279,8 +287,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in self._requests_completed[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
Expand All @@ -295,11 +303,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
Expand Down
Loading