Skip to content

Commit

Permalink
Initial Collection class plus test, refs #191
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 1, 2023
1 parent c25e7c4 commit 6f76170
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 9 deletions.
11 changes: 11 additions & 0 deletions llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
Prompt,
Response,
)
from .embeddings import Collection
from .templates import Template
from .plugins import pm
import click
from typing import Dict, List, Optional
import json
import os
import pathlib
import struct

__all__ = [
"hookimpl",
"get_model",
"get_key",
"user_dir",
"Collection",
"Conversation",
"Model",
"Options",
Expand Down Expand Up @@ -226,3 +229,11 @@ def remove_alias(alias):
raise KeyError("No such alias: {}".format(alias))
del current[alias]
path.write_text(json.dumps(current, indent=4) + "\n")


def encode(values):
return struct.pack("<" + "f" * len(values), *values)


def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
11 changes: 2 additions & 9 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Response,
Template,
UnknownModelError,
decode,
encode,
get_embedding_models_with_aliases,
get_embedding_model,
get_key,
Expand All @@ -28,7 +30,6 @@
import shutil
import sqlite_utils
from sqlite_utils.db import NotFoundError
import struct
import sys
import textwrap
from typing import cast, Optional
Expand Down Expand Up @@ -1288,14 +1289,6 @@ def logs_on():
return not (user_dir() / "logs-off").exists()


def encode(values):
return struct.pack("<" + "f" * len(values), *values)


def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)


def cosine_similarity(a, b):
dot_product = sum(x * y for x, y in zip(a, b))
magnitude_a = sum(x * x for x in a) ** 0.5
Expand Down
162 changes: 162 additions & 0 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
import json
from sqlite_utils import Database
from typing import Any, Dict, List, Tuple, Optional, Union


class Collection:
def __init__(
self,
db: Database,
name: str,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
) -> None:
from llm import get_embedding_model

self.db = db
self.name = name
if model and model_id and model.model_id != model_id:
raise ValueError("model_id does not match model.model_id")
if model_id and not model:
model = get_embedding_model(model_id)
self.model = model
self._id = None

def id(self) -> int:
"""
Get the ID of the collection, creating it in the DB if necessary.
Returns:
int: ID of the collection
"""
if self._id is not None:
return self._id
if not self.db["collections"].exists():
embeddings_migrations.apply(self.db)
rows = self.db["collections"].rows_where("name = ?", [self.name])
try:
row = next(rows)
self._id = row["id"]
except StopIteration:
# Create it
self._id = (
self.db["collections"]
.insert(
{
"name": self.name,
"model": self.model.model_id,
}
)
.last_pk
)
return self._id

def exists(self) -> bool:
"""
Check if the collection exists in the DB.
Returns:
bool: True if exists, False otherwise
"""
matches = list(
self.db.query("select 1 from collections where name = ?", (self.name,))
)
return bool(matches)

def count(self) -> int:
"""
Count the number of items in the collection.
Returns:
int: Number of items in the collection
"""
return next(
self.db.query(
"""
select count(*) as c from embeddings where collection_id = (
select id from collections where name = ?
)
""",
(self.name,),
)
)["c"]

def embed(
self,
id: str,
text: str,
metadata: Optional[Dict[str, Any]] = None,
store: bool = False,
) -> None:
"""
Embed a text and store it in the collection with a given ID.
Args:
id (str): ID for the text
text (str): Text to be embedded
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the text in the content column
"""
from llm import encode

embedding = self.model.embed(text)
self.db["embeddings"].insert(
{
"collection_id": self.id(),
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"metadata": json.dumps(metadata) if metadata else None,
}
)

def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
"""
Embed multiple texts and store them in the collection with given IDs.
Args:
id_text_map (dict): Dictionary mapping IDs to texts
store (bool, optional): Whether to store the text in the content column
"""
raise NotImplementedError

def embed_multi_with_metadata(
self,
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
) -> None:
"""
Embed multiple texts along with metadata and store them in the collection with given IDs.
Args:
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
"""
raise NotImplementedError

def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given ID.
Args:
id (str): ID to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError

def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given text.
Args:
text (str): Text to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError
30 changes: 30 additions & 0 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import llm
import sqlite_utils


def test_demo_plugin():
Expand All @@ -18,3 +19,32 @@ def test_embed_huge_list():
assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}
# Should have happened in 100 batches
assert model.batch_count == 100


def test_collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
assert collection.id() == 1
assert collection.count() == 0
# Embed some stuff
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
assert collection.count() == 2
# Check that the embeddings are there
rows = list(db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "1",
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"metadata": None,
},
{
"collection_id": 1,
"id": "2",
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"metadata": None,
},
]

0 comments on commit 6f76170

Please sign in to comment.