-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: patch streaming API code #1693
Changes from all commits
4ffee3d
0fef357
0e4a861
7b219ff
9983af7
fcbc428
eef77c0
21f501b
1fe6b75
f2e1766
c9e001a
524a4d6
9bb3224
0a42461
7d513d5
6466da1
af1989a
4935daa
fc9b893
ffa1ff0
8e65ba2
e4f4e8d
a8cb4a1
5755123
82ae1d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,9 @@ | |
from memgpt.schemas.memory import Memory | ||
from memgpt.schemas.message import Message | ||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse | ||
from memgpt.schemas.openai.chat_completion_response import ( | ||
Message as ChatCompletionMessage, | ||
) | ||
from memgpt.schemas.passage import Passage | ||
from memgpt.schemas.tool import Tool | ||
from memgpt.system import ( | ||
|
@@ -441,9 +444,21 @@ def _get_ai_reply( | |
except Exception as e: | ||
raise e | ||
|
||
def _handle_ai_response(self, response_message: Message, override_tool_call_id: bool = True) -> Tuple[List[Message], bool, bool]: | ||
def _handle_ai_response( | ||
self, | ||
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function? | ||
override_tool_call_id: bool = True, | ||
# If we are streaming, we needed to create a Message ID ahead of time, | ||
# and now we want to use it in the creation of the Message object | ||
# TODO figure out a cleaner way to do this | ||
response_message_id: Optional[str] = None, | ||
) -> Tuple[List[Message], bool, bool]: | ||
"""Handles parsing and function execution""" | ||
|
||
# Hacky failsafe for now to make sure we didn't implement the streaming Message ID creation incorrectly | ||
if response_message_id is not None: | ||
assert response_message_id.startswith("message-"), response_message_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sarahwooders We can cut this out later, but I think it's fine to leave in for now while the streaming code is in flux (or at least until we add streaming unit tests that test for persistence of IDs that are streamed back |
||
|
||
messages = [] # append these to the history when done | ||
|
||
# Step 2: check if LLM wanted to call a function | ||
|
@@ -474,6 +489,7 @@ def _handle_ai_response(self, response_message: Message, override_tool_call_id: | |
# NOTE: we're recreating the message here | ||
# TODO should probably just overwrite the fields? | ||
Message.dict_to_message( | ||
id=response_message_id, | ||
agent_id=self.agent_state.id, | ||
user_id=self.agent_state.user_id, | ||
model=self.model, | ||
|
@@ -619,6 +635,7 @@ def _handle_ai_response(self, response_message: Message, override_tool_call_id: | |
# Standard non-function reply | ||
messages.append( | ||
Message.dict_to_message( | ||
id=response_message_id, | ||
agent_id=self.agent_state.id, | ||
user_id=self.agent_state.user_id, | ||
model=self.model, | ||
|
@@ -765,7 +782,12 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: | |
# (if yes) Step 5: send the info on the function call and function response to LLM | ||
response_message = response.choices[0].message | ||
response_message.model_copy() # TODO why are we copying here? | ||
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) | ||
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response( | ||
response_message, | ||
# TODO this is kind of hacky, find a better way to handle this | ||
# the only time we set up message creation ahead of time is when streaming is on | ||
response_message_id=response.id if stream else None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sarahwooders translation: If we're streaming (tokens), then we want to create a However, this all happens before the MemGPT agent logic loop that takes a So that means we need to modify |
||
) | ||
|
||
# Add the extra metadata to the assistant response | ||
# (e.g. enough metadata to enable recreating the API call) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,7 +375,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]: | |
|
||
# agent interactions | ||
|
||
def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: | ||
def user_message(self, agent_id: str, message: str) -> MemGPTResponse: | ||
return self.send_message(agent_id, message, role="user") | ||
|
||
def save(self): | ||
|
@@ -423,7 +423,7 @@ def send_message( | |
) -> MemGPTResponse: | ||
messages = [MessageCreate(role=role, text=message, name=name)] | ||
# TODO: figure out how to handle stream_steps and stream_tokens | ||
request = MemGPTRequest(messages=messages, stream_steps=stream) | ||
request = MemGPTRequest(messages=messages, stream_steps=stream, return_message_object=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sarahwooders |
||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers) | ||
if response.status_code != 200: | ||
raise ValueError(f"Failed to send message: {response.text}") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import json | ||
from datetime import datetime, timezone | ||
from typing import Literal, Union | ||
from typing import Literal, Optional, Union | ||
|
||
from pydantic import BaseModel, field_serializer | ||
|
||
|
@@ -12,7 +13,11 @@ class BaseMemGPTMessage(BaseModel): | |
|
||
@field_serializer("date") | ||
def serialize_datetime(self, dt: datetime, _info): | ||
return dt.now(timezone.utc).isoformat() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @4shub @goetzrobin FYI this was a pretty bad bug that was previously causing all message streaming response chunks to have newly created timestamps |
||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: | ||
dt = dt.replace(tzinfo=timezone.utc) | ||
# Remove microseconds since it seems like we're inconsistent with getting them | ||
# TODO figure out why we don't always get microseconds (get_utc_time() does) | ||
return dt.isoformat(timespec="seconds") | ||
|
||
|
||
class InternalMonologue(BaseMemGPTMessage): | ||
|
@@ -32,6 +37,20 @@ class FunctionCall(BaseModel): | |
arguments: str | ||
|
||
|
||
class FunctionCallDelta(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sarahwooders I didn't have to make a
to
However
to
So we need a new Pydantic model that supports optional attributes when name is null or arguments is null (technically you should never have the case where both are null, but not sure how you set that up in Pydantic + probably not worth the hassle). |
||
name: Optional[str] | ||
arguments: Optional[str] | ||
|
||
# NOTE: this is a workaround to exclude None values from the JSON dump, | ||
# since the OpenAI style of returning chunks doesn't include keys with null values | ||
def model_dump(self, *args, **kwargs): | ||
kwargs["exclude_none"] = True | ||
return super().model_dump(*args, **kwargs) | ||
|
||
def json(self, *args, **kwargs): | ||
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs) | ||
|
||
|
||
class FunctionCallMessage(BaseMemGPTMessage): | ||
""" | ||
{ | ||
|
@@ -44,7 +63,21 @@ class FunctionCallMessage(BaseMemGPTMessage): | |
} | ||
""" | ||
|
||
function_call: FunctionCall | ||
function_call: Union[FunctionCall, FunctionCallDelta] | ||
|
||
# NOTE: this is required for the FunctionCallDelta exclude_none to work correctly | ||
def model_dump(self, *args, **kwargs): | ||
kwargs["exclude_none"] = True | ||
data = super().model_dump(*args, **kwargs) | ||
if isinstance(data["function_call"], dict): | ||
data["function_call"] = {k: v for k, v in data["function_call"].items() if v is not None} | ||
return data | ||
|
||
class Config: | ||
json_encoders = { | ||
FunctionCallDelta: lambda v: v.model_dump(exclude_none=True), | ||
FunctionCall: lambda v: v.model_dump(exclude_none=True), | ||
} | ||
|
||
|
||
class FunctionReturn(BaseMemGPTMessage): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sarahwooders FYI we are now passing
response_message_id
intohandle_ai_response
for the special case where we created theMessage
object before we started unpacking it / turning it into inner thoughts / actions / etc.