Skip to content

Commit

Permalink
fix(openai): yield chunks for streaming (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 9, 2023
1 parent 77a11b3 commit ae70f55
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import types
import itertools
import pkg_resources
from typing import Collection
from wrapt import wrap_function_wrapper
Expand Down Expand Up @@ -188,7 +187,7 @@ def _set_response_attributes(span, llm_request_type, response):
return


def _build_from_streaming_response(llm_request_type, response):
def _build_from_streaming_response(span, llm_request_type, response):
complete_response = {"choices": [], "model": ""}
for item in response:
if is_openai_v1():
Expand Down Expand Up @@ -219,7 +218,16 @@ def _build_from_streaming_response(llm_request_type, response):
complete_choice["message"]["role"] = delta.get("role")
else:
complete_choice["text"] += choice.get("text")
return complete_response

yield item

_set_response_attributes(
span,
llm_request_type,
complete_response,
)
span.set_status(Status(StatusCode.OK))
span.end()


def _with_tracer_wrapper(func):
Expand Down Expand Up @@ -252,9 +260,9 @@ def _llm_request_type_by_module_object(module_name, object_name):


def is_streaming_response(response):
return isinstance(response, types.GeneratorType) or (is_openai_v1() and isinstance(
response, openai.Stream
))
return isinstance(response, types.GeneratorType) or (
is_openai_v1() and isinstance(response, openai.Stream)
)


@_with_tracer_wrapper
Expand All @@ -267,55 +275,53 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
llm_request_type = _llm_request_type_by_module_object(
to_wrap.get("module"), to_wrap.get("object")
)
with tracer.start_as_current_span(

span = tracer.start_span(
name,
kind=SpanKind.CLIENT,
attributes={
SpanAttributes.LLM_VENDOR: "OpenAI",
SpanAttributes.LLM_REQUEST_TYPE: llm_request_type.value,
},
) as span:
)

if span.is_recording():
_set_api_attributes(span)
try:
if span.is_recording():
_set_api_attributes(span)
_set_input_attributes(span, llm_request_type, kwargs)

except Exception as ex: # pylint: disable=broad-except
logger.warning(
"Failed to set input attributes for openai span, error: %s", str(ex)
)

response = wrapped(*args, **kwargs)

if response:
try:
if span.is_recording():
_set_input_attributes(span, llm_request_type, kwargs)
if is_streaming_response(response):
return _build_from_streaming_response(
span, llm_request_type, response
)
else:
_set_response_attributes(
span,
llm_request_type,
response.__dict__ if is_openai_v1() else response,
)

except Exception as ex: # pylint: disable=broad-except
logger.warning(
"Failed to set input attributes for openai span, error: %s", str(ex)
"Failed to set response attributes for openai span, error: %s",
str(ex),
)
if span.is_recording():
span.set_status(Status(StatusCode.OK))

response = wrapped(*args, **kwargs)

if response:
try:
if span.is_recording():
if is_streaming_response(response):
response, to_extract_spans = itertools.tee(response)
_set_response_attributes(
span,
llm_request_type,
_build_from_streaming_response(
llm_request_type, to_extract_spans
),
)
else:
_set_response_attributes(
span,
llm_request_type,
response.__dict__ if is_openai_v1() else response,
)

except Exception as ex: # pylint: disable=broad-except
logger.warning(
"Failed to set response attributes for openai span, error: %s",
str(ex),
)
if span.is_recording():
span.set_status(Status(StatusCode.OK))

return response
span.end()
return response


class OpenAISpanAttributes:
Expand Down
15 changes: 8 additions & 7 deletions packages/sample-app/sample_app/openai_streaming.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from openai import OpenAI
import os
import openai

from traceloop.sdk import Traceloop
from traceloop.sdk.decorators import workflow

client = OpenAI()
Traceloop.init(app_name="joke_generation_service")
openai.api_key = os.getenv("OPENAI_API_KEY")
Traceloop.init(app_name="story_service")


@workflow(name="streaming_joke_creation")
@workflow(name="streaming_story")
def joke_workflow():
stream = client.chat.completions.create(
stream = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke about opentelemetry"}],
messages=[{"role": "user", "content": "Tell me a story about opentelemetry"}],
stream=True,
)

for part in stream:
print(part.choices[0].delta.content or "", end="")
print(part.choices[0].delta.get("content") or "", end="")
print()


Expand Down
5 changes: 4 additions & 1 deletion packages/traceloop-sdk/tests/test_openai_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ def test_completion_langchain_style(exporter, openai_client):


def test_streaming(exporter, openai_client):
openai_client.chat.completions.create(
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me a joke about opentelemetry"}],
stream=True,
)

for part in response:
pass

spans = exporter.get_finished_spans()
assert [span.name for span in spans] == [
"openai.chat",
Expand Down

0 comments on commit ae70f55

Please sign in to comment.