Skip to content

Commit

Permalink
refactor: make Agent.step() multi-step, and rename Agent.step() to Ag…
Browse files Browse the repository at this point in the history
…ent.inner_step() ie the single step version
  • Loading branch information
cpacker committed Oct 14, 2024
1 parent 1b4773a commit e49039a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 85 deletions.
108 changes: 103 additions & 5 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from letta.constants import (
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
IN_CONTEXT_MEMORY_KEYWORD,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import LLMError
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
Expand All @@ -32,11 +36,15 @@
from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
)
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.passage import Passage
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
get_login_event,
get_token_limit_warning,
package_function_response,
package_summarize_message,
package_user_message,
Expand All @@ -56,9 +64,6 @@
verify_first_message_correctness,
)

from .errors import LLMError
from .llm_api.helpers import is_context_overflow_error


def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime,
Expand Down Expand Up @@ -202,7 +207,7 @@ class BaseAgent(ABC):
def step(
self,
messages: Union[Message, List[Message]],
) -> AgentStepResponse:
) -> LettaUsageStatistics:
"""
Top-level event message handler for the agent.
"""
Expand Down Expand Up @@ -721,6 +726,99 @@ def _handle_ai_response(
return messages, heartbeat_request, function_failed

def step(
self,
messages: Union[Message, List[Message]],
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
stream: bool = False,
ms: Optional[MetadataStore] = None,
skip_verify: bool = False,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
assert ms is not None, "MetadataStore is required"

next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
step_response = self.inner_step(
messages=next_input_message,
first_message=False,
skip_verify=skip_verify,
return_dicts=False,
stream=stream,
ms=ms,
**kwargs,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage

step_count += 1
total_usage += usage
counter += 1
self.interface.step_complete()

# logger.debug("Saving agent state")
# save updated state
save_agent(self, ms)

# Chain stops
if not chaining:
printd("No chaining, stopping after one step")
break
elif max_chaining_steps is not None and counter > max_chaining_steps:
printd(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)

def inner_step(
self,
messages: Union[Message, List[Message]],
first_message: bool = False,
Expand All @@ -732,7 +830,7 @@ def step(
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""Top-level event message handler for the Letta agent"""
"""Runs a single step in the agent loop (generates at most one LLM call)"""

try:

Expand Down
92 changes: 12 additions & 80 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
Expand Down Expand Up @@ -404,6 +403,7 @@ def _step(
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")

logger.debug(f"Got input messages: {input_messages}")
letta_agent = None
try:

# Get the agent object (loaded in memory)
Expand All @@ -415,93 +415,25 @@ def _step(
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False

logger.debug(f"Starting agent step")
no_verify = True
next_input_message = input_messages
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
step_response = letta_agent.step(
messages=next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
stream=token_streaming,
# timestamp=timestamp,
ms=self.ms,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage

step_count += 1
total_usage += usage
counter += 1
letta_agent.interface.step_complete()

logger.debug("Saving agent state")
# save updated state
save_agent(letta_agent, self.ms)

# Chain stops
if not self.chaining:
logger.debug("No chaining, stopping after one step")
break
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break
usage_stats = letta_agent.step(
messages=input_messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
stream=token_streaming,
ms=self.ms,
skip_verify=True,
)

except Exception as e:
logger.error(f"Error in server._step: {e}")
print(traceback.print_exc())
raise
finally:
logger.debug("Calling step_yield()")
letta_agent.interface.step_yield()
if letta_agent:
letta_agent.interface.step_yield()

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
return usage_stats

def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
"""Process a CLI command"""
Expand Down

0 comments on commit e49039a

Please sign in to comment.