Skip to content

Commit

Permalink
First working prototype of new attachments feature, refs #587
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Oct 27, 2024
1 parent d654c95 commit a1ee8ac
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 14 deletions.
1 change: 1 addition & 0 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NeedsKeyException,
)
from .models import (
Attachment,
Conversation,
Model,
ModelWithAliases,
Expand Down
84 changes: 83 additions & 1 deletion llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
import json
from llm import (
Attachment,
Collection,
Conversation,
Response,
Expand All @@ -30,7 +31,9 @@
from .migrations import migrate
from .plugins import pm
import base64
import httpx
import pathlib
import puremagic
import pydantic
import readline
from runpy import run_module
Expand All @@ -48,6 +51,54 @@
DEFAULT_TEMPLATE = "prompt: "


class AttachmentType(click.ParamType):
name = "attachment"

def convert(self, value, param, ctx):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
try:
mimetype = puremagic.from_string(content, mime=True)
except puremagic.PureError:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(mimetype, None, None, content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
response = httpx.head(value)
response.raise_for_status()
mimetype = response.headers.get("content-type")
except httpx.HTTPError as ex:
raise click.BadParameter(str(ex))
return Attachment(mimetype, None, value, None)
# Check that the file exists
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
return Attachment(mimetype, str(path), None, None)


def attachment_types_callback(ctx, param, values):
collected = []
for value, mimetype in values:
if "://" in value:
attachment = Attachment(mimetype, None, value, None)
elif value == "-":
content = sys.stdin.buffer.read()
attachment = Attachment(mimetype, None, None, content)
else:
# Look for file
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
return collected


def _validate_metadata_json(ctx, param, value):
if value is None:
return value
Expand Down Expand Up @@ -88,6 +139,23 @@ def cli():
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"attachments",
"-a",
"--attachment",
type=AttachmentType(),
multiple=True,
help="Attachment path or URL or -",
)
@click.option(
"attachment_types",
"--at",
"--attachment-type",
type=(str, str),
multiple=True,
callback=attachment_types_callback,
help="Attachment with explicit mimetype",
)
@click.option(
"options",
"-o",
Expand Down Expand Up @@ -127,6 +195,8 @@ def prompt(
prompt,
system,
model_id,
attachments,
attachment_types,
options,
template,
param,
Expand All @@ -142,6 +212,14 @@ def prompt(
Execute a prompt
Documentation: https://llm.datasette.io/en/stable/usage.html
Examples:
\b
llm 'Capital of France?'
llm 'Capital of France?' -m gpt-4o
llm 'Capital of France?' -s 'answer in Spanish'
llm 'Extract text from this image' -a image.jpg
"""
if log and no_log:
raise click.ClickException("--log and --no-log are mutually exclusive")
Expand Down Expand Up @@ -262,6 +340,8 @@ def read_prompt():
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))

resolved_attachments = [*attachments, *attachment_types]

should_stream = model.can_stream and not no_stream
if not should_stream:
validated_options["stream"] = False
Expand All @@ -273,7 +353,9 @@ def read_prompt():
prompt_method = conversation.prompt

try:
response = prompt_method(prompt, system, **validated_options)
response = prompt_method(
prompt, *resolved_attachments, system=system, **validated_options
)
if should_stream:
for chunk in response:
print(chunk, end="")
Expand Down
27 changes: 24 additions & 3 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def register_models(register):
register(Chat("gpt-4-turbo-2024-04-09"))
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
# GPT-4o
register(Chat("gpt-4o"), aliases=("4o",))
register(Chat("gpt-4o-mini"), aliases=("4o-mini",))
register(Chat("gpt-4o", vision=True), aliases=("4o",))
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
# o1
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
Expand Down Expand Up @@ -271,6 +271,7 @@ def __init__(
api_engine=None,
headers=None,
can_stream=True,
vision=False,
allows_system_prompt=True,
):
self.model_id = model_id
Expand All @@ -282,8 +283,17 @@ def __init__(
self.api_engine = api_engine
self.headers = headers
self.can_stream = can_stream
self.vision = vision
self.allows_system_prompt = allows_system_prompt

if vision:
self.attachment_types = {
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
}

def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)

Expand All @@ -308,7 +318,18 @@ def execute(self, prompt, stream, response, conversation=None):
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
if not prompt.attachments:
messages.append({"role": "user", "content": prompt.prompt})
else:
vision_message = [{"type": "text", "text": prompt.prompt}]
for attachment in prompt.attachments:
url = attachment.url
if not url:
base64_image = attachment.base64_content()
url = f"data:{attachment.resolve_type()};base64,{base64_image}"
vision_message.append({"type": "image_url", "image_url": {"url": url}})
messages.append({"role": "user", "content": vision_message})

response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()
Expand Down
94 changes: 84 additions & 10 deletions llm/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import base64
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import httpx
from itertools import islice
import puremagic
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
Expand All @@ -13,17 +16,52 @@
CONVERSATION_NAME_LENGTH = 32


@dataclass
class Attachment:
type: Optional[str] = None
path: Optional[str] = None
url: Optional[str] = None
content: Optional[bytes] = None

def resolve_type(self):
if self.type:
return self.type
# Derive it from path or url or content
if self.path:
return puremagic.from_file(self.path, mime=True)
if self.url:
return puremagic.from_url(self.url, mime=True)
if self.content:
return puremagic.from_string(self.content, mime=True)
raise ValueError("Attachment has no type and no content to derive it from")

def base64_content(self):
content = self.content
if not content:
if self.path:
content = open(self.path, "rb").read()
elif self.url:
response = httpx.get(self.url)
response.raise_for_status()
content = response.content
return base64.b64encode(content).decode("utf-8")


@dataclass
class Prompt:
prompt: str
model: "Model"
system: Optional[str]
prompt_json: Optional[str]
options: "Options"
attachments: Optional[List[Attachment]] = field(default_factory=list)
system: Optional[str] = None
prompt_json: Optional[str] = None
options: "Options" = field(default_factory=dict)

def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
def __init__(
self, prompt, model, attachments, system=None, prompt_json=None, options=None
):
self.prompt = prompt
self.model = model
self.attachments = list(attachments)
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
Expand All @@ -39,15 +77,17 @@ class Conversation:
def prompt(
self,
prompt: Optional[str],
*attachments: Attachment,
system: Optional[str] = None,
stream: bool = True,
**options
):
return Response(
Prompt(
prompt,
system=system,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
Expand Down Expand Up @@ -158,14 +198,22 @@ def log_to_db(self, db):
db["responses"].insert(response)

@classmethod
def fake(cls, model: "Model", prompt: str, system: str, response: str):
def fake(
cls,
model: "Model",
prompt: str,
*attachments: List[Attachment],
system: str,
response: str
):
"Utility method to help with writing tests"
response_obj = cls(
model=model,
prompt=Prompt(
prompt,
system=system,
model=model,
attachments=attachments,
system=system,
),
stream=False,
)
Expand All @@ -183,8 +231,9 @@ def from_row(cls, row):
model=model,
prompt=Prompt(
prompt=row["prompt"],
system=row["system"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
Expand Down Expand Up @@ -242,10 +291,15 @@ def get_key(self):

class Model(ABC, _get_key_mixin):
model_id: str

# API key handling
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None

# Model characteristics
can_stream: bool = False
attachment_types = set()

class Options(_Options):
pass
Expand All @@ -269,13 +323,33 @@ def execute(

def prompt(
self,
prompt: Optional[str],
prompt: str,
*attachments: Attachment,
system: Optional[str] = None,
stream: bool = True,
**options
):
# Validate attachments
if attachments and not self.attachment_types:
raise ValueError(
"This model does not support attachments, but some were provided"
)
for attachment in attachments:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
"This model does not support attachments of type '{}', only {}".format(
attachment_type, ", ".join(self.attachment_types)
)
)
return self.response(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
stream=stream,
)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_long_description():
"setuptools",
"pip",
"pyreadline3; sys_platform == 'win32'",
"puremagic",
],
extras_require={
"test": [
Expand Down

0 comments on commit a1ee8ac

Please sign in to comment.