diff --git a/llm/__init__.py b/llm/__init__.py index f7e3662d..57855747 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -13,6 +13,7 @@ Prompt, Response, ) +from .embeddings import Collection from .templates import Template from .plugins import pm import click @@ -20,12 +21,14 @@ import json import os import pathlib +import struct __all__ = [ "hookimpl", "get_model", "get_key", "user_dir", + "Collection", "Conversation", "Model", "Options", @@ -226,3 +229,11 @@ def remove_alias(alias): raise KeyError("No such alias: {}".format(alias)) del current[alias] path.write_text(json.dumps(current, indent=4) + "\n") + + +def encode(values): + return struct.pack("<" + "f" * len(values), *values) + + +def decode(binary): + return struct.unpack("<" + "f" * (len(binary) // 4), binary) diff --git a/llm/cli.py b/llm/cli.py index b6c2460d..c99b9e92 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -6,6 +6,8 @@ Response, Template, UnknownModelError, + decode, + encode, get_embedding_models_with_aliases, get_embedding_model, get_key, @@ -28,7 +30,6 @@ import shutil import sqlite_utils from sqlite_utils.db import NotFoundError -import struct import sys import textwrap from typing import cast, Optional @@ -1288,14 +1289,6 @@ def logs_on(): return not (user_dir() / "logs-off").exists() -def encode(values): - return struct.pack("<" + "f" * len(values), *values) - - -def decode(binary): - return struct.unpack("<" + "f" * (len(binary) // 4), binary) - - def cosine_similarity(a, b): dot_product = sum(x * y for x, y in zip(a, b)) magnitude_a = sum(x * x for x in a) ** 0.5 diff --git a/llm/embeddings.py b/llm/embeddings.py new file mode 100644 index 00000000..5a465aae --- /dev/null +++ b/llm/embeddings.py @@ -0,0 +1,162 @@ +from .models import EmbeddingModel +from .embeddings_migrations import embeddings_migrations +import json +from sqlite_utils import Database +from typing import Any, Dict, List, Tuple, Optional, Union + + +class Collection: + def __init__( + self, + db: Database, + name: str, + *, + model: Optional[EmbeddingModel] = None, + model_id: Optional[str] = None, + ) -> None: + from llm import get_embedding_model + + self.db = db + self.name = name + if model and model_id and model.model_id != model_id: + raise ValueError("model_id does not match model.model_id") + if model_id and not model: + model = get_embedding_model(model_id) + self.model = model + self._id = None + + def id(self) -> int: + """ + Get the ID of the collection, creating it in the DB if necessary. + + Returns: + int: ID of the collection + """ + if self._id is not None: + return self._id + if not self.db["collections"].exists(): + embeddings_migrations.apply(self.db) + rows = self.db["collections"].rows_where("name = ?", [self.name]) + try: + row = next(rows) + self._id = row["id"] + except StopIteration: + # Create it + self._id = ( + self.db["collections"] + .insert( + { + "name": self.name, + "model": self.model.model_id, + } + ) + .last_pk + ) + return self._id + + def exists(self) -> bool: + """ + Check if the collection exists in the DB. + + Returns: + bool: True if exists, False otherwise + """ + matches = list( + self.db.query("select 1 from collections where name = ?", (self.name,)) + ) + return bool(matches) + + def count(self) -> int: + """ + Count the number of items in the collection. + + Returns: + int: Number of items in the collection + """ + return next( + self.db.query( + """ + select count(*) as c from embeddings where collection_id = ( + select id from collections where name = ? + ) + """, + (self.name,), + ) + )["c"] + + def embed( + self, + id: str, + text: str, + metadata: Optional[Dict[str, Any]] = None, + store: bool = False, + ) -> None: + """ + Embed a text and store it in the collection with a given ID. + + Args: + id (str): ID for the text + text (str): Text to be embedded + metadata (dict, optional): Metadata to be stored + store (bool, optional): Whether to store the text in the content column + """ + from llm import encode + + embedding = self.model.embed(text) + self.db["embeddings"].insert( + { + "collection_id": self.id(), + "id": id, + "embedding": encode(embedding), + "content": text if store else None, + "metadata": json.dumps(metadata) if metadata else None, + } + ) + + def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None: + """ + Embed multiple texts and store them in the collection with given IDs. + + Args: + id_text_map (dict): Dictionary mapping IDs to texts + store (bool, optional): Whether to store the text in the content column + """ + raise NotImplementedError + + def embed_multi_with_metadata( + self, + id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]], + ) -> None: + """ + Embed multiple texts along with metadata and store them in the collection with given IDs. + + Args: + id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples + """ + raise NotImplementedError + + def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]: + """ + Find similar items in the collection by a given ID. + + Args: + id (str): ID to search by + number (int, optional): Number of similar items to return + + Returns: + list: List of (id, score) tuples + """ + raise NotImplementedError + + def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]: + """ + Find similar items in the collection by a given text. + + Args: + text (str): Text to search by + number (int, optional): Number of similar items to return + + Returns: + list: List of (id, score) tuples + """ + raise NotImplementedError diff --git a/tests/test_embed.py b/tests/test_embed.py index efd4ab89..409e8aed 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -1,4 +1,5 @@ import llm +import sqlite_utils def test_demo_plugin(): @@ -18,3 +19,32 @@ def test_embed_huge_list(): assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900} # Should have happened in 100 batches assert model.batch_count == 100 + + +def test_collection(): + db = sqlite_utils.Database(memory=True) + collection = llm.Collection(db, "test", model_id="embed-demo") + assert collection.id() == 1 + assert collection.count() == 0 + # Embed some stuff + collection.embed(1, "hello world") + collection.embed(2, "goodbye world") + assert collection.count() == 2 + # Check that the embeddings are there + rows = list(db["embeddings"].rows) + assert rows == [ + { + "collection_id": 1, + "id": "1", + "embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + "content": None, + "metadata": None, + }, + { + "collection_id": 1, + "id": "2", + "embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + "content": None, + "metadata": None, + }, + ]