Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: monitoring validation error #5965

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 55 additions & 54 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ class WorkerRequestHandler:
_KEY_RESULT = '__results__'

def __init__(
self,
args: 'argparse.Namespace',
logger: 'JinaLogger',
metrics_registry: Optional['CollectorRegistry'] = None,
tracer_provider: Optional['trace.TracerProvider'] = None,
meter_provider: Optional['metrics.MeterProvider'] = None,
meter=None,
tracer=None,
deployment_name: str = '',
**kwargs,
self,
args: 'argparse.Namespace',
logger: 'JinaLogger',
metrics_registry: Optional['CollectorRegistry'] = None,
tracer_provider: Optional['trace.TracerProvider'] = None,
meter_provider: Optional['metrics.MeterProvider'] = None,
meter=None,
tracer=None,
deployment_name: str = '',
**kwargs,
):
"""Initialize private parameters and execute private loading functions.

Expand All @@ -79,8 +79,8 @@ def __init__(
self._is_closed = False
if self.metrics_registry:
with ImportExtensions(
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
):
from prometheus_client import Counter, Summary

Expand Down Expand Up @@ -201,9 +201,9 @@ async def _hot_reload(self):
watched_files.add(extra_python_file)

with ImportExtensions(
required=True,
logger=self.logger,
help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install
required=True,
logger=self.logger,
help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install
watchfiles''',
):
from watchfiles import awatch
Expand Down Expand Up @@ -272,16 +272,16 @@ def _init_batchqueue_dict(self):
}

def _init_monitoring(
self,
metrics_registry: Optional['CollectorRegistry'] = None,
meter: Optional['metrics.Meter'] = None,
self,
metrics_registry: Optional['CollectorRegistry'] = None,
meter: Optional['metrics.Meter'] = None,
):

if metrics_registry:

with ImportExtensions(
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
required=True,
help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
):
from prometheus_client import Counter, Summary

Expand Down Expand Up @@ -337,10 +337,10 @@ def _init_monitoring(
self._sent_response_size_histogram = None

def _load_executor(
self,
metrics_registry: Optional['CollectorRegistry'] = None,
tracer_provider: Optional['trace.TracerProvider'] = None,
meter_provider: Optional['metrics.MeterProvider'] = None,
self,
metrics_registry: Optional['CollectorRegistry'] = None,
tracer_provider: Optional['trace.TracerProvider'] = None,
meter_provider: Optional['metrics.MeterProvider'] = None,
):
"""
Load the executor to this runtime, specified by ``uses`` CLI argument.
Expand Down Expand Up @@ -471,14 +471,14 @@ def _record_request_size_monitoring(self, requests):
)
self._request_size_histogram.record(req.nbytes, attributes=attributes)

def _record_docs_processed_monitoring(self, requests):
def _record_docs_processed_monitoring(self, requests, len_docs: int):
if self._document_processed_metrics:
self._document_processed_metrics.labels(
requests[0].header.exec_endpoint,
self._executor.__class__.__name__,
self.args.name,
).inc(
len(requests[0].docs)
len_docs
) # TODO we can optimize here and access the
# lenght of the da without loading the da in memory

Expand All @@ -489,7 +489,7 @@ def _record_docs_processed_monitoring(self, requests):
self.args.name,
)
self._document_processed_counter.add(
len(requests[0].docs), attributes=attributes
len_docs, attributes=attributes
) # TODO same as above

def _record_response_size_monitoring(self, requests):
Expand Down Expand Up @@ -536,10 +536,10 @@ def _set_result(self, requests, return_data, docs):
return docs

async def _setup_requests(
self,
requests: List['DataRequest'],
exec_endpoint: str,
tracing_context: Optional['Context'] = None,
self,
requests: List['DataRequest'],
exec_endpoint: str,
tracing_context: Optional['Context'] = None,
):
"""Execute a request using the executor.

Expand Down Expand Up @@ -568,7 +568,7 @@ async def _setup_requests(
return requests, params

async def handle_generator(
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
) -> Generator:
"""Prepares and executes a request for generator endpoints.

Expand Down Expand Up @@ -606,7 +606,7 @@ async def handle_generator(
)

async def handle(
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
) -> DataRequest:
"""Initialize private parameters and execute private loading functions.

Expand All @@ -632,6 +632,7 @@ async def handle(
requests, exec_endpoint, tracing_context=tracing_context
)

len_docs = len(requests[0].docs) # TODO we can optimize here and access the
if exec_endpoint in self._batchqueue_config:
assert len(requests) == 1, 'dynamic batching does not support no_reduce'

Expand Down Expand Up @@ -667,20 +668,20 @@ async def handle(
for req in requests:
req.add_executor(self.deployment_name)

self._record_docs_processed_monitoring(requests)
self._record_response_size_monitoring(requests)
self._record_docs_processed_monitoring(requests, len_docs)
try:
requests[0].document_array_cls = self._executor.requests[
exec_endpoint
].response_schema
except AttributeError:
pass
self._record_response_size_monitoring(requests)

return requests[0]

@staticmethod
def replace_docs(
request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None
request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None
) -> None:
"""Replaces the docs in a message with new Documents.

Expand Down Expand Up @@ -728,7 +729,7 @@ async def close(self):

