Skip to content

Commit

Permalink
Duplicate content is only embedded once, closes #217
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 4, 2023
1 parent 0eda99e commit 3bf781f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/embeddings/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Embeddings are much more useful if you store them somewhere, so you can calculat

LLM includes the concept of a "collection" of embeddings. A collection groups together a set of stored embeddings created using the same model, each with a unique ID within that collection.

Embeddings also store a hash of the content that was embedded. This hash is later used to avoid calculating duplicate embeddings for the same content.

First, we'll set a default model so we don't have to keep repeating it:
```bash
llm embed-models default ada-002
Expand Down
32 changes: 29 additions & 3 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,19 @@ def embed(
"""
from llm import encode

content_hash = self.content_hash(text)
if self.db["embeddings"].count_where(
"content_hash = ? and collection_id = ?", [content_hash, self.id]
):
return
embedding = self.model().embed(text)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id,
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"content_hash": self.content_hash(text),
"content_hash": content_hash,
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
},
Expand Down Expand Up @@ -183,7 +188,26 @@ def embed_multi_with_metadata(
batch = list(islice(iterator, batch_size))
if not batch:
break
embeddings = list(self.model().embed_multi(item[1] for item in batch))
# Calculate hashes first
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
# Any of those hashes already exist?
existing_ids = [
row["id"]
for row in self.db.query(
"""
select id from embeddings
where collection_id = ? and content_hash in ({})
""".format(
",".join("?" for _ in items_and_hashes)
),
[collection_id]
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
)
]
filtered_batch = [item for item in batch if item[0] not in existing_ids]
embeddings = list(
self.model().embed_multi(item[1] for item in filtered_batch)
)
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
Expand All @@ -196,7 +220,9 @@ def embed_multi_with_metadata(
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
}
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
for (embedding, (id, text, metadata)) in zip(
embeddings, filtered_batch
)
),
replace=True,
)
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,35 @@ class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10

def __init__(self):
self.embedded_content = []

def embed_batch(self, texts):
if not hasattr(self, "batch_count"):
self.batch_count = 0
self.batch_count += 1
for text in texts:
self.embedded_content.append(text)
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
yield embedding


@pytest.fixture
def embed_demo():
return EmbedDemo()


@pytest.fixture(autouse=True)
def register_embed_demo_model():
def register_embed_demo_model(embed_demo):
class EmbedDemoPlugin:
__name__ = "EmbedDemoPlugin"

@llm.hookimpl
def register_embedding_models(self, register):
register(EmbedDemo())
register(embed_demo)

pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin")
try:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,32 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
# At the end of this, there should be 2 embeddings
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
assert db["embeddings"].count == 1


def test_duplicate_content_embedded_only_once(embed_demo):
# content_hash should avoid embedding the same content twice
# per collection
db = sqlite_utils.Database(memory=True)
assert len(embed_demo.embedded_content) == 0
collection = Collection("test", db, model_id="embed-demo")
collection.embed("1", "hello world")
assert len(embed_demo.embedded_content) == 1
collection.embed("2", "goodbye world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
collection.embed("1", "hello world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
# The same string in another collection should be embedded
c2 = Collection("test2", db, model_id="embed-demo")
c2.embed("1", "hello world")
assert db["embeddings"].count == 3
assert len(embed_demo.embedded_content) == 3

# Same again for embed_multi
collection.embed_multi(
(("1", "hello world"), ("2", "goodbye world"), ("3", "this is new"))
)
# Should have only embedded one more thing
assert db["embeddings"].count == 4
assert len(embed_demo.embedded_content) == 4

0 comments on commit 3bf781f

Please sign in to comment.