diff --git a/docs/embeddings/cli.md b/docs/embeddings/cli.md index ed96fe6a..5dfa99e7 100644 --- a/docs/embeddings/cli.md +++ b/docs/embeddings/cli.md @@ -13,21 +13,24 @@ The `llm embed` command can be used to calculate embedding vectors for a string The simplest way to use this command is to pass content to it using the `-c/--content` option, like this: ```bash -llm embed -c 'This is some content' +llm embed -c 'This is some content' -m ada-002 ``` -The command will return a JSON array of floating point numbers directly to the terminal: +`-m ada-002` specifies the OpenAI `ada-002` model. You will need to have set an OpenAI API key using `llm keys set openai` for this to work. -```json -[0.123, 0.456, 0.789...] +You can install plugins to access other models. The [llm-sentence-transformers](https://github.com/simonw/llm-sentence-transformers) plugin can be used to run models on your own laptop, such as the [MiniLM-L6](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model: + +```bash +llm install llm-sentence-transformers +llm embed -c 'This is some content' -m sentence-transformers/all-MiniLM-L6-v2 ``` -By default it uses the {ref}`default embedding model `. -Use the `-m/--model` option to specify a different model: +The `llm embed` command returns a JSON array of floating point numbers directly to the terminal: -```bash -llm -m sentence-transformers/all-MiniLM-L6-v2 \ - -c 'This is some content' +```json +[0.123, 0.456, 0.789...] ``` +You can omit the `-m/--model` option if you set a {ref}`default embedding model `. + See {ref}`embeddings-binary` for options to get back embeddings in formats other than JSON. (embeddings-collections)= @@ -37,6 +40,11 @@ Embeddings are much more useful if you store them somewhere, so you can calculat LLM includes the concept of a "collection" of embeddings. A collection groups together a set of stored embeddings created using the same model, each with a unique ID within that collection. +First, we'll set a default model so we don't have to keep repeating it: +```bash +llm embed-models default ada-002 +``` + The `llm embed` command can store results directly in a named collection like this: ```bash diff --git a/docs/embeddings/python-api.md b/docs/embeddings/python-api.md index 3f3e6f23..cfe338b7 100644 --- a/docs/embeddings/python-api.md +++ b/docs/embeddings/python-api.md @@ -106,6 +106,13 @@ A collection instance has the following properties and methods: - `similar_by_vector(vector: List[float], number: int=10, skip_id: str=None)` - returns a list of entries that are most similar to the given embedding vector, optionally skipping the entry with the given ID - `delete()` - deletes the collection and its embeddings from the database +There is also a `Collection.exists(db, name)` class method which returns a boolean value and can be used to determine if a collection exists or not in a database: + +```python +if Collection.exists(db, "entries"): + print("The entries collection exists") +``` + (embeddings-python-similar)= ## Retrieving similar items diff --git a/llm/__init__.py b/llm/__init__.py index 77756cae..e2ac276d 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -115,7 +115,7 @@ def get_embedding_model(name): try: return aliases[name] except KeyError: - raise UnknownModelError("Unknown model: " + name) + raise UnknownModelError("Unknown model: " + str(name)) def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]: diff --git a/llm/cli.py b/llm/cli.py index 91c8049d..d985fe84 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -929,20 +929,28 @@ def get_db(): model_obj = None if collection: db = get_db() - collection_obj = Collection(collection, db, model_id=model) - model_obj = collection_obj.model() + if Collection.exists(db, collection): + # Load existing collection and use its model + collection_obj = Collection(collection, db) + model_obj = collection_obj.model() + else: + # We will create a new one, but that means model is required + if not model: + model = get_default_embedding_model() + if model is None: + raise click.ClickException( + "You need to specify a model (no default model is set)" + ) + collection_obj = Collection(collection, db=db, model_id=model) + model_obj = collection_obj.model() if model_obj is None: - if not model: - model = get_default_embedding_model() - if model is None: - raise click.ClickException( - "You need to specify a model (no default model is set)" - ) try: model_obj = get_embedding_model(model) - except UnknownModelError as ex: - raise click.ClickException(str(ex)) + except UnknownModelError: + raise click.ClickException( + "You need to specify a model (no default model is set)" + ) show_output = True if collection and (format_ is None): diff --git a/llm/embeddings.py b/llm/embeddings.py index 4ea8d2db..1e128ad5 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -288,6 +288,17 @@ def similar(self, text: str, number: int = 10) -> List[Entry]: comparison_vector = self.model().embed(text) return self.similar_by_vector(comparison_vector, number) + @classmethod + def exists(cls, db: Database, name: str) -> bool: + """ + Does this collection exist in the database? + + Args: + name (str): Name of the collection + """ + rows = list(db["collections"].rows_where("name = ?", [name])) + return bool(rows) + def delete(self): """ Delete the collection and its embeddings from the database diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 71d1006c..7a844c53 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -102,9 +102,8 @@ def test_embed_store(user_path, metadata, metadata_error): assert embeddings_db.exists() # Check the contents db = sqlite_utils.Database(str(embeddings_db)) - assert list(db["collections"].rows) == [ - {"id": 1, "name": "items", "model": "embed-demo"} - ] + rows = list(db["collections"].rows) + assert rows == [{"id": 1, "name": "items", "model": "embed-demo"}] expected_metadata = None if metadata and not metadata_error: expected_metadata = metadata @@ -390,3 +389,34 @@ def test_default_embedding_model(): result5 = runner.invoke(cli, ["embed-models", "default"]) assert result5.exit_code == 0 assert result5.output == "\n" + + +@pytest.mark.parametrize("default_is_set", (False, True)) +@pytest.mark.parametrize("command", ("embed", "embed-multi")) +def test_default_embed_model_errors(user_path, default_is_set, command): + runner = CliRunner() + if default_is_set: + (user_path / "default_embedding_model.txt").write_text( + "embed-demo", encoding="utf8" + ) + args = [] + input = None + if command == "embed-multi": + args = ["embed-multi", "example", "-"] + input = "id,name\n1,hello" + else: + args = ["embed", "example", "1", "-c", "hello world"] + result = runner.invoke(cli, args, input=input, catch_exceptions=False) + if default_is_set: + assert result.exit_code == 0 + else: + assert result.exit_code == 1 + assert "You need to specify a model (no default model is set)" in result.output + # Now set the default model and try again + result2 = runner.invoke(cli, ["embed-models", "default", "embed-demo"]) + assert result2.exit_code == 0 + result3 = runner.invoke(cli, args, input=input, catch_exceptions=False) + assert result3.exit_code == 0 + # At the end of this, there should be 2 embeddings + db = sqlite_utils.Database(str(user_path / "embeddings.db")) + assert db["embeddings"].count == 1