From a1ee8ac170d50dab03b9fdf0f124c6d1fb779753 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 26 Oct 2024 17:40:23 -0700 Subject: [PATCH] First working prototype of new attachments feature, refs #587 --- llm/__init__.py | 1 + llm/cli.py | 84 ++++++++++++++++++++++++- llm/default_plugins/openai_models.py | 27 +++++++- llm/models.py | 94 +++++++++++++++++++++++++--- setup.py | 1 + 5 files changed, 193 insertions(+), 14 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 9e8afacb..f76e2728 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,6 +4,7 @@ NeedsKeyException, ) from .models import ( + Attachment, Conversation, Model, ModelWithAliases, diff --git a/llm/cli.py b/llm/cli.py index a1b14576..33e14f09 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -4,6 +4,7 @@ import io import json from llm import ( + Attachment, Collection, Conversation, Response, @@ -30,7 +31,9 @@ from .migrations import migrate from .plugins import pm import base64 +import httpx import pathlib +import puremagic import pydantic import readline from runpy import run_module @@ -48,6 +51,54 @@ DEFAULT_TEMPLATE = "prompt: " +class AttachmentType(click.ParamType): + name = "attachment" + + def convert(self, value, param, ctx): + if value == "-": + content = sys.stdin.buffer.read() + # Try to guess type + try: + mimetype = puremagic.from_string(content, mime=True) + except puremagic.PureError: + raise click.BadParameter("Could not determine mimetype of stdin") + return Attachment(mimetype, None, None, content) + if "://" in value: + # Confirm URL exists and try to guess type + try: + response = httpx.head(value) + response.raise_for_status() + mimetype = response.headers.get("content-type") + except httpx.HTTPError as ex: + raise click.BadParameter(str(ex)) + return Attachment(mimetype, None, value, None) + # Check that the file exists + path = pathlib.Path(value) + if not path.exists(): + self.fail(f"File {value} does not exist", param, ctx) + # Try to guess type + mimetype = puremagic.from_file(str(path), mime=True) + return Attachment(mimetype, str(path), None, None) + + +def attachment_types_callback(ctx, param, values): + collected = [] + for value, mimetype in values: + if "://" in value: + attachment = Attachment(mimetype, None, value, None) + elif value == "-": + content = sys.stdin.buffer.read() + attachment = Attachment(mimetype, None, None, content) + else: + # Look for file + path = pathlib.Path(value) + if not path.exists(): + raise click.BadParameter(f"File {value} does not exist") + attachment = Attachment(mimetype, str(path), None, None) + collected.append(attachment) + return collected + + def _validate_metadata_json(ctx, param, value): if value is None: return value @@ -88,6 +139,23 @@ def cli(): @click.argument("prompt", required=False) @click.option("-s", "--system", help="System prompt to use") @click.option("model_id", "-m", "--model", help="Model to use") +@click.option( + "attachments", + "-a", + "--attachment", + type=AttachmentType(), + multiple=True, + help="Attachment path or URL or -", +) +@click.option( + "attachment_types", + "--at", + "--attachment-type", + type=(str, str), + multiple=True, + callback=attachment_types_callback, + help="Attachment with explicit mimetype", +) @click.option( "options", "-o", @@ -127,6 +195,8 @@ def prompt( prompt, system, model_id, + attachments, + attachment_types, options, template, param, @@ -142,6 +212,14 @@ def prompt( Execute a prompt Documentation: https://llm.datasette.io/en/stable/usage.html + + Examples: + + \b + llm 'Capital of France?' + llm 'Capital of France?' -m gpt-4o + llm 'Capital of France?' -s 'answer in Spanish' + llm 'Extract text from this image' -a image.jpg """ if log and no_log: raise click.ClickException("--log and --no-log are mutually exclusive") @@ -262,6 +340,8 @@ def read_prompt(): except pydantic.ValidationError as ex: raise click.ClickException(render_errors(ex.errors())) + resolved_attachments = [*attachments, *attachment_types] + should_stream = model.can_stream and not no_stream if not should_stream: validated_options["stream"] = False @@ -273,7 +353,9 @@ def read_prompt(): prompt_method = conversation.prompt try: - response = prompt_method(prompt, system, **validated_options) + response = prompt_method( + prompt, *resolved_attachments, system=system, **validated_options + ) if should_stream: for chunk in response: print(chunk, end="") diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 657c0d20..913e7545 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -33,8 +33,8 @@ def register_models(register): register(Chat("gpt-4-turbo-2024-04-09")) register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t")) # GPT-4o - register(Chat("gpt-4o"), aliases=("4o",)) - register(Chat("gpt-4o-mini"), aliases=("4o-mini",)) + register(Chat("gpt-4o", vision=True), aliases=("4o",)) + register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",)) # o1 register(Chat("o1-preview", can_stream=False, allows_system_prompt=False)) register(Chat("o1-mini", can_stream=False, allows_system_prompt=False)) @@ -271,6 +271,7 @@ def __init__( api_engine=None, headers=None, can_stream=True, + vision=False, allows_system_prompt=True, ): self.model_id = model_id @@ -282,8 +283,17 @@ def __init__( self.api_engine = api_engine self.headers = headers self.can_stream = can_stream + self.vision = vision self.allows_system_prompt = allows_system_prompt + if vision: + self.attachment_types = { + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + } + def __str__(self): return "OpenAI Chat: {}".format(self.model_id) @@ -308,7 +318,18 @@ def execute(self, prompt, stream, response, conversation=None): messages.append({"role": "assistant", "content": prev_response.text()}) if prompt.system and prompt.system != current_system: messages.append({"role": "system", "content": prompt.system}) - messages.append({"role": "user", "content": prompt.prompt}) + if not prompt.attachments: + messages.append({"role": "user", "content": prompt.prompt}) + else: + vision_message = [{"type": "text", "text": prompt.prompt}] + for attachment in prompt.attachments: + url = attachment.url + if not url: + base64_image = attachment.base64_content() + url = f"data:{attachment.resolve_type()};base64,{base64_image}" + vision_message.append({"type": "image_url", "image_url": {"url": url}}) + messages.append({"role": "user", "content": vision_message}) + response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) client = self.get_client() diff --git a/llm/models.py b/llm/models.py index 0e47bb60..77bdb8e9 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,7 +1,10 @@ +import base64 from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException +import httpx from itertools import islice +import puremagic import re import time from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union @@ -13,17 +16,52 @@ CONVERSATION_NAME_LENGTH = 32 +@dataclass +class Attachment: + type: Optional[str] = None + path: Optional[str] = None + url: Optional[str] = None + content: Optional[bytes] = None + + def resolve_type(self): + if self.type: + return self.type + # Derive it from path or url or content + if self.path: + return puremagic.from_file(self.path, mime=True) + if self.url: + return puremagic.from_url(self.url, mime=True) + if self.content: + return puremagic.from_string(self.content, mime=True) + raise ValueError("Attachment has no type and no content to derive it from") + + def base64_content(self): + content = self.content + if not content: + if self.path: + content = open(self.path, "rb").read() + elif self.url: + response = httpx.get(self.url) + response.raise_for_status() + content = response.content + return base64.b64encode(content).decode("utf-8") + + @dataclass class Prompt: prompt: str model: "Model" - system: Optional[str] - prompt_json: Optional[str] - options: "Options" + attachments: Optional[List[Attachment]] = field(default_factory=list) + system: Optional[str] = None + prompt_json: Optional[str] = None + options: "Options" = field(default_factory=dict) - def __init__(self, prompt, model, system=None, prompt_json=None, options=None): + def __init__( + self, prompt, model, attachments, system=None, prompt_json=None, options=None + ): self.prompt = prompt self.model = model + self.attachments = list(attachments) self.system = system self.prompt_json = prompt_json self.options = options or {} @@ -39,6 +77,7 @@ class Conversation: def prompt( self, prompt: Optional[str], + *attachments: Attachment, system: Optional[str] = None, stream: bool = True, **options @@ -46,8 +85,9 @@ def prompt( return Response( Prompt( prompt, - system=system, model=self.model, + attachments=attachments, + system=system, options=self.model.Options(**options), ), self.model, @@ -158,14 +198,22 @@ def log_to_db(self, db): db["responses"].insert(response) @classmethod - def fake(cls, model: "Model", prompt: str, system: str, response: str): + def fake( + cls, + model: "Model", + prompt: str, + *attachments: List[Attachment], + system: str, + response: str + ): "Utility method to help with writing tests" response_obj = cls( model=model, prompt=Prompt( prompt, - system=system, model=model, + attachments=attachments, + system=system, ), stream=False, ) @@ -183,8 +231,9 @@ def from_row(cls, row): model=model, prompt=Prompt( prompt=row["prompt"], - system=row["system"], model=model, + attachments=[], + system=row["system"], options=model.Options(**json.loads(row["options_json"])), ), stream=False, @@ -242,10 +291,15 @@ def get_key(self): class Model(ABC, _get_key_mixin): model_id: str + + # API key handling key: Optional[str] = None needs_key: Optional[str] = None key_env_var: Optional[str] = None + + # Model characteristics can_stream: bool = False + attachment_types = set() class Options(_Options): pass @@ -269,13 +323,33 @@ def execute( def prompt( self, - prompt: Optional[str], + prompt: str, + *attachments: Attachment, system: Optional[str] = None, stream: bool = True, **options ): + # Validate attachments + if attachments and not self.attachment_types: + raise ValueError( + "This model does not support attachments, but some were provided" + ) + for attachment in attachments: + attachment_type = attachment.resolve_type() + if attachment_type not in self.attachment_types: + raise ValueError( + "This model does not support attachments of type '{}', only {}".format( + attachment_type, ", ".join(self.attachment_types) + ) + ) return self.response( - Prompt(prompt, system=system, model=self, options=self.Options(**options)), + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), stream=stream, ) diff --git a/setup.py b/setup.py index 1f6adcd7..b8b55bf8 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ def get_long_description(): "setuptools", "pip", "pyreadline3; sys_platform == 'win32'", + "puremagic", ], extras_require={ "test": [