diff --git a/docs/embeddings/python-api.md b/docs/embeddings/python-api.md index 67d49e32..1f9e8c9f 100644 --- a/docs/embeddings/python-api.md +++ b/docs/embeddings/python-api.md @@ -30,13 +30,17 @@ To work with embeddings in this way you will need an instance of a [sqlite-utils import sqlite_utils import llm +# This collection will use an in-memory database that will be +# discarded when the Python process exits +collection = llm.Collection("entries", model_id="ada-002") + +# Or you can persist the database to disk like this: db = sqlite_utils.Database("my-embeddings.db") -# Pass model_id= to specify a model for the collection -collection = llm.Collection(db, "entries", model_id="ada-002") +collection = llm.Collection("entries", db, model_id="ada-002") -# Or you can pass a model directly using model= +# You can pass a model directly using model= instead of model_id= embedding_model = llm.get_embedding_model("ada-002") -collection = llm.Collection(db, "entries", model=embedding_model) +collection = llm.Collection("entries", db, model=embedding_model) ``` If the collection already exists in the database you can omit the `model` or `model_id` argument - the model ID will be read from the `collections` table. diff --git a/llm/cli.py b/llm/cli.py index 05c6c3e4..298671b3 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -910,7 +910,7 @@ def get_db(): model_obj = None if collection: db = get_db() - collection_obj = Collection(db, collection, model_id=model) + collection_obj = Collection(collection, db, model_id=model) model_obj = collection_obj.model() if model_obj is None: @@ -995,7 +995,7 @@ def similar(collection, id, input, content, number, database): raise click.ClickException("No embeddings table found in database") try: - collection_obj = Collection(db, collection, create=False) + collection_obj = Collection(collection, db, create=False) except Collection.DoesNotExist: raise click.ClickException("Collection does not exist") diff --git a/llm/embeddings.py b/llm/embeddings.py index 997a61c6..2ea8e1c8 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -24,8 +24,8 @@ class DoesNotExist(Exception): def __init__( self, - db: Database, name: str, + db: Optional[Database] = None, *, model: Optional[EmbeddingModel] = None, model_id: Optional[str] = None, @@ -48,7 +48,7 @@ def __init__( """ import llm - self.db = db + self.db = db or Database(memory=True) self.name = name self._model = model diff --git a/tests/conftest.py b/tests/conftest.py index 0b7dc504..1b8443ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def user_path(tmpdir): def user_path_with_embeddings(user_path): path = str(user_path / "embeddings.db") db = sqlite_utils.Database(path) - collection = llm.Collection(db, "demo", model_id="embed-demo") + collection = llm.Collection("demo", db, model_id="embed-demo") collection.embed("1", "hello world") collection.embed("2", "goodbye world") diff --git a/tests/test_embed.py b/tests/test_embed.py index 19b60f6e..d7c34423 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -7,8 +7,7 @@ @pytest.fixture def collection(): - db = sqlite_utils.Database(memory=True) - collection = llm.Collection(db, "test", model_id="embed-demo") + collection = llm.Collection("test", model_id="embed-demo") collection.embed(1, "hello world") collection.embed(2, "goodbye world") return collection @@ -95,7 +94,7 @@ def test_similar_by_id(collection): @pytest.mark.parametrize("with_metadata", (False, True)) def test_embed_multi(with_metadata): db = sqlite_utils.Database(memory=True) - collection = llm.Collection(db, "test", model_id="embed-demo") + collection = llm.Collection("test", db, model_id="embed-demo") ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000)) if with_metadata: ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)