@staticmethod
def _get_docs_matrix_from_request(
requests: List['DataRequest'],
requests: List['DataRequest'],
) -> Tuple[Optional[List['DocumentArray']], Optional[Dict[str, 'DocumentArray']]]:
"""
Returns a docs matrix from a list of DataRequest objects.
Expand All @@ -752,7 +753,7 @@ def _get_docs_matrix_from_request(

@staticmethod
def get_parameters_dict_from_request(
requests: List['DataRequest'],
requests: List['DataRequest'],
) -> 'Dict':
"""
Returns a parameters dict from a list of DataRequest objects.
Expand All @@ -772,7 +773,7 @@ def get_parameters_dict_from_request(

@staticmethod
def get_docs_from_request(
requests: List['DataRequest'],
requests: List['DataRequest'],
) -> 'DocumentArray':
"""
Gets a field from the message
Expand Down Expand Up @@ -852,7 +853,7 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest':

# serving part
async def process_single_data(
self, request: DataRequest, context, is_generator: bool = False
self, request: DataRequest, context, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
Expand Down Expand Up @@ -908,7 +909,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
return endpoints_proto

def _extract_tracing_context(
self, metadata: grpc.aio.Metadata
self, metadata: grpc.aio.Metadata
) -> Optional['Context']:
if self.tracer:
from opentelemetry.propagate import extract
Expand All @@ -924,7 +925,7 @@ def _log_data_request(self, request: DataRequest):
)

async def process_data(
self, requests: List[DataRequest], context, is_generator: bool = False
self, requests: List[DataRequest], context, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
Expand All @@ -935,7 +936,7 @@ async def process_data(
:returns: the response request
"""
with MetricsTimer(
self._summary, self._receiving_request_seconds, self._metric_attributes
self._summary, self._receiving_request_seconds, self._metric_attributes
):
try:
if self.logger.debug_enabled:
Expand Down Expand Up @@ -984,8 +985,8 @@ async def process_data(
)

if (
self.args.exit_on_exceptions
and type(ex).__name__ in self.args.exit_on_exceptions
self.args.exit_on_exceptions
and type(ex).__name__ in self.args.exit_on_exceptions
):
self.logger.info('Exiting because of "--exit-on-exceptions".')
raise RuntimeTerminated
Expand All @@ -1010,7 +1011,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
return info_proto

async def stream(
self, request_iterator, context=None, *args, **kwargs
self, request_iterator, context=None, *args, **kwargs
) -> AsyncIterator['Request']:
"""
stream requests from client iterator and stream responses back.
Expand All @@ -1027,8 +1028,8 @@ async def stream(
Call = stream

def _create_snapshot_status(
self,
snapshot_directory: str,
self,
snapshot_directory: str,
) -> 'jina_pb2.SnapshotStatusProto':
_id = str(uuid.uuid4())
self.logger.debug(f'Generated snapshot id: {_id}')
Expand All @@ -1041,7 +1042,7 @@ def _create_snapshot_status(
)

def _create_restore_status(
self,
self,
) -> 'jina_pb2.SnapshotStatusProto':
_id = str(uuid.uuid4())
self.logger.debug(f'Generated restore id: {_id}')
Expand All @@ -1060,9 +1061,9 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto':
"""
self.logger.debug(f' Calling snapshot')
if (
self._snapshot
and self._snapshot_thread
and self._snapshot_thread.is_alive()
self._snapshot
and self._snapshot_thread
and self._snapshot_thread.is_alive()
):
raise RuntimeError(
f'A snapshot with id {self._snapshot.id.value} is currently in progress. Cannot start another.'
Expand All @@ -1080,7 +1081,7 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto':
return self._snapshot

async def snapshot_status(
self, request: 'jina_pb2.SnapshotId', context
self, request: 'jina_pb2.SnapshotId', context
) -> 'jina_pb2.SnapshotStatusProto':
"""
method to start a snapshot process of the Executor
Expand Down Expand Up @@ -1141,7 +1142,7 @@ async def restore(self, request: 'jina_pb2.RestoreSnapshotCommand', context):
return self._restore

async def restore_status(
self, request, context
self, request, context
) -> 'jina_pb2.RestoreSnapshotStatusProto':
"""
method to start a snapshot process of the Executor
Expand Down
23 changes: 23 additions & 0 deletions tests/integration/docarray_v2/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,26 @@ async def filter(self, docs: DocList[TagsDoc], **kwargs) -> DocList[TagsDoc]:
for doc in res:
assert doc.aux.a == 'b'
assert doc.tags == {'a': {'b': 1}}

def test_issue_with_monitoring():

class InputDocMonitor(BaseDoc):
text: str

class OutputDocMonitor(BaseDoc):
price: int

class MonitorExecTest(Executor):
@requests
def foo(self, docs: DocList[InputDocMonitor], **kwargs) -> DocList[OutputDocMonitor]:
ret = DocList[OutputDocMonitor]()
for doc in docs:
ret.append(OutputDocMonitor(price=2))
return ret


f = Flow(monitoring=True).add(uses=MonitorExecTest, monitoring=True)
with f:
ret = f.post(on='/', inputs=DocList[InputDocMonitor]([InputDocMonitor(text='2')]), return_type=DocList[OutputDocMonitor])
assert len(ret) == 1
assert ret[0].price == 2
Loading