Skip to content

Commit

Permalink
Store content_hash in embeddings table, refs #217
Browse files Browse the repository at this point in the history
Uses new migrations feature from simonw/sqlite-migrate#9
  • Loading branch information
simonw committed Sep 3, 2023
1 parent 2633204 commit a5d6b58
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/embeddings/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ CREATE TABLE "embeddings" (
[id] TEXT,
[embedding] BLOB,
[content] TEXT,
[content_hash] BLOB,
[metadata] TEXT,
[updated] INTEGER,
PRIMARY KEY ([collection_id], [id])
Expand Down
7 changes: 7 additions & 0 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
from dataclasses import dataclass
import hashlib
from itertools import islice
import json
from sqlite_utils import Database
Expand Down Expand Up @@ -133,6 +134,7 @@ def embed(
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"content_hash": self.content_hash(text),
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
},
Expand Down Expand Up @@ -279,3 +281,8 @@ def similar(self, text: str, number: int = 10) -> List[Entry]:
"""
comparison_vector = self.model().embed(text)
return self.similar_by_vector(comparison_vector, number)

@staticmethod
def content_hash(text: str) -> bytes:
"Hash content for deduplication. Override to change hashing behavior."
return hashlib.md5(text.encode("utf8")).digest()
47 changes: 47 additions & 0 deletions llm/embeddings_migrations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlite_migrate import Migrations
import hashlib
import time

embeddings_migrations = Migrations("llm.embeddings")
Expand Down Expand Up @@ -34,3 +35,49 @@ def m003_add_updated(db):
db.query(
"update embeddings set updated = ? where updated is null", [int(time.time())]
)


@embeddings_migrations()
def m004_store_content_hash(db):
db["embeddings"].add_column("content_hash", bytes)
db["embeddings"].transform(
column_order=(
"collection_id",
"id",
"embedding",
"content",
"content_hash",
"metadata",
"updated",
)
)

# Backfill content_hash
@db.register_function
def md5(text):
return hashlib.md5(text.encode("utf8")).digest()

@db.register_function
def random_md5():
return hashlib.md5(str(time.time()).encode("utf8")).digest()

rows = list(db["embeddings"].rows)
print(rows)

with db.conn:
db.execute(
"""
update embeddings
set content_hash = md5(content)
where content is not null
"""
)
db.execute(
"""
update embeddings
set content_hash = random_md5()
where content is null
"""
)
# rows = list(db["embeddings"].rows)
db["embeddings"].create_index(["content_hash"])
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_long_description():
"openai",
"click-default-group-wheel",
"sqlite-utils>=3.35.0",
"sqlite-migrate",
"sqlite-migrate>=0.1a2",
"pydantic>=1.10.2",
"PyYAML",
"pluggy",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_collection(collection):
"id": "1",
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"content_hash": collection.content_hash("hello world"),
"metadata": None,
"updated": ANY,
},
Expand All @@ -73,6 +74,7 @@ def test_collection(collection):
"id": "2",
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"content_hash": collection.content_hash("goodbye world"),
"metadata": None,
"updated": ANY,
},
Expand Down
2 changes: 2 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from click.testing import CliRunner
from llm.cli import cli
from llm import Collection
import json
import pytest
import sqlite_utils
Expand Down Expand Up @@ -119,6 +120,7 @@ def test_embed_store(user_path, metadata, metadata_error):
b"\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"content_hash": Collection.content_hash("hello"),
"metadata": expected_metadata,
"updated": ANY,
}
Expand Down
47 changes: 47 additions & 0 deletions tests/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import llm
from llm.migrations import migrate
from llm.embeddings_migrations import embeddings_migrations
import pytest
Expand Down Expand Up @@ -90,8 +91,54 @@ def test_migrations_for_embeddings():
"id": str,
"embedding": bytes,
"content": str,
"content_hash": bytes,
"metadata": str,
"updated": int,
}
assert db["embeddings"].foreign_keys[0].column == "collection_id"
assert db["embeddings"].foreign_keys[0].other_table == "collections"


def test_backfill_content_hash():
db = sqlite_utils.Database(memory=True)
# Run migrations up to but not including m004_store_content_hash
embeddings_migrations.apply(db, stop_before="m004_store_content_hash")
assert "content_hash" not in db["embeddings"].columns_dict
# Add some some directly directly because llm.Collection would run migrations
db["embeddings"].insert_all(
[
{
"collection_id": 1,
"id": "1",
"embedding": (
b"\x00\x00\xa0@\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"metadata": None,
"updated": 1693763088,
},
{
"collection_id": 1,
"id": "2",
"embedding": (
b"\x00\x00\xe0@\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
),
"content": "goodbye world",
"metadata": None,
"updated": 1693763088,
},
]
)
# Now finish the migrations
embeddings_migrations.apply(db)
row1, row2 = db["embeddings"].rows
# This one should be random:
assert row1["content_hash"] is not None
# This should be a hash of 'goodbye world'
assert row2["content_hash"] == llm.Collection.content_hash("goodbye world")

0 comments on commit a5d6b58

Please sign in to comment.