From 32fbd71467ed6dc82647d0045feeb930a636c70a Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 11 Oct 2024 15:58:12 -0700 Subject: [PATCH] feat: add `VLLMProvider` (#1866) Co-authored-by: cpacker --- letta/cli/cli.py | 20 +------- letta/llm_api/llm_api_tools.py | 4 ++ letta/llm_api/openai.py | 10 ++-- letta/local_llm/vllm/api.py | 2 +- letta/providers.py | 81 +++++++++++++++++++++++++++---- letta/schemas/embedding_config.py | 7 +++ letta/schemas/llm_config.py | 7 +++ letta/server/server.py | 20 ++++++-- letta/settings.py | 10 +++- tests/test_providers.py | 8 +++ 10 files changed, 132 insertions(+), 37 deletions(-) diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 160615b7b6..04dbf359a4 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -14,9 +14,7 @@ from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger from letta.metadata import MetadataStore -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import OptionState -from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory, Memory from letta.server.server import logger as server_logger @@ -235,12 +233,7 @@ def run( # choose from list of llm_configs llm_configs = client.list_llm_configs() llm_options = [llm_config.model for llm_config in llm_configs] - - # TODO move into LLMConfig as a class method? - def prettify_llm_config(llm_config: LLMConfig) -> str: - return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else "" - - llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs] + llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs] # select model if len(llm_options) == 0: @@ -255,17 +248,8 @@ def prettify_llm_config(llm_config: LLMConfig) -> str: embedding_configs = client.list_embedding_configs() embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - # TODO move into EmbeddingConfig as a class method? - def prettify_embed_config(embedding_config: EmbeddingConfig) -> str: - return ( - f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})" - if embedding_config.embedding_endpoint - else "" - ) - embedding_choices = [ - questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config) - for embedding_config in embedding_configs + questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs ] # select model diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 50166a1ceb..9864fafe18 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -70,6 +70,10 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) except requests.exceptions.HTTPError as http_err: + + if not hasattr(http_err, "response") or not http_err.response: + raise + # Retry on specified errors if http_err.response.status_code in error_codes: # Increment retries diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 753e7c22aa..3d203fe2c6 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -61,6 +61,7 @@ def openai_get_model_list( headers["Authorization"] = f"Bearer {api_key}" printd(f"Sending request to {url}") + response = None try: # TODO add query param "tool" to be true response = requests.get(url, headers=headers, params=extra_params) @@ -71,7 +72,8 @@ def openai_get_model_list( except requests.exceptions.HTTPError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) try: - response = response.json() + if response: + response = response.json() except: pass printd(f"Got HTTPError, exception={http_err}, response={response}") @@ -79,7 +81,8 @@ def openai_get_model_list( except requests.exceptions.RequestException as req_err: # Handle other requests-related errors (e.g., connection error) try: - response = response.json() + if response: + response = response.json() except: pass printd(f"Got RequestException, exception={req_err}, response={response}") @@ -87,7 +90,8 @@ def openai_get_model_list( except Exception as e: # Handle other potential errors try: - response = response.json() + if response: + response = response.json() except: pass printd(f"Got unknown Exception, exception={e}, response={response}") diff --git a/letta/local_llm/vllm/api.py b/letta/local_llm/vllm/api.py index 102b9606d1..48c48b3260 100644 --- a/letta/local_llm/vllm/api.py +++ b/letta/local_llm/vllm/api.py @@ -3,7 +3,7 @@ from letta.local_llm.settings.settings import get_completions_settings from letta.local_llm.utils import count_tokens, post_json_auth_request -WEBUI_API_SUFFIX = "/v1/completions" +WEBUI_API_SUFFIX = "/completions" def get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user, grammar=None): diff --git a/letta/providers.py b/letta/providers.py index 761fcd7ee7..fa54570846 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -14,14 +14,18 @@ class Provider(BaseModel): - def list_llm_models(self): + def list_llm_models(self) -> List[LLMConfig]: return [] - def list_embedding_models(self): + def list_embedding_models(self) -> List[EmbeddingConfig]: return [] - def get_model_context_window(self, model_name: str): - pass + def get_model_context_window(self, model_name: str) -> Optional[int]: + raise NotImplementedError + + def provider_tag(self) -> str: + """String representation of the provider for display purposes""" + raise NotImplementedError class LettaProvider(Provider): @@ -162,7 +166,7 @@ def list_llm_models(self) -> List[LLMConfig]: ) return configs - def get_model_context_window(self, model_name: str): + def get_model_context_window(self, model_name: str) -> Optional[int]: import requests @@ -310,7 +314,7 @@ def list_embedding_models(self): ) return configs - def get_model_context_window(self, model_name: str): + def get_model_context_window(self, model_name: str) -> Optional[int]: from letta.llm_api.google_ai import google_ai_get_model_context_window return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) @@ -371,16 +375,75 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: ) return configs - def get_model_context_window(self, model_name: str): + def get_model_context_window(self, model_name: str) -> Optional[int]: """ This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model. """ return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096) -class VLLMProvider(OpenAIProvider): +class VLLMChatCompletionsProvider(Provider): + """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" + # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - pass + name: str = "vllm" + base_url: str = Field(..., description="Base URL for the vLLM API.") + + def list_llm_models(self) -> List[LLMConfig]: + # not supported with vLLM + from letta.llm_api.openai import openai_get_model_list + + assert self.base_url, "base_url is required for vLLM provider" + response = openai_get_model_list(self.base_url, api_key=None) + + configs = [] + print(response) + for model in response["data"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=model["max_model_len"], + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # not supported with vLLM + return [] + + +class VLLMCompletionsProvider(Provider): + """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" + + # NOTE: vLLM only serves one model at a time (so could configure that through env variables) + name: str = "vllm" + base_url: str = Field(..., description="Base URL for the vLLM API.") + default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper)to use on vLLM /completions API.") + + def list_llm_models(self) -> List[LLMConfig]: + # not supported with vLLM + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=None) + + configs = [] + for model in response["data"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="vllm", + model_endpoint=self.base_url, + model_wrapper=self.default_prompt_formatter, + context_window=model["max_model_len"], + ) + ) + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # not supported with vLLM + return [] class CohereProvider(OpenAIProvider): diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py index e56b2f8272..31f7ee8da3 100644 --- a/letta/schemas/embedding_config.py +++ b/letta/schemas/embedding_config.py @@ -52,3 +52,10 @@ def default_config(cls, model_name: Optional[str] = None, provider: Optional[str ) else: raise ValueError(f"Model {model_name} not supported.") + + def pretty_print(self) -> str: + return ( + f"{self.embedding_model}" + + (f" [type={self.embedding_endpoint_type}]" if self.embedding_endpoint_type else "") + + (f" [ip={self.embedding_endpoint}]" if self.embedding_endpoint else "") + ) diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 412e6483ee..b3d7f02f0a 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -68,3 +68,10 @@ def default_config(cls, model_name: str): ) else: raise ValueError(f"Model {model_name} not supported.") + + def pretty_print(self) -> str: + return ( + f"{self.model}" + + (f" [type={self.model_endpoint_type}]" if self.model_endpoint_type else "") + + (f" [ip={self.model_endpoint}]" if self.model_endpoint else "") + ) diff --git a/letta/server/server.py b/letta/server/server.py index efd16a784b..08050ac080 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -51,7 +51,8 @@ OllamaProvider, OpenAIProvider, Provider, - VLLMProvider, + VLLMChatCompletionsProvider, + VLLMCompletionsProvider, ) from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState from letta.schemas.api_key import APIKey, APIKeyCreate @@ -244,12 +245,11 @@ def __init__( if model_settings.anthropic_api_key: self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key)) if model_settings.ollama_base_url: - self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url)) - if model_settings.vllm_base_url: - self._enabled_providers.append(VLLMProvider(base_url=model_settings.vllm_base_url)) + self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url, api_key=None)) if model_settings.gemini_api_key: self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key)) if model_settings.azure_api_key and model_settings.azure_base_url: + assert model_settings.azure_api_version, "AZURE_API_VERSION is required" self._enabled_providers.append( AzureProvider( api_key=model_settings.azure_api_key, @@ -257,6 +257,18 @@ def __init__( api_version=model_settings.azure_api_version, ) ) + if model_settings.vllm_api_base: + # vLLM exposes both a /chat/completions and a /completions endpoint + self._enabled_providers.append( + VLLMCompletionsProvider( + base_url=model_settings.vllm_api_base, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + ) + # NOTE: to use the /chat/completions endpoint, you need to specify extra flags on vLLM startup + # see: https://docs.vllm.ai/en/latest/getting_started/examples/openai_chat_completion_client_with_tools.html + # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" + self._enabled_providers.append(VLLMChatCompletionsProvider(base_url=model_settings.vllm_api_base)) def save_agents(self): """Saves all the agents that are in the in-memory object store""" diff --git a/letta/settings.py b/letta/settings.py index 75a55bd9ab..91c7add526 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -4,14 +4,20 @@ from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict +from letta.local_llm.constants import DEFAULT_WRAPPER_NAME + class ModelSettings(BaseSettings): # env_prefix='my_prefix_' + # when we use /completions APIs (instead of /chat/completions), we need to specify a model wrapper + # the "model wrapper" is responsible for prompt formatting and function calling parsing + default_prompt_formatter: str = DEFAULT_WRAPPER_NAME + # openai openai_api_key: Optional[str] = None - openai_api_base: Optional[str] = "https://api.openai.com/v1" + openai_api_base: str = "https://api.openai.com/v1" # groq groq_api_key: Optional[str] = None @@ -34,7 +40,7 @@ class ModelSettings(BaseSettings): gemini_api_key: Optional[str] = None # vLLM - vllm_base_url: Optional[str] = None + vllm_api_base: Optional[str] = None # openllm openllm_auth_type: Optional[str] = None diff --git a/tests/test_providers.py b/tests/test_providers.py index 684fed5fbd..01bb8d4115 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -52,3 +52,11 @@ def test_googleai(): print(models) provider.list_embedding_models() + + +# def test_vllm(): +# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE")) +# models = provider.list_llm_models() +# print(models) +# +# provider.list_embedding_models()