Skip to content

Commit

Permalink
fix: various fixes to enable LocalClient creation (#1724)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Sep 6, 2024
1 parent 9646b53 commit 7eba82a
Show file tree
Hide file tree
Showing 4 changed files with 1,521 additions and 1,395 deletions.
80 changes: 43 additions & 37 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import httpx

from memgpt.config import MemGPTConfig
from memgpt.log import get_logger
from memgpt.constants import BASE_TOOLS
from memgpt.settings import settings
from memgpt.data_sources.connectors import DataConnector
from memgpt.functions.functions import parse_source_code
from memgpt.log import get_logger
from memgpt.memory import get_memory_functions

# This is a hack for now, should be using new schemas
from memgpt.server.schemas.humans import ListHumansResponse
from memgpt.server.schemas.personas import ListPersonasResponse
from memgpt.server.schemas.config import ConfigResponse
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState

# new schemas
from memgpt.schemas.block import Human, Persona
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.schemas.block import (
Block,
CreateBlock,
Expand All @@ -40,35 +34,37 @@
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.memory import (
ArchivalMemorySummary,
ChatMemory,
BlockChatMemory,
Memory,
RecallMemorySummary,
)
from memgpt.schemas.job import Job
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.message import (
Message,
MessageCreate
)
from memgpt.schemas.message import Message, MessageCreate
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.schemas.user import UserCreate
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.schemas.config import ConfigResponse

# This is a hack for now, should be using new schemas
from memgpt.server.schemas.humans import ListHumansResponse
from memgpt.server.schemas.personas import ListPersonasResponse
from memgpt.server.server import SyncServer
from memgpt.settings import settings
from memgpt.utils import get_human_text, get_persona_text

if TYPE_CHECKING:
from httpx import ASGITransport, WSGITransport

logger = get_logger(__name__)

def create_client(base_url: Optional[str] = None,
token: Optional[str] = None,
config: Optional[MemGPTConfig] = None,
app: Optional[str] = None,
debug: Optional[bool] = False) -> Union["RESTClient", "LocalClient"]:

def create_client(
base_url: Optional[str] = None,
token: Optional[str] = None,
config: Optional[MemGPTConfig] = None,
app: Optional[str] = None,
debug: Optional[bool] = False,
) -> Union["RESTClient", "LocalClient"]:
"""factory method to create either a local or rest api enabled client.
_TODO: link to docs on the difference between the two._
Expand Down Expand Up @@ -238,7 +234,7 @@ def __init__(
base_url: str,
token: str,
debug: bool = False,
app: Optional[Union["WSGITransport","ASGITransport"]] = None,
app: Optional[Union["WSGITransport", "ASGITransport"]] = None,
):
super().__init__(debug=debug)
httpx_client_args = {
Expand Down Expand Up @@ -266,7 +262,6 @@ async def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optiona
return True
return False


def get_tool(self, tool_name: str):
response = self.httpx_client.get(f"/tools/{tool_name}/")
if response.status_code != 200:
Expand All @@ -280,7 +275,12 @@ async def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: Memory = BlockChatMemory(blocks=[Block(name="human block", value=get_human_text(settings.human), label="human"), Block(name="persona block", value=get_persona_text(settings.persona), label="persona")]),
memory: Memory = BlockChatMemory(
blocks=[
Block(name="human block", value=get_human_text(settings.human), label="human"),
Block(name="persona block", value=get_persona_text(settings.persona), label="persona"),
]
),
# tools
tools: Optional[List[str]] = None,
include_base_tools: Optional[bool] = True,
Expand Down Expand Up @@ -327,12 +327,8 @@ async def create_agent(

return AgentState(**response.json())


async def rename_agent(self, agent_id: str, new_name: str):
response = await self.httpx_client.patch(f"/agents/{agent_id}/rename/", json={"agent_name": new_name})
assert response.status_code == 200, f"Failed to rename agent: {response.text}"

return AgentState(**response.json())
return await self.update_agent(agent_id, name=new_name)

async def update_agent(
self,
Expand Down Expand Up @@ -466,7 +462,9 @@ async def get_messages(
return [Message(**message) for message in response.json()]

async def send_message(self, agent_id: str, message: str, role: str, stream: Optional[bool] = False) -> MemGPTResponse:
request = MemGPTRequest(messages=[MessageCreate(text=message, role=role)], run_async=False, stream_steps=stream, stream_tokens=stream)
request = MemGPTRequest(
messages=[MessageCreate(text=message, role=role)], run_async=False, stream_steps=stream, stream_tokens=stream
)
response = await self.httpx_client.post(f"/agents/{agent_id}/messages", json=request.model_dump(exclude_none=True))
if response.status_code != 200:
raise ValueError(f"Failed to send message: {response.text}")
Expand Down Expand Up @@ -741,10 +739,11 @@ async def create_tool(

# make REST request
request = ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags)
response = await self.httpx_client.post("/tools/",
json=request.model_dump(exclude_none=True),
params={"update": update},
)
response = await self.httpx_client.post(
"/tools/",
json=request.model_dump(exclude_none=True),
params={"update": update},
)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
return Tool(**response.json())
Expand Down Expand Up @@ -818,6 +817,8 @@ def __init__(

self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)
self.user_id = self.server.get_current_user()
print(f"User ID: {self.user_id}")

if user_id:
self.user_id = user_id
Expand Down Expand Up @@ -847,7 +848,12 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: Memory = BlockChatMemory(blocks=[Block(name="human block", value=get_human_text(settings.human), label="human"), Block(name="persona block", value=get_persona_text(settings.persona), label="persona")]),
memory: Memory = BlockChatMemory(
blocks=[
Block(name="human block", value=get_human_text(settings.human), label="human"),
Block(name="persona block", value=get_persona_text(settings.persona), label="persona"),
]
),
# system
system: Optional[str] = None,
# tools
Expand Down
Loading

0 comments on commit 7eba82a

Please sign in to comment.