Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch streaming API code #1693

Merged
merged 25 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4ffee3d
fix: patch stream_steps to properly trigger + return compact JSON ins…
cpacker Aug 28, 2024
0fef357
fix: patch the type hint"
cpacker Aug 28, 2024
0e4a861
fix: patch type hints
cpacker Aug 28, 2024
7b219ff
feat: add dummy Message creation at the top of the stream handler so …
cpacker Aug 28, 2024
9983af7
fix: add deprecated function role as an option
cpacker Aug 28, 2024
fcbc428
feat: add FunctionCallDelta to the MemGPT message schema, since this …
cpacker Aug 28, 2024
eef77c0
feat: support passing ID when doing a dict-to-message Message constru…
cpacker Aug 28, 2024
21f501b
fix: fix issues with process_chunk that broke when we added typing to…
cpacker Aug 28, 2024
1fe6b75
fix: allow process_chunk to take extra message IDs and created_ats to…
cpacker Aug 28, 2024
f2e1766
feat: provide option in the openai stream handler determining whether…
cpacker Aug 28, 2024
c9e001a
chore: cleanup
cpacker Aug 28, 2024
524a4d6
fix: fix pretty bad bug where we were always creating new timestamps …
cpacker Aug 28, 2024
9bb3224
chore: cleanup
cpacker Aug 28, 2024
0a42461
fix: fix bug where summarizer was broken when streaming was on
cpacker Aug 28, 2024
7d513d5
fix: this was the final step - we create the Message IDs ahead of tim…
cpacker Aug 28, 2024
6466da1
fix: patch Python REST client
cpacker Aug 28, 2024
af1989a
Merge branch 'main' into fix-streaming
cpacker Aug 28, 2024
4935daa
chore: clean stray prints from older PR
cpacker Aug 28, 2024
fc9b893
fix: make hack less disgusting, should be good to merge at this point
cpacker Aug 28, 2024
ffa1ff0
refactor: placate pylance and potentially fix pytest error by reverti…
cpacker Aug 28, 2024
8e65ba2
chore: cleanup, remove note + move maximum context length to constant…
cpacker Aug 29, 2024
e4f4e8d
add todo
cpacker Aug 29, 2024
a8cb4a1
fix: @4shub make streaming chunks reflect real time (though have same…
cpacker Aug 29, 2024
5755123
fix: fix typing + remove microseconds from responses
cpacker Aug 29, 2024
82ae1d9
fix: attempt to patch test
cpacker Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
)
from memgpt.schemas.passage import Passage
from memgpt.schemas.tool import Tool
from memgpt.system import (
Expand Down Expand Up @@ -441,9 +444,21 @@ def _get_ai_reply(
except Exception as e:
raise e

def _handle_ai_response(self, response_message: Message, override_tool_call_id: bool = True) -> Tuple[List[Message], bool, bool]:
def _handle_ai_response(
self,
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
override_tool_call_id: bool = True,
# If we are streaming, we needed to create a Message ID ahead of time,
# and now we want to use it in the creation of the Message object
# TODO figure out a cleaner way to do this
response_message_id: Optional[str] = None,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders FYI we are now passing response_message_id into handle_ai_response for the special case where we created the Message object before we started unpacking it / turning it into inner thoughts / actions / etc.

) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution"""

# Hacky failsafe for now to make sure we didn't implement the streaming Message ID creation incorrectly
if response_message_id is not None:
assert response_message_id.startswith("message-"), response_message_id
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders We can cut this out later, but I think it's fine to leave in for now while the streaming code is in flux (or at least until we add streaming unit tests that test for persistence of IDs that are streamed back


messages = [] # append these to the history when done

# Step 2: check if LLM wanted to call a function
Expand Down Expand Up @@ -474,6 +489,7 @@ def _handle_ai_response(self, response_message: Message, override_tool_call_id:
# NOTE: we're recreating the message here
# TODO should probably just overwrite the fields?
Message.dict_to_message(
id=response_message_id,
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
Expand Down Expand Up @@ -619,6 +635,7 @@ def _handle_ai_response(self, response_message: Message, override_tool_call_id:
# Standard non-function reply
messages.append(
Message.dict_to_message(
id=response_message_id,
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
Expand Down Expand Up @@ -765,7 +782,12 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
# (if yes) Step 5: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message.model_copy() # TODO why are we copying here?
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(
response_message,
# TODO this is kind of hacky, find a better way to handle this
# the only time we set up message creation ahead of time is when streaming is on
response_message_id=response.id if stream else None,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders translation:

If we're streaming (tokens), then we want to create a Message.id ahead of time so that the chunks we return via the API (and via the client once we add support) has ids attached to them.

However, this all happens before the MemGPT agent logic loop that takes a ChatCompletionResponse as input (which is the final result of a stream, not the intermediate result).

So that means we need to modify handle_ai_response to (in the streaming case) accept a pre-generated Message.id, and use it when we create the Message objects inside of handle_ai_response.

)

# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
Expand Down
4 changes: 2 additions & 2 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]:

# agent interactions

def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]:
def user_message(self, agent_id: str, message: str) -> MemGPTResponse:
return self.send_message(agent_id, message, role="user")

def save(self):
Expand Down Expand Up @@ -423,7 +423,7 @@ def send_message(
) -> MemGPTResponse:
messages = [MessageCreate(role=role, text=message, name=name)]
# TODO: figure out how to handle stream_steps and stream_tokens
request = MemGPTRequest(messages=messages, stream_steps=stream)
request = MemGPTRequest(messages=messages, stream_steps=stream, return_message_object=True)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders = True means the type that comes back is Message. = False means the type that comes back is InnerThoughts / FunctionCall / ... (these are now typed too, vs previously they were dicts)

response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
Expand Down
5 changes: 5 additions & 0 deletions memgpt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")

# String in the error message for when the context window is too large
# Example full message:
# This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING = "maximum context length"

# System prompt templating
IN_CONTEXT_MEMORY_KEYWORD = "CORE_MEMORY"

Expand Down
11 changes: 5 additions & 6 deletions memgpt/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import os
import random
import time
import uuid
import warnings
from typing import List, Optional, Union

import requests

from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.constants import CLI_WARNING_PREFIX, OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from memgpt.credentials import MemGPTCredentials
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import (
Expand Down Expand Up @@ -134,7 +133,7 @@ def is_context_overflow_error(exception: requests.exceptions.RequestException) -
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
from memgpt.utils import printd

match_string = "maximum context length"
match_string = OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING

# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
if match_string in str(exception):
Expand Down Expand Up @@ -231,9 +230,9 @@ def create(
# agent_state: AgentState,
llm_config: LLMConfig,
messages: List[Message],
user_id: uuid.UUID = None, # option UUID to associate request with
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
functions: list = None,
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
functions_python: list = None,
user_id: Optional[str] = None, # option UUID to associate request with
functions: Optional[list] = None,
functions_python: Optional[list] = None,
function_call: str = "auto",
# hint
first_message: bool = False,
Expand Down
48 changes: 38 additions & 10 deletions memgpt/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError

from memgpt.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from memgpt.errors import LLMError
from memgpt.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from memgpt.schemas.message import Message as _Message
from memgpt.schemas.message import MessageRole as _MessageRole
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionChunkResponse,
Expand All @@ -22,7 +26,7 @@
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from memgpt.utils import get_utc_time, smart_urljoin
from memgpt.utils import smart_urljoin

OPENAI_SSE_DONE = "[DONE]"

Expand Down Expand Up @@ -82,6 +86,8 @@ def openai_chat_completions_process_stream(
api_key: str,
chat_completion_request: ChatCompletionRequest,
stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.

Expand Down Expand Up @@ -114,13 +120,26 @@ def openai_chat_completions_process_stream(
model=chat_completion_request.model,
)

# Create a dummy Message object to get an ID and date
cpacker marked this conversation as resolved.
Show resolved Hide resolved
# TODO(sarah): add message ID generation function
dummy_message = _Message(
role=_MessageRole.assistant,
text="",
user_id="",
agent_id="",
model="",
name=None,
tool_calls=None,
tool_call_id=None,
)

TEMP_STREAM_RESPONSE_ID = "temp_id"
TEMP_STREAM_FINISH_REASON = "temp_null"
TEMP_STREAM_TOOL_CALL_ID = "temp_id"
chat_completion_response = ChatCompletionResponse(
id=TEMP_STREAM_RESPONSE_ID,
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
choices=[],
created=get_utc_time(),
created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time()
model=chat_completion_request.model,
usage=UsageStatistics(
completion_tokens=0,
Expand All @@ -138,11 +157,14 @@ def openai_chat_completions_process_stream(
openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request)
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
# print(chat_completion_chunk)

if stream_inferface:
if isinstance(stream_inferface, AgentChunkStreamingInterface):
stream_inferface.process_chunk(chat_completion_chunk)
stream_inferface.process_chunk(
chat_completion_chunk,
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
)
elif isinstance(stream_inferface, AgentRefreshStreamingInterface):
stream_inferface.process_refresh(chat_completion_response)
else:
Expand Down Expand Up @@ -209,10 +231,12 @@ def openai_chat_completions_process_stream(
raise NotImplementedError(f"Old function_call style not support with stream=True")

# overwrite response fields based on latest chunk
chat_completion_response.id = chat_completion_chunk.id
chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint
chat_completion_response.created = chat_completion_chunk.created
if not create_message_id:
chat_completion_response.id = chat_completion_chunk.id
if not create_message_datetime:
chat_completion_response.created = chat_completion_chunk.created
chat_completion_response.model = chat_completion_chunk.model
chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint

# increment chunk counter
n_chunks += 1
Expand All @@ -234,7 +258,8 @@ def openai_chat_completions_process_stream(
for c in chat_completion_response.choices
]
)
assert chat_completion_response.id != TEMP_STREAM_RESPONSE_ID
if not create_message_id:
assert chat_completion_response.id != dummy_message.id

# compute token usage before returning
# TODO try actually computing the #tokens instead of assuming the chunks is the same
Expand Down Expand Up @@ -263,7 +288,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionCh
response_dict = json.loads(response_bytes.decode("utf-8"))
error_message = response_dict["error"]["message"]
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
raise Exception(error_message)
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
raise LLMError(error_message)
except LLMError:
raise
except:
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
event_source.response.raise_for_status()
Expand Down
1 change: 1 addition & 0 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def summarize_messages(
llm_config=agent_state.llm_config,
user_id=agent_state.user_id,
messages=message_sequence,
stream=False,
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
)

printd(f"summarize_messages gpt reply: {response.choices[0]}")
Expand Down
1 change: 1 addition & 0 deletions memgpt/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class MessageRole(str, Enum):
assistant = "assistant"
user = "user"
tool = "tool"
function = "function"
system = "system"


Expand Down
39 changes: 36 additions & 3 deletions memgpt/schemas/memgpt_message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from datetime import datetime, timezone
from typing import Literal, Union
from typing import Literal, Optional, Union

from pydantic import BaseModel, field_serializer

Expand All @@ -12,7 +13,11 @@ class BaseMemGPTMessage(BaseModel):

@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
return dt.now(timezone.utc).isoformat()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@4shub @goetzrobin FYI this was a pretty bad bug that was previously causing all message streaming response chunks to have newly created timestamps

if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = dt.replace(tzinfo=timezone.utc)
# Remove microseconds since it seems like we're inconsistent with getting them
# TODO figure out why we don't always get microseconds (get_utc_time() does)
return dt.isoformat(timespec="seconds")


class InternalMonologue(BaseMemGPTMessage):
Expand All @@ -32,6 +37,20 @@ class FunctionCall(BaseModel):
arguments: str


class FunctionCallDelta(BaseModel):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders I didn't have to make a Delta / Chunk specific model for InnerMonologue since InnerMonologue has chunk support built in - InnerMonologue.arguments can just be partial pieces:

InnerMonologue.arguments = "hello there"

to

InnerMonologue.arguments = "hello "
InnerMonologue.arguments = "there"

However FunctionCall is problematic since at least with OpenAI API the stream back usually starts with just name, then chunks of the arguments:

FunctionCall.name: "send_message"
FunctionCall.arguments: "\{\ 'content': ...

to

FunctionCall.name: "send_message"
FunctionCall.arguments: "\{\ "
FunctionCall.arguments: "'content:'"
...

So we need a new Pydantic model that supports optional attributes when name is null or arguments is null (technically you should never have the case where both are null, but not sure how you set that up in Pydantic + probably not worth the hassle).

name: Optional[str]
arguments: Optional[str]

# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
return super().model_dump(*args, **kwargs)

def json(self, *args, **kwargs):
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)


class FunctionCallMessage(BaseMemGPTMessage):
"""
{
Expand All @@ -44,7 +63,21 @@ class FunctionCallMessage(BaseMemGPTMessage):
}
"""

function_call: FunctionCall
function_call: Union[FunctionCall, FunctionCallDelta]

# NOTE: this is required for the FunctionCallDelta exclude_none to work correctly
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
data = super().model_dump(*args, **kwargs)
if isinstance(data["function_call"], dict):
data["function_call"] = {k: v for k, v in data["function_call"].items() if v is not None}
return data

class Config:
json_encoders = {
FunctionCallDelta: lambda v: v.model_dump(exclude_none=True),
FunctionCall: lambda v: v.model_dump(exclude_none=True),
}


class FunctionReturn(BaseMemGPTMessage):
Expand Down
5 changes: 5 additions & 0 deletions memgpt/schemas/memgpt_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ class MemGPTRequest(BaseModel):
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)

return_message_object: bool = Field(
default=False,
description="Set True to return the raw Message object. Set False to return the Message in the format of the MemGPT API.",
)
Loading
Loading