From eaf50d8c359f453a9414a9aa75bfdd303cfbfa09 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 4 Mar 2024 13:12:53 -0800 Subject: [PATCH 1/2] First -i/--image input prototype, refs #331 --- llm/cli.py | 24 ++++++++++++++++++++++-- llm/default_plugins/openai_models.py | 21 +++++++++++++++++++++ llm/models.py | 26 ++++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 7e72283a..3aeba7d6 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -22,12 +22,12 @@ set_alias, remove_alias, ) - from .migrations import migrate from .plugins import pm import base64 import pathlib import pydantic +import re import readline from runpy import run_module import shutil @@ -36,6 +36,7 @@ import sys import textwrap from typing import cast, Optional, Iterable, Union, Tuple +import urllib import warnings import yaml @@ -83,10 +84,28 @@ def cli(): """ +class FileOrUrl(click.ParamType): + name = "file_or_url" + + def convert(self, value, param, ctx): + if value == "-": + return sys.stdin + if re.match(r"^https?://", value): + return urllib.request.urlopen(value) + # Use pathlib to detect if it is a readable file + path = pathlib.Path(value) + if path.exists() and path.is_file(): + return path.open("rb") + self.fail(f"{value} is not a valid file path or URL", param, ctx) + + @cli.command(name="prompt") @click.argument("prompt", required=False) @click.option("-s", "--system", help="System prompt to use") @click.option("model_id", "-m", "--model", help="Model to use") +@click.option( + "images", "-i", "--image", type=FileOrUrl(), multiple=True, help="Images for prompt" +) @click.option( "options", "-o", @@ -126,6 +145,7 @@ def prompt( prompt, system, model_id, + images, options, template, param, @@ -272,7 +292,7 @@ def read_prompt(): prompt_method = conversation.prompt try: - response = prompt_method(prompt, system, **validated_options) + response = prompt_method(prompt, system, images=images, **validated_options) if should_stream: for chunk in response: print(chunk, end="") diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 88958e94..f0d804b5 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,3 +1,4 @@ +import base64 from llm import EmbeddingModel, Model, hookimpl import llm from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client @@ -31,6 +32,7 @@ def register_models(register): register(Chat("gpt-4-1106-preview")) register(Chat("gpt-4-0125-preview")) register(Chat("gpt-4-turbo-preview"), aliases=("gpt-4-turbo", "4-turbo", "4t")) + register(Chat("gpt-4-vision-preview", images=True), aliases=("4v",)) # The -instruct completion model register( Completion("gpt-3.5-turbo-instruct", default_max_tokens=256), @@ -264,6 +266,7 @@ def __init__( api_version=None, api_engine=None, headers=None, + images=False, ): self.model_id = model_id self.key = key @@ -273,6 +276,7 @@ def __init__( self.api_version = api_version self.api_engine = api_engine self.headers = headers + self.supports_images = images def __str__(self): return "OpenAI Chat: {}".format(self.model_id) @@ -297,6 +301,23 @@ def execute(self, prompt, stream, response, conversation=None): if prompt.system and prompt.system != current_system: messages.append({"role": "system", "content": prompt.system}) messages.append({"role": "user", "content": prompt.prompt}) + if prompt.images: + for image in prompt.images: + messages.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,{}".format( + base64.b64encode(image.read()).decode("utf-8") + ) + }, + } + ], + } + ) response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) client = self.get_client() diff --git a/llm/models.py b/llm/models.py index e3e54b87..b4a4fa17 100644 --- a/llm/models.py +++ b/llm/models.py @@ -7,12 +7,20 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union from abc import ABC, abstractmethod import json +from pathlib import Path from pydantic import BaseModel from ulid import ULID CONVERSATION_NAME_LENGTH = 32 +@dataclass +class PromptImage: + filepath: Optional[Path] + url: Optional[str] + bytes: Optional[bytes] + + @dataclass class Prompt: prompt: str @@ -20,13 +28,17 @@ class Prompt: system: Optional[str] prompt_json: Optional[str] options: "Options" + images: Optional[List[PromptImage]] - def __init__(self, prompt, model, system=None, prompt_json=None, options=None): + def __init__( + self, prompt, model, system=None, images=None, prompt_json=None, options=None + ): self.prompt = prompt self.model = model self.system = system self.prompt_json = prompt_json self.options = options or {} + self.images = images @dataclass @@ -246,6 +258,7 @@ class Model(ABC, _get_key_mixin): needs_key: Optional[str] = None key_env_var: Optional[str] = None can_stream: bool = False + supports_images: bool = False class Options(_Options): pass @@ -272,10 +285,19 @@ def prompt( prompt: Optional[str], system: Optional[str] = None, stream: bool = True, + images: Optional[List[PromptImage]] = None, **options ): + if images and not self.supports_images: + raise ValueError("This model does not support images") return self.response( - Prompt(prompt, system=system, model=self, options=self.Options(**options)), + Prompt( + prompt, + system=system, + model=self, + images=images, + options=self.Options(**options), + ), stream=stream, ) From 96bb174eeaef7f9cc45f3f3d7a801d2c951b15f0 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 13 May 2024 12:39:29 -0700 Subject: [PATCH 2/2] WIP image changes --- llm/cli.py | 37 +++++++++++++++++++++++++++++++------ llm/migrations.py | 12 ++++++++++++ llm/models.py | 1 + 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 3aeba7d6..9f182d77 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -84,18 +84,35 @@ def cli(): """ -class FileOrUrl(click.ParamType): - name = "file_or_url" +class ImageOrUrl: + def __init__(self, fp=None, url=None): + self._fp = fp + self.url = url + + def fp(self): + if self._fp: + return self._fp + else: + return urllib.request.urlopen(self.url) + + +class ImageOption(click.ParamType): + """ + Images can be specified as a filepath or a + URL or a '-' to read from standard input + """ + + name = "image" def convert(self, value, param, ctx): if value == "-": - return sys.stdin + return ImageOrUrl(sys.stdin) if re.match(r"^https?://", value): - return urllib.request.urlopen(value) + return ImageOrUrl(url=value) # Use pathlib to detect if it is a readable file path = pathlib.Path(value) if path.exists() and path.is_file(): - return path.open("rb") + return ImageOrUrl(path.open("rb")) self.fail(f"{value} is not a valid file path or URL", param, ctx) @@ -104,7 +121,12 @@ def convert(self, value, param, ctx): @click.option("-s", "--system", help="System prompt to use") @click.option("model_id", "-m", "--model", help="Model to use") @click.option( - "images", "-i", "--image", type=FileOrUrl(), multiple=True, help="Images for prompt" + "images", + "-i", + "--image", + type=ImageOption(), + multiple=True, + help="Images for prompt", ) @click.option( "options", @@ -125,6 +147,7 @@ def convert(self, value, param, ctx): @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("-n", "--no-log", is_flag=True, help="Don't log to database") @click.option("--log", is_flag=True, help="Log prompt and response to the database") +@click.option("--store", is_flag=True, help="Store image files in the database") @click.option( "_continue", "-c", @@ -152,6 +175,7 @@ def prompt( no_stream, no_log, log, + store, _continue, conversation_id, key, @@ -194,6 +218,7 @@ def read_prompt(): ("--template", template), ("--continue", _continue), ("--cid", conversation_id), + ("--image", images), ): if var: disallowed_options.append(option) diff --git a/llm/migrations.py b/llm/migrations.py index 008ae976..c2ddaa5b 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -201,3 +201,15 @@ def m010_create_new_log_tables(db): @migration def m011_fts_for_responses(db): db["responses"].enable_fts(["prompt", "response"], create_triggers=True) + + +@migration +def m012_images_table(db): + db["images"].create({ + "id": str, # ulid + "url": str, + "filepath": str, + "content": bytes, + "content_md5": str, # To avoid storing duplicate blobs + }) + db["images"].create_index(["content_md5"]) diff --git a/llm/models.py b/llm/models.py index b4a4fa17..4aecea85 100644 --- a/llm/models.py +++ b/llm/models.py @@ -259,6 +259,7 @@ class Model(ABC, _get_key_mixin): key_env_var: Optional[str] = None can_stream: bool = False supports_images: bool = False + supports_image_urls: bool = False class Options(_Options): pass