Skip to content

Commit

Permalink
feat: use dynamic batching param
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 23, 2024
1 parent 434d09e commit 515a121
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 17 deletions.
3 changes: 1 addition & 2 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self._inputs = None
self._inputs_length = None
self._setup_instrumentation(
name=(
self.args.name
Expand Down Expand Up @@ -144,8 +145,6 @@ def _get_requests(
else:
total_docs = None

self._inputs_length = None

if total_docs:
self._inputs_length = max(1, total_docs / _kwargs['request_size'])

Expand Down
3 changes: 3 additions & 0 deletions jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def dynamic_batching(
flush_all: bool = False,
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
use_custom_metric: bool = False,
use_dynamic_batching: bool = True,
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -438,6 +439,7 @@ def dynamic_batching(
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:param custom_metric: Potential lambda function to measure the "weight" of each request.
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
:param use_dynamic_batching: Determines if we should apply dynamic batching for this method.
:return: decorated function
"""

Expand Down Expand Up @@ -486,6 +488,7 @@ def _inject_owner_attrs(self, owner, name):
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
owner.dynamic_batching[fn_name]['use_dynamic_batching'] = use_dynamic_batching
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
5 changes: 3 additions & 2 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
**kwargs,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
self.func = func
Expand Down Expand Up @@ -285,7 +286,8 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
self._preferred_batch_size if not self._flush_all else None,
docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand Down Expand Up @@ -360,7 +362,6 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
requests_completed_in_batch,
)


async def close(self):
"""Closes the batch queue by flushing pending requests."""
if not self._is_closed:
Expand Down
9 changes: 5 additions & 4 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,11 @@ def _init_batchqueue_dict(self):
)
raise Exception(error_msg)

if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))
if dbatch_config.get("use_dynamic_batching", True):
if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))

# Specific endpoint configs take precedence over function configs
for endpoint, dbatch_config in dbatch_endpoints:
Expand Down
49 changes: 46 additions & 3 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ def foo(self, docs, **kwargs):


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True])
@pytest.mark.parametrize('flush_all', [True])
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [True, False])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

Expand All @@ -719,7 +719,9 @@ def foo(self, docs, **kwargs):
for doc in docs:
doc.text = f"{total_len}"

depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={
'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric,
"flush_all": flush_all}})
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
Expand All @@ -733,3 +735,44 @@ def foo(self, docs, **kwargs):
):
res.extend(r)
assert len(res) == 50 # 1 request per input


@pytest.mark.asyncio
@pytest.mark.parametrize('use_dynamic_batching', [True, False])
async def test_use_dynamic_batching(use_dynamic_batching):
class UseDynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=10)
@requests(on='/foo')
def foo(self, docs, **kwargs):
print(f'len docs {len(docs)}')
for doc in docs:
doc.text = f"{len(docs)}"

depl = Deployment(uses=UseDynBatchProcessor, uses_dynamic_batching={
'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_dynamic_batching": use_dynamic_batching,
"flush_all": False}})
da = DocumentArray([Document(text='aaaaa') for _ in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input
for doc in res:
num_10 = 0
if doc.text == "10":
num_10 += 1
if not use_dynamic_batching:
assert doc.text == "1"

if use_dynamic_batching:
assert num_10 > 0
else:
assert num_10 == 0
12 changes: 6 additions & 6 deletions tests/unit/serve/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,15 @@ class C(B):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
],
)
Expand All @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
],
)
Expand Down

0 comments on commit 515a121

Please sign in to comment.