diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 43654d6b..f04a7eaf 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1884fc9a..1a4d3926 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] pydantic: ["==1.10.2", ">=2.0.0"] steps: - uses: actions/checkout@v4 diff --git a/Justfile b/Justfile index 868e7ea7..0363d42f 100644 --- a/Justfile +++ b/Justfile @@ -17,7 +17,7 @@ echo " cog" pipenv run cog --check \ -p "import sys, os; sys._called_from_test=True; os.environ['LLM_USER_PATH'] = '/tmp'" \ - README.md docs/*.md + README.md docs/**/*.md docs/*.md echo " mypy" pipenv run mypy llm echo " ruff" diff --git a/docs/help.md b/docs/help.md index 9cd2927c..3352a000 100644 --- a/docs/help.md +++ b/docs/help.md @@ -74,6 +74,7 @@ Commands: plugins List installed plugins similar Return top N similar IDs from a collection templates Manage stored prompt templates + tools Manage available tools uninstall Uninstall Python packages from the LLM environment ``` @@ -92,6 +93,7 @@ Options: -o, --option ... key/value options for the model -t, --template TEXT Template to use -p, --param ... Parameters for template + --enable-tools Enable tool usage for supported models --no-stream Do not stream output -n, --no-log Don't log to database --log Log prompt and response to the database @@ -117,6 +119,7 @@ Options: -t, --template TEXT Template to use -p, --param ... Parameters for template -o, --option ... key/value options for the model + --enable-tools Enable tool usage for supported models --no-stream Do not stream output --key TEXT API key to use --help Show this message and exit. @@ -298,6 +301,32 @@ Options: --help Show this message and exit. ``` +(help-tools)= +### llm tools --help +``` +Usage: llm tools [OPTIONS] COMMAND [ARGS]... + + Manage available tools + +Options: + --help Show this message and exit. + +Commands: + list* List available tools +``` + +(help-tools-list)= +#### llm tools list --help +``` +Usage: llm tools list [OPTIONS] + + List available tools + +Options: + --schema Show JSON schema for each tool + --help Show this message and exit. +``` + (help-templates)= ### llm templates --help ``` diff --git a/docs/openai-models.md b/docs/openai-models.md index d0fc6fcd..66f23db4 100644 --- a/docs/openai-models.md +++ b/docs/openai-models.md @@ -31,18 +31,18 @@ models = [line for line in result.output.split("\n") if line.startswith("OpenAI cog.out("```\n{}\n```".format("\n".join(models))) ]]] --> ``` -OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) -OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) -OpenAI Chat: gpt-4 (aliases: 4, gpt4) -OpenAI Chat: gpt-4-32k (aliases: 4-32k) -OpenAI Chat: gpt-4-1106-preview -OpenAI Chat: gpt-4-0125-preview -OpenAI Chat: gpt-4-turbo-2024-04-09 -OpenAI Chat: gpt-4-turbo (aliases: gpt-4-turbo-preview, 4-turbo, 4t) -OpenAI Chat: gpt-4o (aliases: 4o) -OpenAI Chat: gpt-4o-mini (aliases: 4o-mini) -OpenAI Chat: o1-preview -OpenAI Chat: o1-mini +OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) (supports tool calling) +OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) (supports tool calling) +OpenAI Chat: gpt-4 (aliases: 4, gpt4) (supports tool calling) +OpenAI Chat: gpt-4-32k (aliases: 4-32k) (supports tool calling) +OpenAI Chat: gpt-4-1106-preview (supports tool calling) +OpenAI Chat: gpt-4-0125-preview (supports tool calling) +OpenAI Chat: gpt-4-turbo-2024-04-09 (supports tool calling) +OpenAI Chat: gpt-4-turbo (aliases: gpt-4-turbo-preview, 4-turbo, 4t) (supports tool calling) +OpenAI Chat: gpt-4o (aliases: 4o) (supports tool calling) +OpenAI Chat: gpt-4o-mini (aliases: 4o-mini) (supports tool calling) +OpenAI Chat: o1-preview (supports tool calling) +OpenAI Chat: o1-mini (supports tool calling) OpenAI Completion: gpt-3.5-turbo-instruct (aliases: 3.5-instruct, chatgpt-instruct) ``` diff --git a/docs/plugins/index.md b/docs/plugins/index.md index 96ae62fd..e4f36782 100644 --- a/docs/plugins/index.md +++ b/docs/plugins/index.md @@ -5,6 +5,8 @@ LLM plugins can enhance LLM by making alternative Large Language Models availabl Plugins can also add new commands to the `llm` CLI tool. +Plugins can also add new Python functions that some LLM models can invoke - LLM tool calling. + The {ref}`plugin directory ` lists available plugins that you can install and use. {ref}`tutorial-model-plugin` describes how to build a new plugin in detail. @@ -17,5 +19,6 @@ installing-plugins directory plugin-hooks tutorial-model-plugin +tool-calling plugin-utilities ``` diff --git a/docs/plugins/llm-sampletools/llm_sampletools.py b/docs/plugins/llm-sampletools/llm_sampletools.py new file mode 100644 index 00000000..674e3a1f --- /dev/null +++ b/docs/plugins/llm-sampletools/llm_sampletools.py @@ -0,0 +1,108 @@ +import random +import sys +from typing import Annotated +import enum +import datetime + +import pydantic +import llm + + +@llm.hookimpl +def register_tools(register): + # Annotated function, will be introspected + register(llm.Tool(random_number)) + register(llm.Tool(best_restaurant_in)) + + # Generate parameter schema from pydantic model + register(llm.Tool(current_temperature, WeatherInfo.model_json_schema())) + + # No parameters, no parameter schema needed - no doc comment so provide description + register( + llm.Tool(best_restaurant, description="Find the best restaurant in the world.") + ) + + # Manually specify parameter schema + register( + llm.Tool( + current_time, + { + "type": "object", + "properties": { + "time_format": { + "description": "The format to use for the returned datetime, either ISO 8601 or unix ctime format.", + "type": "string", + "enum": ["iso", "ctime"], + }, + }, + "required": ["time_format"], + "additionalProperties": False, + }, + ) + ) + + +########## + + +def random_number( + minimum: Annotated[int, "The minimum value of the random number, default is 0"] = 0, + maximum: Annotated[ + int, f"The maximum value of the random number, default is {sys.maxsize}." + ] = sys.maxsize, +) -> str: + """Generate a random number.""" + return str(random.randrange(minimum, maximum)) # noqa: S311 + + +########## + + +def best_restaurant(): + return "WorldsBestRestaurant" + + +########## + + +def best_restaurant_in( + location: Annotated[str, "The city the restaurant is located in."] +) -> str: + """Find the best restaurant in the given location.""" + return "CitiesBestRestaurant" + + +########## + + +class Degrees(enum.Enum): + CELSIUS = "celsius" + FAHRENHEIT = "fahrenheit" + + +class WeatherInfo(pydantic.BaseModel): + location: str = pydantic.Field( + description="The location to return the current temperature for." + ) + degrees: Degrees = pydantic.Field( + description="The degree scale to return temperature in." + ) + + +def current_temperature(**weather_info) -> str: + """Return the current temperature in the provided location.""" + info = WeatherInfo(**weather_info) + return f"The current temperature in {info.location} is 42° {info.degrees.value}." + + +########## + + +def current_time(time_format): + """Return the current date and time in UTC using the specified format.""" + time = datetime.datetime.now(datetime.timezone.utc) + if time_format == "iso": + return time.isoformat() + elif time_format == "ctime": + return time.ctime() + raise ValueError(f"Unsupported time format: {time_format}") diff --git a/docs/plugins/llm-sampletools/pyproject.toml b/docs/plugins/llm-sampletools/pyproject.toml new file mode 100644 index 00000000..eebc67c7 --- /dev/null +++ b/docs/plugins/llm-sampletools/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "llm-sampletools" +version = "0.1" +dependencies = [ + "llm", + "pydantic>=2.0", +] + +[project.entry-points.llm] +markov = "llm_sampletools" \ No newline at end of file diff --git a/docs/plugins/plugin-hooks.md b/docs/plugins/plugin-hooks.md index 1d7d58f6..3a5b68d3 100644 --- a/docs/plugins/plugin-hooks.md +++ b/docs/plugins/plugin-hooks.md @@ -44,3 +44,10 @@ class HelloWorld(llm.Model): ``` {ref}`tutorial-model-plugin` describes how to use this hook in detail. + +## register_tools(register) + +This hook can be used to register one or more Python callables as `llm.Tool`s. +Models that support tool calling will be able to use these tools + +{ref}`tool-calling` describes this hook in more detail. \ No newline at end of file diff --git a/docs/plugins/tool-calling.md b/docs/plugins/tool-calling.md new file mode 100644 index 00000000..fa7a329e --- /dev/null +++ b/docs/plugins/tool-calling.md @@ -0,0 +1,91 @@ +(tool-calling)= +# Tool calling + +A plugin can expose additional tools to any supporting model via the `register_tools` plugin hook. +A plugin that implements a new LLM model can consume installed tools if the model supports tool calling. + +## Registering new tools + +Tools are `llm.Tool` instances holding a Python callable and a JSON schema describing the function parameters. +The callable must have a docstring that describes what it does. +If a parameter JSON schema is not provided, `llm.Tool` will introspect the callable and attempt to generate one. +For this to work, each paramater needs a `typing.Annotation` that contains the parameter type and a text description. +The function must return a string. It can raise a descriptive exception to be returned to the LLM if the tool fails. +If it raises `llm.ModelError`, that exception will be forwarded to the user. + +```python +from typing import Annotated +import llm + +@llm.hookimpl +def register_tools(register): + register(llm.Tool(best_restaurant_in)) + +def best_restaurant_in( + location: Annotated[str, "The city the restaurant is located in."] +) -> str: + """Find the best restaurant in the given location.""" + return "CitiesBestRestaurant" +``` + +Now when the user enables tool calling, if the model supports tool calling +(the default OpenAI chat models do), then the model can invoke the tool. +```shell-session +$ llm --enable-tools -m 4o-mini 'What is the best restaurant in Asbury Park, NJ?' +The best restaurant in Asbury Park, NJ, is called "Cities Best Restaurant." +``` + +You can generate a parameters JSON schema using [pydantic.ModelBase.model_json_schema()](https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_json_schema), or write one by hand and pass it in to the `llm.Tool` initializer. +Here are some examples of both: + +```{literalinclude} llm-sampletools/llm_sampletools.py +:language: python +``` + +## Using tools in models + +If your plugin is implementing a new `llm.Model` class that can support tool calling, +then you can set `supports_tool_calling = True` in your model class. + +You can then use the `Model.tools` property to access tools registered by your or other plugins. +The `tools` property contains a `dict` of tool names mapped to `llm.Tool` instances. +The `Tool.schema` property contains a Python dict representing the JSON schema for that tool function. +`Tool` is callable - it should be passed a JSON string representing the callables parameters. +The Tool handles any exceptions raised other than `llm.ModelError`. + +Here is a skeleton implementation for a hypothetical LLM API that supports tool calling. +```python +import llmapi # hypothetical API + +class MyToolCallingModel(llm.Model): + model_id = "toolcaller" + supports_tool_calling = True + + def execute(self, prompt, stream, response, conversation): + messages = [{"role": "user", "content": prompt.prompt}] + # Invoke our hypothetical LLM API, passing in all registered tool schemas. + completion = llmapi.chat.completion( + messages=messages, + tools=[tool.schema for tool in self.tools.values()] + ) + if completion.tool_calls: + messages.append({"role": "assistant", "tool_calls": completion.tool_calls}) + for tool_call in completion.tool_calls: + # Find the named tool and invoke it, adding the result to messages + tool = self.tools.get(tool_call.function.name) + if tool: + # Invoke the tool with the JSON string arguments. + tool_response = tool(tool_call.function.arguments) + messages.append({"role": "tool", "content": tool_response, "tool_call_id": tool_call.id}) + # Send the tool results back to the LLM + completion = llmapi.chat.completion(messages=messages) + yield completion.content + else: + yield completion.content +``` + +A number of LLM APIs support tool function calling using JSON schemas to define the tools. +For example [OpenAI](https://platform.openai.com/docs/guides/function-calling), +[Anthropic](https://docs.anthropic.com/en/docs/build-with-claude/tool-use), +[Google Gemini](https://ai.google.dev/gemini-api/docs/function-calling#function_declarations), +[Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools). \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index 005a1690..8123faa9 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -225,7 +225,7 @@ result = CliRunner().invoke(cli, ["models", "list", "--options"]) cog.out("```\n{}\n```".format(result.output)) ]]] --> ``` -OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) +OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) (supports tool calling) temperature: float What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will @@ -255,7 +255,7 @@ OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) Integer seed to attempt to sample deterministically json_object: boolean Output a valid JSON object {...}. Prompt must mention JSON. -OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) +OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -265,7 +265,7 @@ OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4 (aliases: 4, gpt4) +OpenAI Chat: gpt-4 (aliases: 4, gpt4) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -275,7 +275,7 @@ OpenAI Chat: gpt-4 (aliases: 4, gpt4) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4-32k (aliases: 4-32k) +OpenAI Chat: gpt-4-32k (aliases: 4-32k) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -285,7 +285,7 @@ OpenAI Chat: gpt-4-32k (aliases: 4-32k) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4-1106-preview +OpenAI Chat: gpt-4-1106-preview (supports tool calling) temperature: float max_tokens: int top_p: float @@ -295,7 +295,7 @@ OpenAI Chat: gpt-4-1106-preview logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4-0125-preview +OpenAI Chat: gpt-4-0125-preview (supports tool calling) temperature: float max_tokens: int top_p: float @@ -305,7 +305,7 @@ OpenAI Chat: gpt-4-0125-preview logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4-turbo-2024-04-09 +OpenAI Chat: gpt-4-turbo-2024-04-09 (supports tool calling) temperature: float max_tokens: int top_p: float @@ -315,7 +315,7 @@ OpenAI Chat: gpt-4-turbo-2024-04-09 logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4-turbo (aliases: gpt-4-turbo-preview, 4-turbo, 4t) +OpenAI Chat: gpt-4-turbo (aliases: gpt-4-turbo-preview, 4-turbo, 4t) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -325,7 +325,7 @@ OpenAI Chat: gpt-4-turbo (aliases: gpt-4-turbo-preview, 4-turbo, 4t) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4o (aliases: 4o) +OpenAI Chat: gpt-4o (aliases: 4o) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -335,7 +335,7 @@ OpenAI Chat: gpt-4o (aliases: 4o) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: gpt-4o-mini (aliases: 4o-mini) +OpenAI Chat: gpt-4o-mini (aliases: 4o-mini) (supports tool calling) temperature: float max_tokens: int top_p: float @@ -345,7 +345,7 @@ OpenAI Chat: gpt-4o-mini (aliases: 4o-mini) logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: o1-preview +OpenAI Chat: o1-preview (supports tool calling) temperature: float max_tokens: int top_p: float @@ -355,7 +355,7 @@ OpenAI Chat: o1-preview logit_bias: dict, str seed: int json_object: boolean -OpenAI Chat: o1-mini +OpenAI Chat: o1-mini (supports tool calling) temperature: float max_tokens: int top_p: float diff --git a/llm/__init__.py b/llm/__init__.py index 9e8afacb..8b03552d 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -3,6 +3,7 @@ ModelError, NeedsKeyException, ) +from .tool import Tool from .models import ( Conversation, Model, @@ -37,6 +38,7 @@ "Template", "ModelError", "NeedsKeyException", + "Tool", ] DEFAULT_MODEL = "gpt-4o-mini" @@ -115,6 +117,16 @@ def register(model, aliases=None): return models +def get_tools() -> Dict[str, Tool]: + tools = {} + + def register(tool: Tool): + tools[tool.name] = tool + + pm.hook.register_tools(register=register) + return tools + + def get_embedding_model(name): aliases = get_embedding_model_aliases() try: diff --git a/llm/cli.py b/llm/cli.py index a1b14576..eb45fc58 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -20,6 +20,7 @@ get_model, get_model_aliases, get_models_with_aliases, + get_tools, user_dir, set_alias, set_default_model, @@ -60,6 +61,19 @@ def _validate_metadata_json(ctx, param, value): raise click.BadParameter("Metadata must be valid JSON") +def _validate_tools(model, enable_tools: bool): + if enable_tools: + if not model.supports_tool_calling: + click.secho( + f"Model {model.model_id} does not support tool calling", + err=True, + dim=True, + italic=True, + ) + else: + model.supports_tool_calling = False + + @click.group( cls=DefaultGroup, default="prompt", @@ -104,6 +118,9 @@ def cli(): type=(str, str), help="Parameters for template", ) +@click.option( + "--enable-tools", is_flag=True, help="Enable tool usage for supported models" +) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("-n", "--no-log", is_flag=True, help="Don't log to database") @click.option("--log", is_flag=True, help="Log prompt and response to the database") @@ -130,6 +147,7 @@ def prompt( options, template, param, + enable_tools, no_stream, no_log, log, @@ -241,6 +259,8 @@ def read_prompt(): except KeyError: raise click.ClickException("'{}' is not a known model".format(model_id)) + _validate_tools(model, enable_tools) + # Provide the API key, if one is needed and has been provided if model.needs_key: model.key = get_key(key, model.needs_key, model.key_env_var) @@ -326,6 +346,9 @@ def read_prompt(): multiple=True, help="key/value options for the model", ) +@click.option( + "--enable-tools", is_flag=True, help="Enable tool usage for supported models" +) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("--key", help="API key to use") def chat( @@ -336,6 +359,7 @@ def chat( template, param, options, + enable_tools, no_stream, key, ): @@ -381,6 +405,8 @@ def chat( except KeyError: raise click.ClickException("'{}' is not a known model".format(model_id)) + _validate_tools(model, enable_tools) + # Provide the API key, if one is needed and has been provided if model.needs_key: model.key = get_key(key, model.needs_key, model.key_env_var) @@ -811,6 +837,8 @@ def models_list(options): extra = "" if model_with_aliases.aliases: extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases)) + if model_with_aliases.model.supports_tool_calling: + extra += " (supports tool calling)" output = str(model_with_aliases.model) + extra if options and model_with_aliases.model.Options.schema()["properties"]: for name, field in model_with_aliases.model.Options.schema()[ @@ -855,6 +883,25 @@ def models_default(model): raise click.ClickException("Unknown model: {}".format(model)) +@cli.group( + cls=DefaultGroup, + default="list", + default_if_no_args=True, +) +def tools(): + "Manage available tools" + + +@tools.command(name="list") +@click.option("--schema", is_flag=True, help="Show JSON schema for each tool") +def tools_list(schema): + "List available tools" + for name, tool in get_tools().items(): + click.echo(f"{name}: {tool.description}") + if schema: + click.echo(json.dumps(tool.schema, indent=2)) + + @cli.group( cls=DefaultGroup, default="list", diff --git a/llm/default_plugins/file_tools.py b/llm/default_plugins/file_tools.py new file mode 100644 index 00000000..f73b53cb --- /dev/null +++ b/llm/default_plugins/file_tools.py @@ -0,0 +1,24 @@ +import json +import glob +from typing import Annotated +import llm + + +@llm.hookimpl +def register_tools(register): + register(llm.Tool(read_files)) + + +def read_files( + filenames: Annotated[ + list[str], + "A list of file paths to read. Paths can be a Python `glob.glob()` pattern.", + ] +) -> str: + """Read the given filenames and return the contents.""" + result = [] + for path in filenames: + for filename in glob.glob(path): + with open(filename, "r") as f: + result.append({"filename": filename, "contents": f.read()}) + return json.dumps(result) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 657c0d20..e32ef032 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -5,6 +5,7 @@ import datetime import httpx import openai +from openai.types.chat import ChatCompletion import os try: @@ -19,6 +20,7 @@ from typing import List, Iterable, Iterator, Optional, Union import json import yaml +from ..tool import format_error @hookimpl @@ -248,9 +250,43 @@ def validate_logit_bias(cls, logit_bias): return validated_logit_bias +class ChatCompletionHandler: + client: openai.OpenAI + model: str + stream: bool + completion: Optional[ChatCompletion] = None + + def __init__(self, client: openai.OpenAI, model: str, stream: bool): + self.client = client + self.model = model + self.stream = stream + + def run(self, messages, **kwargs): + if self.stream: + with self.client.beta.chat.completions.stream( + model=self.model, + messages=messages, + **kwargs, + ) as stream: + for event in stream: + if event.type == "content.delta": + yield event.delta + self.completion = stream.get_final_completion() + else: + self.completion = self.client.chat.completions.create( + model=self.model, + messages=messages, + stream=False, + **kwargs, + ) + if self.completion.choices[0].message.content: + yield self.completion.choices[0].message.content + + class Chat(Model): needs_key = "openai" key_env_var = "OPENAI_API_KEY" + supports_tool_calling = True default_max_tokens = None @@ -311,32 +347,52 @@ def execute(self, prompt, stream, response, conversation=None): messages.append({"role": "user", "content": prompt.prompt}) response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) + if self.tools: + kwargs["tools"] = [tool.schema for tool in self.tools.values()] client = self.get_client() - if stream: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=True, - **kwargs, + + handler = ChatCompletionHandler( + client, self.model_name or self.model_id, stream + ) + yield from handler.run(messages, **kwargs) + + response_json = remove_dict_none_values(handler.completion.model_dump()) + tool_calls = handler.completion.choices[0].message.tool_calls + + if not tool_calls: + response.response_json = response_json + return + + response.response_json = [response_json] + while tool_calls: + messages.append( + { + "role": "assistant", + "tool_calls": [tc.model_dump() for tc in tool_calls], + } ) - chunks = [] - for chunk in completion: - chunks.append(chunk) - content = chunk.choices[0].delta.content - if content is not None: - yield content - response.response_json = remove_dict_none_values(combine_chunks(chunks)) - else: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=False, - **kwargs, + for tool_call in tool_calls: + tool = self.tools.get(tool_call.function.name) + if not tool: + tool_response = format_error( + f"Attempt to call non-existent function '{tool_call.function.name}'" + ) + else: + tool_response = tool(tool_call.function.arguments) + messages.append( + { + "role": "tool", + "content": tool_response, + "tool_call_id": tool_call.id, + } + ) + yield from handler.run(messages, **kwargs) + response.response_json.append( + remove_dict_none_values(handler.completion.model_dump()) ) - response.response_json = remove_dict_none_values(completion.model_dump()) - yield completion.choices[0].message.content + tool_calls = handler.completion.choices[0].message.tool_calls - def get_client(self): + def get_client(self) -> openai.OpenAI: kwargs = {} if self.api_base: kwargs["base_url"] = self.api_base @@ -369,6 +425,8 @@ def build_kwargs(self, prompt): class Completion(Chat): + supports_tool_calling = False + class Options(SharedOptions): logprobs: Optional[int] = Field( description="Include the log probabilities of most likely N per token", diff --git a/llm/hookspecs.py b/llm/hookspecs.py index e7f806be..ee6c757b 100644 --- a/llm/hookspecs.py +++ b/llm/hookspecs.py @@ -18,3 +18,8 @@ def register_models(register): @hookspec def register_embedding_models(register): "Register additional model instances that can be used for embedding" + + +@hookspec +def register_tools(register): + "Register tool functions that the LLM model can invoke" diff --git a/llm/models.py b/llm/models.py index 0e47bb60..491ba1cc 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException +from functools import cached_property from itertools import islice import re import time @@ -9,6 +10,7 @@ import json from pydantic import BaseModel from ulid import ULID +import llm CONVERSATION_NAME_LENGTH = 32 @@ -246,10 +248,17 @@ class Model(ABC, _get_key_mixin): needs_key: Optional[str] = None key_env_var: Optional[str] = None can_stream: bool = False + supports_tool_calling: bool = False class Options(_Options): pass + @cached_property + def tools(self) -> Dict[str, llm.Tool]: + if self.supports_tool_calling: + return llm.get_tools() + return {} + def conversation(self): return Conversation(model=self) diff --git a/llm/plugins.py b/llm/plugins.py index 933725c7..223b383f 100644 --- a/llm/plugins.py +++ b/llm/plugins.py @@ -5,7 +5,10 @@ import sys from . import hookspecs -DEFAULT_PLUGINS = ("llm.default_plugins.openai_models",) +DEFAULT_PLUGINS = ( + "llm.default_plugins.openai_models", + "llm.default_plugins.file_tools", +) pm = pluggy.PluginManager("llm") pm.add_hookspecs(hookspecs) diff --git a/llm/tool.py b/llm/tool.py new file mode 100644 index 00000000..5e154a0b --- /dev/null +++ b/llm/tool.py @@ -0,0 +1,139 @@ +import enum +import inspect +import json +from typing import Any, Annotated, Union, Optional, get_origin, get_args +from collections.abc import Callable +import click +import llm + + +# __origin__ could be types.UnionType instance (for optional parameters that have a None default) or a class +TYPEMAP = { + int: "integer", + Union[int, None]: "integer", + float: "number", + Union[float, None]: "number", + list: "array", + Union[list, None]: "array", + bool: "boolean", + Union[bool, None]: "boolean", + str: "string", + Union[str, None]: "string", + type(None): "null", # types.NoneType is Python 3.10+ +} + + +def convert_parameter(param: inspect.Parameter) -> dict[str, Any]: + """Convert a function parameter to a JSON schema parameter.""" + annotation = param.annotation + # This will return Annotated, or None for inspect.Parameter.empty or other types + unsubscriped_type = get_origin(annotation) + if not ( + unsubscriped_type is Annotated + and len(annotation.__metadata__) == 1 + and isinstance(annotation.__metadata__[0], str) + ): + raise ValueError( + "Function parameters must be annotated with typing.Annotated[, 'description']" + ) + + schema: dict[str, Any] = { + "description": annotation.__metadata__[0], + } + + origin = annotation.__origin__ + type_ = TYPEMAP.get(get_origin(origin)) or TYPEMAP.get(origin) + if type_: + schema["type"] = type_ + if type_ == "array": + args = get_args(origin) + if args: + if len(args) == 1 and (arg := TYPEMAP.get(args[0])): + schema["items"] = {"type": arg} + else: + raise TypeError(f"Annotated parameter type {origin} not supported") + elif issubclass(origin, enum.Enum): + # str values only for now, e.g. enum.StrEnum + schema["type"] = "string" + schema["enum"] = [m.value for m in origin if isinstance(m.value, str)] + else: + raise TypeError(f"Annotated parameter type {origin} not supported") + + return schema + + +def format_exception(e: Exception) -> str: + return json.dumps({"is_error": True, "exception": repr(e)}) + + +def format_error(message: str) -> str: + return json.dumps({"is_error": True, "error": message}) + + +class Tool: + name: str + schema: dict[str, Any] + function: Callable[..., str] + + def __init__( + self, + function: Callable[..., str], + parameters_schema: Optional[dict[str, Any]] = None, + description: Optional[str] = None, + name: Optional[str] = None, + ): + self.function = function + self.name = name or function.__name__ + self.description = description or function.__doc__ + if not self.description: + raise ValueError( + "Tool functions must have a docstring or provide a description" + ) + self.schema = { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + }, + } + signature = inspect.signature(self.function) + if ( + signature.return_annotation is not inspect.Parameter.empty + and signature.return_annotation is not str + ): + raise ValueError("Tool functions must return a string") + if not parameters_schema and signature.parameters: + parameters_schema = self.introspect_function(signature) + if parameters_schema: + self.schema["function"]["parameters"] = parameters_schema + + def __call__(self, json_parameters: str) -> str: + try: + args = json.loads(json_parameters) + params = ", ".join(f"{k}={v}" for k, v in args.items()) + click.secho( + f"Tool: {self.name}({params})", + err=True, + italic=True, + dim=True, + ) + return self.function(**args) + except llm.ModelError: + raise + except Exception as e: + return format_exception(e) + + def introspect_function(self, signature: inspect.Signature) -> dict[str, Any]: + return { + "type": "object", + "properties": { + name: convert_parameter(param) + for name, param in signature.parameters.items() + }, + "required": [ + name + for name, param in signature.parameters.items() + if param.default is inspect.Parameter.empty + ], + "additionalProperties": False, + } diff --git a/setup.py b/setup.py index 1f6adcd7..e6edc77b 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_long_description(): """, install_requires=[ "click", - "openai>=1.0", + "openai>=1.40.0", "click-default-group>=1.2.3", "sqlite-utils>=3.37", "sqlite-migrate>=0.1a2", @@ -53,15 +53,14 @@ def get_long_description(): "test": [ "pytest", "numpy", - "pytest-httpx", + "pytest-httpx==0.30.0", # XXX hold back until https://github.com/simonw/llm/pull/580 "cogapp", "mypy>=1.10.0", "black>=24.1.0", "ruff", - "types-click", "types-PyYAML", "types-setuptools", ] }, - python_requires=">=3.8", + python_requires=">=3.9", ) diff --git a/tests/fixtures/stream_tool_call.txt b/tests/fixtures/stream_tool_call.txt new file mode 100644 index 00000000..c6b60f0d --- /dev/null +++ b/tests/fixtures/stream_tool_call.txt @@ -0,0 +1,11 @@ +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_eHh2kRnDdLBSQJ0HaSA4A4uF","type":"function","function":{"name":"read_files","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"fil"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"enames"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":[\""}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"LICENSE"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":".txt"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"]"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNrkc8I43YTP6kbGNFCr9FtiBsJ","object":"chat.completion.chunk","created":1728501095,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_e2bde53e6e","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} +data: [DONE] \ No newline at end of file diff --git a/tests/fixtures/stream_tool_call_result.txt b/tests/fixtures/stream_tool_call_result.txt new file mode 100644 index 00000000..a0d125c5 --- /dev/null +++ b/tests/fixtures/stream_tool_call_result.txt @@ -0,0 +1,50 @@ +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" file"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" LICENSE"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":".txt"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" states"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" that"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" software"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" distributed"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" under"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" this"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" License"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" provided"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" \""},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"AS"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" IS"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":",\""},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" without"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" any"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" warranties"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" conditions"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" either"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" express"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" implied"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" emphasizes"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" that"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" user"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" should"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" refer"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" License"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" for"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" specific"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" permissions"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" and"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" limitations"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" regarding"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" software"},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} +data: {"id":"chatcmpl-AGWNspufiXQn3iYXzKEo45Ync4IUz","object":"chat.completion.chunk","created":1728501096,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} +data: [DONE] \ No newline at end of file diff --git a/tests/test_llm.py b/tests/test_llm.py index c303061d..db504e8b 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -514,7 +514,7 @@ def test_openai_localai_configuration(mocked_localai, user_path): EXPECTED_OPTIONS = """ -OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) +OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) (supports tool calling) temperature: float What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will diff --git a/tests/test_tool.py b/tests/test_tool.py new file mode 100644 index 00000000..63744147 --- /dev/null +++ b/tests/test_tool.py @@ -0,0 +1,392 @@ +import enum +from typing import Annotated, Union +import json +import os +import functools + +import pytest +from pytest_httpx import IteratorStream +from click.testing import CliRunner + +from llm.tool import Tool +from llm.cli import cli +from llm.default_plugins import file_tools + + +def test_no_parameters(): + @Tool + def tool() -> str: + "tool description" + return "output" + + assert tool.schema == { + "type": "function", + "function": {"name": "tool", "description": "tool description"}, + } + assert tool("{}") == "output" + + +def test_missing_description(): + with pytest.raises(ValueError, match=" description"): + + @Tool + def tool() -> str: + return "output" + + +def test_invalid_return(): + with pytest.raises(ValueError, match=" return"): + + @Tool + def tool() -> int: + "tool description" + + +def test_missing_annotated(): + with pytest.raises(ValueError, match=" annotated"): + + @Tool + def tool(a: int) -> str: + "tool description" + + +def test_missing_annotated_description(): + with pytest.raises(TypeError, match=" at least two arguments"): + + @Tool + def tool(a: Annotated[int]) -> str: + "tool description" + + +def test_unsupported_parameters(): + with pytest.raises(TypeError, match=" parameter type"): + + @Tool + def tool(a: Annotated[object, "a desc"]) -> str: + "tool description" + + +def test_call(): + @Tool + def tool(a: Annotated[int, "a desc"]) -> str: + "tool description" + return "output" + + assert tool(json.dumps({"a": 1})) == "output" + + assert "exception" in tool("{}") + assert "exception" in tool(json.dumps({"a": 1, "b": 2})) + + +def test_annotated_parameters(): + @Tool + def tool( + a: Annotated[bool, "a desc"], + b: Annotated[int, "b desc"] = 1, + c: Annotated[Union[str, None], "c desc"] = "2", + ) -> str: + "tool description" + return "output" + + assert tool.schema == { + "type": "function", + "function": { + "name": "tool", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "a": {"description": "a desc", "type": "boolean"}, + "b": {"description": "b desc", "type": "integer"}, + "c": {"description": "c desc", "type": "string"}, + }, + "required": ["a"], + "additionalProperties": False, + }, + }, + } + assert tool(json.dumps({"a": True})) == "output" + + +def test_enum_parameters(): + class MyEnum(enum.Enum): + A = "a" + B = "b" + + @Tool + def tool( + a: Annotated[MyEnum, "a enum desc"], + b: Annotated[int, "b desc"] = 1, + ) -> str: + "tool description" + return "output" + + assert tool.schema == { + "type": "function", + "function": { + "name": "tool", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "a": { + "description": "a enum desc", + "type": "string", + "enum": ["a", "b"], + }, + "b": {"description": "b desc", "type": "integer"}, + }, + "required": ["a"], + "additionalProperties": False, + }, + }, + } + assert tool(json.dumps({"a": MyEnum.A.value})) == "output" + + +def test_list_parameters(): + @Tool + def tool( + a: Annotated[list, "a enum desc"], + b: Annotated[list[int], "b desc"], + c: Annotated[list[str], "c desc"], + ) -> str: + "tool description" + return "output" + + assert tool.schema == { + "type": "function", + "function": { + "name": "tool", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "a": {"description": "a enum desc", "type": "array"}, + "b": { + "description": "b desc", + "type": "array", + "items": {"type": "integer"}, + }, + "c": { + "description": "c desc", + "type": "array", + "items": {"type": "string"}, + }, + }, + "required": ["a", "b", "c"], + "additionalProperties": False, + }, + }, + } + assert tool(json.dumps({"a": [], "b": [1], "c": ["s"]})) == "output" + + +def test_unsupported_list_parameters(): + with pytest.raises(TypeError, match=" parameter type"): + + @Tool + def tool( + a: Annotated[list[Union[str, int]], "a enum desc"], + ) -> str: + "tool description" + return "output" + + +def test_object_tool(): + class MyTool: + "tool description" + + __name__ = "tool" + + def __call__( + self, + a: Annotated[bool, "a desc"], + b: Annotated[int, "b desc"] = 1, + ) -> str: + return "output" + + tool = Tool(MyTool()) + + assert tool.schema == { + "type": "function", + "function": { + "name": "tool", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "a": {"description": "a desc", "type": "boolean"}, + "b": {"description": "b desc", "type": "integer"}, + }, + "required": ["a"], + "additionalProperties": False, + }, + }, + } + assert tool(json.dumps({"a": True, "b": 3})) == "output" + + +def stream_tool_call(datafile): + with open(datafile) as f: + for line in f: + yield f"{line}\n\n".encode("utf-8") + + +@pytest.fixture +def read_files_mock(monkeypatch): + def mock_read_files(filenames): + return "some license text" + + monkeypatch.setattr( + file_tools, + "read_files", + functools.update_wrapper(mock_read_files, file_tools.read_files), + ) + + +def test_tool_completion_stream(httpx_mock, read_files_mock, logs_db): + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + stream=IteratorStream( + stream_tool_call( + os.path.join(os.path.dirname(__file__), "fixtures/stream_tool_call.txt") + ) + ), + headers={"Content-Type": "text/event-stream"}, + ) + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + stream=IteratorStream( + stream_tool_call( + os.path.join( + os.path.dirname(__file__), "fixtures/stream_tool_call_result.txt" + ) + ) + ), + headers={"Content-Type": "text/event-stream"}, + ) + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + cli, + [ + "--enable-tools", + "-m", + "4o-mini", + "--key", + "x", + "Summarize this file LICENSE.txt", + ], + ) + assert result.exit_code == 0 + assert result.output == ( + 'The file LICENSE.txt states that software distributed under this License is provided "AS IS," ' + "without any warranties or conditions, either express or implied. " + "It emphasizes that the user should refer to the License for specific permissions and " + "limitations regarding the software.\n" + ) + rows = list(logs_db["responses"].rows_where(select="response_json")) + assert ( + len(json.loads(rows[0]["response_json"])) == 2 + ) # two response_jsons for tools + + +def test_tool_completion_nostream(httpx_mock, read_files_mock, logs_db): + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-AGWNZTKcKeVOqSmRraGyzeEnOzs4O", + "object": "chat.completion", + "created": 1728501077, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_SSoLPi8JuIZ1WDygNI5CSCkx", + "type": "function", + "function": { + "name": "read_files", + "arguments": '{"filenames":["LICENSE.txt"]}', + }, + } + ], + "refusal": None, + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": { + "prompt_tokens": 74, + "completion_tokens": 17, + "total_tokens": 91, + "prompt_tokens_details": {"cached_tokens": 0}, + "completion_tokens_details": {"reasoning_tokens": 0}, + }, + "system_fingerprint": "fp_74ba47b4ac", + }, + headers={"Content-Type": "application/json"}, + ) + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-AGWNa4MUDJ7q6pm2KZqutUqPWlQnX", + "object": "chat.completion", + "created": 1728501078, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": 'The LICENSE.txt file states that the software is distributed "AS IS," without any warranties or conditions, either express or implied. It advises the reader to refer to the License for specific terms regarding permissions and limitations.', + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 174, + "completion_tokens": 43, + "total_tokens": 217, + "prompt_tokens_details": {"cached_tokens": 0}, + "completion_tokens_details": {"reasoning_tokens": 0}, + }, + "system_fingerprint": "fp_f85bea6784", + }, + headers={"Content-Type": "application/json"}, + ) + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + cli, + [ + "--no-stream", + "--enable-tools", + "-m", + "4o-mini", + "--key", + "x", + "Summarize this file LICENSE.txt", + ], + ) + assert result.exit_code == 0 + assert result.output == ( + 'The LICENSE.txt file states that the software is distributed "AS IS," ' + "without any warranties or conditions, either express or implied. " + "It advises the reader to refer to the License for specific terms regarding " + "permissions and limitations.\n" + ) + rows = list(logs_db["responses"].rows_where(select="response_json")) + assert ( + len(json.loads(rows[0]["response_json"])) == 2 + ) # two response_jsons for tools