Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom_metric for dynamic batching #6189

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,22 @@ def _validate_sagemaker(self):
return

def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]):
import collections

def deep_update(source, overrides):
for key, value in overrides.items():
if isinstance(value, collections.Mapping) and value:
returned = deep_update(source.get(key, {}), value)
source[key] = returned
else:
source[key] = overrides[key]
return source

if _dynamic_batching:
self.dynamic_batching = getattr(self, 'dynamic_batching', {})
self.dynamic_batching.update(_dynamic_batching)
self.dynamic_batching = deep_update(
self.dynamic_batching, _dynamic_batching
)

def _add_metas(self, _metas: Optional[Dict]):
from jina.serve.executors.metas import get_default_metas
Expand Down
8 changes: 7 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
flush_all: bool = False,
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
use_custom_metric: bool = False,
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -434,6 +436,8 @@ def dynamic_batching(
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:param custom_metric: Potential lambda function to measure the "weight" of each request.
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
:return: decorated function
"""

Expand Down Expand Up @@ -480,6 +484,8 @@ def _inject_owner_attrs(self, owner, name):
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
131 changes: 81 additions & 50 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import copy
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from jina._docarray import docarray_v2
import contextlib

if not docarray_v2:
from docarray import DocumentArray
else:
Expand All @@ -18,16 +19,18 @@ class BatchQueue:
"""A batch queue that holds the data request and the callable to batch requests to."""

def __init__(
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
if allow_concurrent and flush_all:
Expand All @@ -44,6 +47,8 @@ def __init__(
self._response_docarray_cls = response_docarray_cls
self._flush_all = flush_all
self._preferred_batch_size: int = preferred_batch_size
self._custom_metric = None if not use_custom_metric else custom_metric
self._metric_value = 0
self._timeout: int = timeout
self._reset()
self._flush_trigger: Event = Event()
Expand All @@ -62,20 +67,22 @@ def _reset(self) -> None:
# a list of every request ID
self._request_idxs: List[int] = []
self._request_lens: List[int] = []
self._docs_metrics: List[int] = []
self._requests_completed: List[asyncio.Queue] = []
if not docarray_v2:
self._big_doc: DocumentArray = DocumentArray.empty()
else:
self._big_doc = self._request_docarray_cls()
self._metric_value = 0

self._flush_task: Optional[Task] = None
self._flush_trigger: Event = Event()

def _cancel_timer_if_pending(self):
if (
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
):
self._timer_finished = False
self._timer_task.cancel()
Expand All @@ -91,7 +98,7 @@ async def _sleep_then_set(self):
self._flush_trigger.set()
self._timer_finished = True

async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.

This method creates an asyncio Queue for that request and keeps track of it. It returns
Expand All @@ -116,12 +123,18 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()

return queue
Expand All @@ -132,10 +145,10 @@ async def _await_then_flush(self, http=False) -> None:
"""

def _get_docs_groups_completed_request_indexes(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
):
"""
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
Expand All @@ -160,9 +173,9 @@ def _get_docs_groups_completed_request_indexes(
)
if req_idx > min_involved_req_idx:
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs += num_docs_in_req_idx
completed_req_idx.append(min_involved_req_idx)
min_involved_req_idx = req_idx
Expand All @@ -171,25 +184,25 @@ def _get_docs_groups_completed_request_indexes(
num_docs_in_req_idx += 1

if (
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== requests_lens_in_batch[req_idx]
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== requests_lens_in_batch[req_idx]
):
completed_req_idx.append(req_idx)
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs + num_docs_in_req_idx
]
distributed_requests.append(request_bucket)

return distributed_requests, completed_req_idx

async def _assign_results(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
):
"""
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
Expand Down Expand Up @@ -220,7 +233,7 @@ async def _assign_results(
request = requests_in_batch[request_idx]
request_completed = requests_completed_in_batch[request_idx]
if http is False or self._output_array_type is not None:
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.data.set_docs_convert_arrays(
docs_group, ndarray_type=self._output_array_type
)
Expand All @@ -230,22 +243,39 @@ async def _assign_results(

return num_assigned_docs

def batch(iterable_1, iterable_2, n:Optional[int] = 1):
def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Optional = None):
if n is None:
yield iterable_1, iterable_2
return
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
ndx : min(ndx + n, items)
]
elif iterable_metrics is None:
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx: min(ndx + n, items)], iterable_2[
ndx: min(ndx + n, items)
]
else:
batch_idx = 0
batch_weight = 0

for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)):
batch_weight += weight

if batch_weight >= n:
yield iterable_1[batch_idx: i + 1], iterable_2[batch_idx: i + 1]
batch_idx = i + 1
batch_weight = 0

# Yield any remaining items
if batch_weight > 0:
yield iterable_1[batch_idx: len(iterable_1)], iterable_2[batch_idx: len(iterable_1)]

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
async with self._data_lock:
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

Expand All @@ -263,7 +293,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand All @@ -278,8 +309,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
Expand All @@ -301,8 +332,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
Expand All @@ -320,11 +351,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
Expand Down
64 changes: 64 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,67 @@ def foo(self, docs, **kwargs):

assert smaller_than_5 == (1 if allow_concurrent else 0)
assert larger_than_5 > 0


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [False, True])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=10, custom_metric=lambda x: len(x.text))
@requests(on='/foo')
def foo(self, docs, **kwargs):
time.sleep(0.5)
total_len = sum([len(doc.text) for doc in docs])
for doc in docs:
doc.text = f"{total_len}"

depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input

# If custom_metric and flush all
if use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "10"

elif not use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "50"

elif use_custom_metric and flush_all:
# There will be 2 "10" and the rest will be "240"
num_10 = 0
num_240 = 0
for doc in res:
if doc.text == "10":
num_10 += 1
elif doc.text == "240":
num_240 += 1

assert num_10 == 2
assert num_240 == 48
elif not use_custom_metric and flush_all:
# There will be 10 "50" and the rest will be "200"
num_50 = 0
num_200 = 0
for doc in res:
if doc.text == "50":
num_50 += 1
elif doc.text == "200":
num_200 += 1

assert num_50 == 10
assert num_200 == 40
Loading
Loading