Skip to content

Commit

Permalink
fix: various fixes to patch anthropic integration (#1698)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Aug 30, 2024
1 parent 7e4ca06 commit 89d6c7d
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 7 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/test_anthropic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Endpoint (Anthropic)

env:
OPENAI_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v4

- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev"

- name: Initialize credentials
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
poetry run memgpt quickstart --backend anthropic
- name: Test LLM endpoint
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_anthropic
7 changes: 7 additions & 0 deletions configs/llm_model_configs/anthropic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"context_window": 200000,
"model": "claude-3-opus-20240229",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null
}
21 changes: 21 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class QuickstartChoice(Enum):
openai = "openai"
# azure = "azure"
memgpt_hosted = "memgpt"
anthropic = "anthropic"


def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice:
Expand Down Expand Up @@ -229,6 +230,26 @@ def quickstart(
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return

elif backend == QuickstartChoice.anthropic:
# Make sure we have an API key
api_key = os.getenv("ANTHROPIC_API_KEY")
while api_key is None or len(api_key) == 0:
# Ask for API key as input
api_key = questionary.password("Enter your Anthropic API key:").ask()
credentials.anthropic_key = api_key
credentials.save()

script_dir = os.path.dirname(__file__) # Get the directory where the script is located
backup_config_path = os.path.join(script_dir, "..", "configs", "anthropic.json")
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded config file successfully.")
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return

else:
raise NotImplementedError(backend)

Expand Down
13 changes: 13 additions & 0 deletions memgpt/configs/anthropic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"context_window": 200000,
"model": "claude-3-opus-20240229",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"embedding_endpoint_type": "hugging-face",
"embedding_endpoint": "https://embeddings.memgpt.ai",
"embedding_model": "BAAI/bge-large-en-v1.5",
"embedding_dim": 1024,
"embedding_chunk_size": 300

}
11 changes: 8 additions & 3 deletions memgpt/llm_api/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import re
import uuid
from typing import List, Optional, Union

import requests
Expand Down Expand Up @@ -256,7 +255,7 @@ def convert_anthropic_response_to_chatcompletion(
type="function",
function=FunctionCall(
name=response_json["content"][1]["name"],
arguments=json_dumps(response_json["content"][1]["input"]),
arguments=json.dumps(response_json["content"][1]["input"], indent=2),
),
)
]
Expand Down Expand Up @@ -330,8 +329,14 @@ def anthropic_chat_completions_request(
data["system"] = data["messages"][0]["content"]
data["messages"] = data["messages"][1:]

# set `content` to None if missing
for message in data["messages"]:
if "content" not in message:
message["content"] = None

# Convert to Anthropic format
msg_objs = [Message.dict_to_message(user_id=uuid.uuid4(), agent_id=uuid.uuid4(), openai_message_dict=m) for m in data["messages"]]

msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]]
data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs]

# Handling Anthropic special requirement for 'user' message in front
Expand Down
8 changes: 4 additions & 4 deletions memgpt/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.")
user_id: str = Field(None, description="The unique identifier of the user.")
agent_id: str = Field(None, description="The unique identifier of the agent.")
user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
name: Optional[str] = Field(None, description="The name of the participant.")
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
Expand Down Expand Up @@ -367,8 +367,8 @@ def add_xml_tag(string: str, xml_tag: Optional[str]):
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function["name"],
"input": json.loads(tool_call.function["arguments"]),
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments),
}
)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,8 @@ def test_llm_endpoint_ollama():
def test_embedding_endpoint_ollama():
filename = os.path.join(embedding_config_dir, "ollama.json")
run_embedding_endpoint(filename)


def test_llm_endpoint_anthropic():
filename = os.path.join(llm_config_dir, "anthropic.json")
run_llm_endpoint(filename)

0 comments on commit 89d6c7d

Please sign in to comment.