Skip to content

Commit

Permalink
refactor: simplify Agent.step inputs to Message or `List[Message]…
Browse files Browse the repository at this point in the history
…` only (#1879)
  • Loading branch information
cpacker authored Oct 14, 2024
1 parent 1842fd1 commit cc616ef
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 85 deletions.
116 changes: 51 additions & 65 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_login_event,
package_function_response,
package_summarize_message,
package_user_message,
)
from letta.utils import (
count_tokens,
Expand Down Expand Up @@ -200,16 +201,7 @@ class BaseAgent(ABC):
@abstractmethod
def step(
self,
messages: Union[Message, List[Message], str], # TODO deprecate str inputs
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
messages: Union[Message, List[Message]],
) -> AgentStepResponse:
"""
Top-level event message handler for the agent.
Expand Down Expand Up @@ -730,14 +722,13 @@ def _handle_ai_response(

def step(
self,
user_message: Union[Message, None, str], # NOTE: should be json.dump(dict)
messages: Union[Message, List[Message]],
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True,
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
Expand All @@ -760,50 +751,13 @@ def step(
self.rebuild_memory(force=True, ms=ms)

# Step 1: add user message
if user_message is not None:
if isinstance(user_message, Message):
assert user_message.text is not None

# Validate JSON via save/load
user_message_text = validate_json(user_message.text)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text)

if name is not None:
# Update Message object
user_message.text = cleaned_user_message_text
user_message.name = name
if isinstance(messages, Message):
messages = [messages]

# Recreate timestamp
if recreate_message_timestamp:
user_message.created_at = get_utc_time()
if not all(isinstance(m, Message) for m in messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")

elif isinstance(user_message, str):
# Validate JSON via save/load
user_message = validate_json(user_message)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)

# If user_message['name'] is not None, it will be handled properly by dict_to_message
# So no need to run strip_name_field_from_user_message

# Create the associated Message object (in the database)
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={"role": "user", "content": cleaned_user_message_text, "name": name},
created_at=timestamp,
)

else:
raise ValueError(f"Bad type for user_message: {type(user_message)}")

self.interface.user_message(user_message.text, msg_obj=user_message)

input_message_sequence = self._messages + [user_message]

# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self._messages
input_message_sequence = self._messages + messages

if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
Expand Down Expand Up @@ -846,11 +800,8 @@ def step(
)

# Step 6: extend the message history
if user_message is not None:
if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages
else:
raise ValueError(type(user_message))
if len(messages) > 0:
all_new_messages = messages + all_response_messages
else:
all_new_messages = all_response_messages

Expand Down Expand Up @@ -897,7 +848,7 @@ def step(
)

except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
printd(f"step() failed\nmessages = {messages}\nerror = {e}")

# If we got a context alert, try trimming the messages length, then try again
if is_context_overflow_error(e):
Expand All @@ -906,14 +857,14 @@ def step(

# Try step again
return self.step(
user_message,
messages=messages,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
return_dicts=return_dicts,
recreate_message_timestamp=recreate_message_timestamp,
# recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
timestamp=timestamp,
# timestamp=timestamp,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)
Expand All @@ -922,6 +873,40 @@ def step(
printd(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e

def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
"""Takes a basic user message string, turns it into a stringified JSON with extra metadata, then sends it to the agent
Example:
-> user_message_str = 'hi'
-> {'message': 'hi', 'type': 'user_message', ...}
-> json.dumps(...)
-> agent.step(messages=[Message(role='user', text=...)])
"""
# Wrap with metadata, dumps to JSON
assert user_message_str and isinstance(
user_message_str, str
), f"user_message_str should be a non-empty string, got {type(user_message_str)}"
user_message_json_str = package_user_message(user_message_str)

# Validate JSON via save/load
user_message = validate_json(user_message_json_str)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)

# Turn into a dict
openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name}

# Create the associated Message object (in the database)
assert self.agent_state.user_id is not None, "User ID is not set"
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=openai_message_dict,
# created_at=timestamp,
)

return self.step(messages=[user_message], **kwargs)

def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"

Expand Down Expand Up @@ -1340,7 +1325,8 @@ def retry_message(self) -> List[Message]:

self.pop_until_user()
user_message = self.pop_message(count=1)[0]
step_response = self.step(user_message=user_message.text, return_dicts=False)
assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False)
messages = step_response.messages

assert messages is not None
Expand Down
28 changes: 19 additions & 9 deletions letta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,19 +356,29 @@ def run_agent_loop(
else:
# If message did not begin with command prefix, pass inputs to Letta
# Handle user message and append to messages
user_message = system.package_user_message(user_input)
user_message = str(user_input)

skip_next_user_input = False

def process_agent_step(user_message, no_verify):
step_response = letta_agent.step(
user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
if user_message is None:
step_response = letta_agent.step(
messages=[],
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
else:
step_response = letta_agent.step_user_message(
user_message_str=user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
new_messages = step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
Expand Down
62 changes: 51 additions & 11 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,22 @@ def _get_or_load_agent(self, agent_id: str) -> Agent:
letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return letta_agent

def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], timestamp: Optional[datetime]) -> LettaUsageStatistics:
def _step(
self,
user_id: str,
agent_id: str,
input_messages: Union[Message, List[Message]],
# timestamp: Optional[datetime],
) -> LettaUsageStatistics:
"""Send the input message through the agent"""
logger.debug(f"Got input message: {input_message}")

# Input validation
if isinstance(input_messages, Message):
input_messages = [input_messages]
if not all(isinstance(m, Message) for m in input_messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")

logger.debug(f"Got input messages: {input_messages}")
try:

# Get the agent object (loaded in memory)
Expand All @@ -398,18 +411,18 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message],

logger.debug(f"Starting agent step")
no_verify = True
next_input_message = input_message
next_input_message = input_messages
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
step_response = letta_agent.step(
next_input_message,
messages=next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
stream=token_streaming,
timestamp=timestamp,
# timestamp=timestamp,
ms=self.ms,
)
step_response.messages
Expand All @@ -436,13 +449,40 @@ def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message],
break
# Chain handlers
elif token_warning:
next_input_message = system.get_token_limit_warning()
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
Expand Down Expand Up @@ -621,7 +661,7 @@ def user_message(
)

# Run the agent state forward
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=message, timestamp=timestamp)
usage = self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
return usage

def system_message(
Expand Down Expand Up @@ -669,7 +709,7 @@ def system_message(

if isinstance(message, Message):
# Can't have a null text field
if len(message.text) == 0 or message.text is None:
if message.text is None or len(message.text) == 0:
raise ValueError(f"Invalid input: '{message.text}'")
# If the input begins with a command prefix, reject
elif message.text.startswith("/"):
Expand All @@ -683,7 +723,7 @@ def system_message(
message.created_at = timestamp

# Run the agent state forward
return self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message, timestamp=timestamp)
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message)

# @LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
Expand Down

0 comments on commit cc616ef

Please sign in to comment.