From 89d6c7d7be9f25f0336ebfadb03cd104bfd6429e Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 30 Aug 2024 15:15:56 -0700 Subject: [PATCH] fix: various fixes to patch anthropic integration (#1698) --- .github/workflows/test_anthropic.yml | 37 ++++++++++++++++++++++++ configs/llm_model_configs/anthropic.json | 7 +++++ memgpt/cli/cli.py | 21 ++++++++++++++ memgpt/configs/anthropic.json | 13 +++++++++ memgpt/llm_api/anthropic.py | 11 +++++-- memgpt/schemas/message.py | 8 ++--- tests/test_endpoints.py | 5 ++++ 7 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/test_anthropic.yml create mode 100644 configs/llm_model_configs/anthropic.json create mode 100644 memgpt/configs/anthropic.json diff --git a/.github/workflows/test_anthropic.yml b/.github/workflows/test_anthropic.yml new file mode 100644 index 0000000000..ffb22d0653 --- /dev/null +++ b/.github/workflows/test_anthropic.yml @@ -0,0 +1,37 @@ +name: Endpoint (Anthropic) + +env: + OPENAI_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: "Setup Python, Poetry and Dependencies" + uses: packetcoders/action-setup-cache-python-poetry@main + with: + python-version: "3.12" + poetry-version: "1.8.2" + install-args: "-E dev" + + - name: Initialize credentials + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + run: | + poetry run memgpt quickstart --backend anthropic + + - name: Test LLM endpoint + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_anthropic diff --git a/configs/llm_model_configs/anthropic.json b/configs/llm_model_configs/anthropic.json new file mode 100644 index 0000000000..6281aa9644 --- /dev/null +++ b/configs/llm_model_configs/anthropic.json @@ -0,0 +1,7 @@ +{ + "context_window": 200000, + "model": "claude-3-opus-20240229", + "model_endpoint_type": "anthropic", + "model_endpoint": "https://api.anthropic.com/v1", + "model_wrapper": null +} diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 96a72f4fc0..2847aa3b9b 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -40,6 +40,7 @@ class QuickstartChoice(Enum): openai = "openai" # azure = "azure" memgpt_hosted = "memgpt" + anthropic = "anthropic" def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice: @@ -229,6 +230,26 @@ def quickstart( typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED) return + elif backend == QuickstartChoice.anthropic: + # Make sure we have an API key + api_key = os.getenv("ANTHROPIC_API_KEY") + while api_key is None or len(api_key) == 0: + # Ask for API key as input + api_key = questionary.password("Enter your Anthropic API key:").ask() + credentials.anthropic_key = api_key + credentials.save() + + script_dir = os.path.dirname(__file__) # Get the directory where the script is located + backup_config_path = os.path.join(script_dir, "..", "configs", "anthropic.json") + try: + with open(backup_config_path, "r", encoding="utf-8") as file: + backup_config = json.load(file) + printd("Loaded config file successfully.") + new_config, config_was_modified = set_config_with_dict(backup_config) + except FileNotFoundError: + typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED) + return + else: raise NotImplementedError(backend) diff --git a/memgpt/configs/anthropic.json b/memgpt/configs/anthropic.json new file mode 100644 index 0000000000..a1eb92f2de --- /dev/null +++ b/memgpt/configs/anthropic.json @@ -0,0 +1,13 @@ +{ + "context_window": 200000, + "model": "claude-3-opus-20240229", + "model_endpoint_type": "anthropic", + "model_endpoint": "https://api.anthropic.com/v1", + "model_wrapper": null, + "embedding_endpoint_type": "hugging-face", + "embedding_endpoint": "https://embeddings.memgpt.ai", + "embedding_model": "BAAI/bge-large-en-v1.5", + "embedding_dim": 1024, + "embedding_chunk_size": 300 + +} diff --git a/memgpt/llm_api/anthropic.py b/memgpt/llm_api/anthropic.py index cdaba1aff0..bcb1a58a75 100644 --- a/memgpt/llm_api/anthropic.py +++ b/memgpt/llm_api/anthropic.py @@ -1,6 +1,5 @@ import json import re -import uuid from typing import List, Optional, Union import requests @@ -256,7 +255,7 @@ def convert_anthropic_response_to_chatcompletion( type="function", function=FunctionCall( name=response_json["content"][1]["name"], - arguments=json_dumps(response_json["content"][1]["input"]), + arguments=json.dumps(response_json["content"][1]["input"], indent=2), ), ) ] @@ -330,8 +329,14 @@ def anthropic_chat_completions_request( data["system"] = data["messages"][0]["content"] data["messages"] = data["messages"][1:] + # set `content` to None if missing + for message in data["messages"]: + if "content" not in message: + message["content"] = None + # Convert to Anthropic format - msg_objs = [Message.dict_to_message(user_id=uuid.uuid4(), agent_id=uuid.uuid4(), openai_message_dict=m) for m in data["messages"]] + + msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]] data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs] # Handling Anthropic special requirement for 'user' message in front diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index 222b83bf11..26ae3c0e18 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -62,8 +62,8 @@ class Message(BaseMessage): id: str = BaseMessage.generate_id_field() role: MessageRole = Field(..., description="The role of the participant.") text: Optional[str] = Field(None, description="The text of the message.") - user_id: str = Field(None, description="The unique identifier of the user.") - agent_id: str = Field(None, description="The unique identifier of the agent.") + user_id: Optional[str] = Field(None, description="The unique identifier of the user.") + agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") model: Optional[str] = Field(None, description="The model used to make the function call.") name: Optional[str] = Field(None, description="The name of the participant.") created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.") @@ -367,8 +367,8 @@ def add_xml_tag(string: str, xml_tag: Optional[str]): { "type": "tool_use", "id": tool_call.id, - "name": tool_call.function["name"], - "input": json.loads(tool_call.function["arguments"]), + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), } ) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index a9b7192f9a..424490a083 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -103,3 +103,8 @@ def test_llm_endpoint_ollama(): def test_embedding_endpoint_ollama(): filename = os.path.join(embedding_config_dir, "ollama.json") run_embedding_endpoint(filename) + + +def test_llm_endpoint_anthropic(): + filename = os.path.join(llm_config_dir, "anthropic.json") + run_llm_endpoint(filename)