From 2b10d7455129f54f998aeaa00c73fe6915ba2eaa Mon Sep 17 00:00:00 2001 From: Gal Kleinman Date: Wed, 29 Nov 2023 16:10:45 +0200 Subject: [PATCH] fix: lint issues --- .../instrumentation/bedrock/__init__.py | 24 ++++++++++++------- .../bedrock/reusable_streaming_body.py | 4 ++-- .../sample_app/bedrock_example_app.py | 8 ++++--- .../traceloop/sdk/tracing/tracing.py | 4 +++- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py index 5a1cc2c45..0aa2e5c7f 100644 --- a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py +++ b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py @@ -1,6 +1,5 @@ """OpenTelemetry Bedrock instrumentation""" from functools import wraps -from itertools import tee import json import logging import os @@ -10,7 +9,6 @@ from opentelemetry import context as context_api from opentelemetry.trace import get_tracer, SpanKind -from opentelemetry.trace.status import Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import ( @@ -64,7 +62,7 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs): """Instruments and calls every function defined in TO_WRAP.""" if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): return wrapped(*args, **kwargs) - + if kwargs.get("service_name") == "bedrock-runtime": client = wrapped(*args, **kwargs) client.invoke_model = _instrumented_model_invoke(client.invoke_model, tracer) @@ -88,7 +86,7 @@ def with_instrumentation(*args, **kwargs): if span.is_recording(): (vendor, model) = kwargs.get("modelId").split(".") - + _set_span_attribute(span, SpanAttributes.LLM_VENDOR, vendor) _set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model) @@ -100,11 +98,12 @@ def with_instrumentation(*args, **kwargs): _set_ai21_span_attributes(span, request_body, response_body) elif vendor == "meta": _set_llama_span_attributes(span, request_body, response_body) - + return response - + return with_instrumentation + def _set_cohere_span_attributes(span, request_body, response_body): _set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, LLMRequestTypeValues.COMPLETION.value) _set_span_attribute(span, SpanAttributes.LLM_TOP_P, request_body.get("p")) @@ -117,6 +116,7 @@ def _set_cohere_span_attributes(span, request_body, response_body): for i, generation in enumerate(response_body.get("generations")): _set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.{i}.content", generation.get("text")) + def _set_anthropic_span_attributes(span, request_body, response_body): _set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, LLMRequestTypeValues.COMPLETION.value) _set_span_attribute(span, SpanAttributes.LLM_TOP_P, request_body.get("top_p")) @@ -127,6 +127,7 @@ def _set_anthropic_span_attributes(span, request_body, response_body): _set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.user", request_body.get("prompt")) _set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.0.content", response_body.get("completion")) + def _set_ai21_span_attributes(span, request_body, response_body): _set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, LLMRequestTypeValues.COMPLETION.value) _set_span_attribute(span, SpanAttributes.LLM_TOP_P, request_body.get("topP")) @@ -134,10 +135,17 @@ def _set_ai21_span_attributes(span, request_body, response_body): _set_span_attribute(span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, request_body.get("maxTokens")) if should_send_prompts(): - _set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.user", request_body.get("prompt")) + _set_span_attribute( + span, + f"{SpanAttributes.LLM_PROMPTS}.0.user", request_body.get("prompt") + ) for i, completion in enumerate(response_body.get("completions")): - _set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.{i}.content", completion.get("data").get("text")) + _set_span_attribute( + span, + f"{SpanAttributes.LLM_COMPLETIONS}.{i}.content", completion.get("data").get("text") + ) + def _set_llama_span_attributes(span, request_body, response_body): _set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, LLMRequestTypeValues.COMPLETION.value) diff --git a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/reusable_streaming_body.py b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/reusable_streaming_body.py index 3d6393f51..774e3d25d 100644 --- a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/reusable_streaming_body.py +++ b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/reusable_streaming_body.py @@ -14,7 +14,7 @@ def __init__(self, raw_stream, content_length): super().__init__(raw_stream, content_length) self._buffer = None self._buffer_cursor = 0 - + def read(self, amt=None): """Read at most amt bytes from the stream. @@ -28,7 +28,7 @@ def read(self, amt=None): raise ReadTimeoutError(endpoint_url=e.url, error=e) except URLLib3ProtocolError as e: raise ResponseStreamingError(error=e) - + self._amount_read += len(self._buffer) if amt is None or (not self._buffer and amt > 0): # If the server sends empty contents or diff --git a/packages/sample-app/sample_app/bedrock_example_app.py b/packages/sample-app/sample_app/bedrock_example_app.py index ae9b64c0d..32d7cc374 100644 --- a/packages/sample-app/sample_app/bedrock_example_app.py +++ b/packages/sample-app/sample_app/bedrock_example_app.py @@ -19,9 +19,9 @@ def create_joke(): }) response = brt.invoke_model( - body=body, - modelId='cohere.command-text-v14', - accept='application/json', + body=body, + modelId='cohere.command-text-v14', + accept='application/json', contentType='application/json' ) @@ -29,8 +29,10 @@ def create_joke(): return response_body.get('generations')[0].get('text') + @workflow(name="pirate_joke_generator") def joke_workflow(): print(create_joke()) + joke_workflow() diff --git a/packages/traceloop-sdk/traceloop/sdk/tracing/tracing.py b/packages/traceloop-sdk/traceloop/sdk/tracing/tracing.py index 8fae0f5b1..e8df6593a 100644 --- a/packages/traceloop-sdk/traceloop/sdk/tracing/tracing.py +++ b/packages/traceloop-sdk/traceloop/sdk/tracing/tracing.py @@ -30,7 +30,8 @@ from typing import Dict TRACER_NAME = "traceloop.tracer" -EXCLUDED_URLS = "api.openai.com,openai.azure.com,api.anthropic.com,api.cohere.ai,pinecone.io,traceloop.com,posthog.com,bedrock-runtime" +EXCLUDED_URLS = ("api.openai.com,openai.azure.com,api.anthropic.com,api.cohere.ai,pinecone.io,traceloop.com," + "posthog.com,bedrock-runtime") class TracerWrapper(object): @@ -369,6 +370,7 @@ def init_pymysql_instrumentor(): if not instrumentor.is_instrumented_by_opentelemetry: instrumentor.instrument() + def init_bedrock_instrumentor(): if importlib.util.find_spec("boto3") is not None: from opentelemetry.instrumentation.bedrock import BedrockInstrumentor