Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image inputs (very much a work in progress) #492

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
set_alias,
remove_alias,
)

from .migrations import migrate
from .plugins import pm
import base64
import pathlib
import pydantic
import re
import readline
from runpy import run_module
import shutil
Expand All @@ -36,6 +36,7 @@
import sys
import textwrap
from typing import cast, Optional, Iterable, Union, Tuple
import urllib
import warnings
import yaml

Expand Down Expand Up @@ -83,10 +84,50 @@ def cli():
"""


class ImageOrUrl:
def __init__(self, fp=None, url=None):
self._fp = fp
self.url = url

def fp(self):
if self._fp:
return self._fp
else:
return urllib.request.urlopen(self.url)


class ImageOption(click.ParamType):
"""
Images can be specified as a filepath or a
URL or a '-' to read from standard input
"""

name = "image"

def convert(self, value, param, ctx):
if value == "-":
return ImageOrUrl(sys.stdin)
if re.match(r"^https?://", value):
return ImageOrUrl(url=value)
# Use pathlib to detect if it is a readable file
path = pathlib.Path(value)
if path.exists() and path.is_file():
return ImageOrUrl(path.open("rb"))
self.fail(f"{value} is not a valid file path or URL", param, ctx)


@cli.command(name="prompt")
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"images",
"-i",
"--image",
type=ImageOption(),
multiple=True,
help="Images for prompt",
)
@click.option(
"options",
"-o",
Expand All @@ -106,6 +147,7 @@ def cli():
@click.option("--no-stream", is_flag=True, help="Do not stream output")
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option("--log", is_flag=True, help="Log prompt and response to the database")
@click.option("--store", is_flag=True, help="Store image files in the database")
@click.option(
"_continue",
"-c",
Expand All @@ -126,12 +168,14 @@ def prompt(
prompt,
system,
model_id,
images,
options,
template,
param,
no_stream,
no_log,
log,
store,
_continue,
conversation_id,
key,
Expand Down Expand Up @@ -174,6 +218,7 @@ def read_prompt():
("--template", template),
("--continue", _continue),
("--cid", conversation_id),
("--image", images),
):
if var:
disallowed_options.append(option)
Expand Down Expand Up @@ -272,7 +317,7 @@ def read_prompt():
prompt_method = conversation.prompt

try:
response = prompt_method(prompt, system, **validated_options)
response = prompt_method(prompt, system, images=images, **validated_options)
if should_stream:
for chunk in response:
print(chunk, end="")
Expand Down
21 changes: 21 additions & 0 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from llm import EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
Expand Down Expand Up @@ -31,6 +32,7 @@ def register_models(register):
register(Chat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-preview"), aliases=("gpt-4-turbo", "4-turbo", "4t"))
register(Chat("gpt-4-vision-preview", images=True), aliases=("4v",))
# The -instruct completion model
register(
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
Expand Down Expand Up @@ -264,6 +266,7 @@ def __init__(
api_version=None,
api_engine=None,
headers=None,
images=False,
):
self.model_id = model_id
self.key = key
Expand All @@ -273,6 +276,7 @@ def __init__(
self.api_version = api_version
self.api_engine = api_engine
self.headers = headers
self.supports_images = images

def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)
Expand All @@ -297,6 +301,23 @@ def execute(self, prompt, stream, response, conversation=None):
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
if prompt.images:
for image in prompt.images:
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,{}".format(
base64.b64encode(image.read()).decode("utf-8")
)
},
}
],
}
)
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()
Expand Down
12 changes: 12 additions & 0 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,15 @@ def m010_create_new_log_tables(db):
@migration
def m011_fts_for_responses(db):
db["responses"].enable_fts(["prompt", "response"], create_triggers=True)


@migration
def m012_images_table(db):
db["images"].create({
"id": str, # ulid
"url": str,
"filepath": str,
"content": bytes,
"content_md5": str, # To avoid storing duplicate blobs
})
db["images"].create_index(["content_md5"])
27 changes: 25 additions & 2 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,38 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from abc import ABC, abstractmethod
import json
from pathlib import Path
from pydantic import BaseModel
from ulid import ULID

CONVERSATION_NAME_LENGTH = 32


@dataclass
class PromptImage:
filepath: Optional[Path]
url: Optional[str]
bytes: Optional[bytes]


@dataclass
class Prompt:
prompt: str
model: "Model"
system: Optional[str]
prompt_json: Optional[str]
options: "Options"
images: Optional[List[PromptImage]]

def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
def __init__(
self, prompt, model, system=None, images=None, prompt_json=None, options=None
):
self.prompt = prompt
self.model = model
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
self.images = images


@dataclass
Expand Down Expand Up @@ -246,6 +258,8 @@ class Model(ABC, _get_key_mixin):
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
supports_images: bool = False
supports_image_urls: bool = False

class Options(_Options):
pass
Expand All @@ -272,10 +286,19 @@ def prompt(
prompt: Optional[str],
system: Optional[str] = None,
stream: bool = True,
images: Optional[List[PromptImage]] = None,
**options
):
if images and not self.supports_images:
raise ValueError("This model does not support images")
return self.response(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
Prompt(
prompt,
system=system,
model=self,
images=images,
options=self.Options(**options),
),
stream=stream,
)

Expand Down
Loading