Skip to content

Commit

Permalink
feat: Add MistralProvider (#1883)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Zhou <[email protected]>
  • Loading branch information
mattzh72 and Matt Zhou authored Oct 14, 2024
1 parent 6fc2fee commit 1b4773a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
47 changes: 47 additions & 0 deletions letta/llm_api/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import requests

from letta.utils import printd, smart_urljoin


def mistral_get_model_list(url: str, api_key: str) -> dict:
url = smart_urljoin(url, "models")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response_json = response.json() # convert to dict from string
return response_json
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
if response:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
if response:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
if response:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
44 changes: 44 additions & 0 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,50 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
return []


class MistralProvider(Provider):
name: str = "mistral"
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"

def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.mistral import mistral_get_model_list

# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
response = mistral_get_model_list(self.base_url, api_key=self.api_key)

assert "data" in response, f"Mistral model query response missing 'data' field: {response}"

configs = []
for model in response["data"]:
# If model has chat completions and function calling enabled
if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
)
)

return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# Not supported for mistral
return []

def get_model_context_window(self, model_name: str) -> Optional[int]:
# Redoing this is fine because it's a pretty lightweight call
models = self.list_llm_models()

for m in models:
if model_name in m["id"]:
return int(m["max_context_length"])

return None


class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint
Expand Down
17 changes: 14 additions & 3 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from letta.providers import (
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
)
Expand Down Expand Up @@ -33,10 +35,13 @@ def test_anthropic():
#


# TODO: Add this test
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
def test_azure():
pass
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
models = provider.list_llm_models()
print([m.model for m in models])

embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])


def test_ollama():
Expand All @@ -60,6 +65,12 @@ def test_googleai():
provider.list_embedding_models()


def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
models = provider.list_llm_models()
print([m.model for m in models])


# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
Expand Down

0 comments on commit 1b4773a

Please sign in to comment.