Skip to content

Commit

Permalink
fix: errors on logging openai streaming completion calls (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 3, 2023
1 parent e3e97ef commit 1a31383
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import types
import itertools
from typing import Collection
from wrapt import wrap_function_wrapper
import openai
Expand Down Expand Up @@ -149,27 +150,29 @@ def _set_response_attributes(span, llm_request_type, response):
return


def _build_from_streaming_response(response):
def _build_from_streaming_response(llm_request_type, response):
complete_response = {"choices": [], "model": ""}
for item in response:
for choice in item.get("choices"):
index = choice.get("index")
if len(complete_response.get("choices")) <= index:
complete_response["choices"].append(
{
"index": index,
"message": {"content": "", "role": ""},
}
{"index": index, "message": {"content": "", "role": ""}}
if llm_request_type == LLMRequestTypeValues.CHAT
else {"index": index, "text": ""}
)
complete_choice = complete_response.get("choices")[index]
if choice.get("finish_reason"):
complete_choice["finish_reason"] = choice.get("finish_reason")
if choice.get("delta").get("content"):
complete_choice["message"]["content"] += choice.get("delta").get(
"content"
)
if choice.get("delta").get("role"):
complete_choice["message"]["role"] = choice.get("delta").get("role")
if llm_request_type == LLMRequestTypeValues.CHAT:
if choice.get("delta").get("content"):
complete_choice["message"]["content"] += choice.get("delta").get(
"content"
)
if choice.get("delta").get("role"):
complete_choice["message"]["role"] = choice.get("delta").get("role")
else:
complete_choice["text"] += choice.get("text")
return complete_response


Expand Down Expand Up @@ -231,10 +234,13 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
try:
if span.is_recording():
if isinstance(response, types.GeneratorType):
response, to_extract_spans = itertools.tee(response)
_set_response_attributes(
span,
llm_request_type,
_build_from_streaming_response(response),
_build_from_streaming_response(
llm_request_type, to_extract_spans
),
)
else:
_set_response_attributes(span, llm_request_type, response)
Expand Down
2 changes: 1 addition & 1 deletion packages/sample-app/sample_app/langchain_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def langchain_app():
llm = OpenAI(temperature=0)
llm = OpenAI(temperature=0, streaming=True)
search = DuckDuckGoSearchAPIWrapper()
llm_math_chain = LLMMathChain.from_llm(llm)
tools = [
Expand Down

0 comments on commit 1a31383

Please sign in to comment.