Skip to content

Commit

Permalink
Collection now defaults to in-memory DB, closes #213
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 2, 2023
1 parent 8bdaca1 commit 51488c5
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 12 deletions.
12 changes: 8 additions & 4 deletions docs/embeddings/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ To work with embeddings in this way you will need an instance of a [sqlite-utils
import sqlite_utils
import llm

# This collection will use an in-memory database that will be
# discarded when the Python process exits
collection = llm.Collection("entries", model_id="ada-002")

# Or you can persist the database to disk like this:
db = sqlite_utils.Database("my-embeddings.db")
# Pass model_id= to specify a model for the collection
collection = llm.Collection(db, "entries", model_id="ada-002")
collection = llm.Collection("entries", db, model_id="ada-002")

# Or you can pass a model directly using model=
# You can pass a model directly using model= instead of model_id=
embedding_model = llm.get_embedding_model("ada-002")
collection = llm.Collection(db, "entries", model=embedding_model)
collection = llm.Collection("entries", db, model=embedding_model)
```
If the collection already exists in the database you can omit the `model` or `model_id` argument - the model ID will be read from the `collections` table.

Expand Down
4 changes: 2 additions & 2 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def get_db():
model_obj = None
if collection:
db = get_db()
collection_obj = Collection(db, collection, model_id=model)
collection_obj = Collection(collection, db, model_id=model)
model_obj = collection_obj.model()

if model_obj is None:
Expand Down Expand Up @@ -995,7 +995,7 @@ def similar(collection, id, input, content, number, database):
raise click.ClickException("No embeddings table found in database")

try:
collection_obj = Collection(db, collection, create=False)
collection_obj = Collection(collection, db, create=False)
except Collection.DoesNotExist:
raise click.ClickException("Collection does not exist")

Expand Down
4 changes: 2 additions & 2 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class DoesNotExist(Exception):

def __init__(
self,
db: Database,
name: str,
db: Optional[Database] = None,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
Expand All @@ -48,7 +48,7 @@ def __init__(
"""
import llm

self.db = db
self.db = db or Database(memory=True)
self.name = name
self._model = model

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def user_path(tmpdir):
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection(db, "demo", model_id="embed-demo")
collection = llm.Collection("demo", db, model_id="embed-demo")
collection.embed("1", "hello world")
collection.embed("2", "goodbye world")

Expand Down
5 changes: 2 additions & 3 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

@pytest.fixture
def collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
collection = llm.Collection("test", model_id="embed-demo")
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
return collection
Expand Down Expand Up @@ -95,7 +94,7 @@ def test_similar_by_id(collection):
@pytest.mark.parametrize("with_metadata", (False, True))
def test_embed_multi(with_metadata):
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
collection = llm.Collection("test", db, model_id="embed-demo")
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
if with_metadata:
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)
Expand Down

0 comments on commit 51488c5

Please sign in to comment.