Skip to content

Commit

Permalink
Default embedding model finishing touches, closes #222
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 4, 2023
1 parent 8ce7046 commit 0eda99e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 23 deletions.
26 changes: 17 additions & 9 deletions docs/embeddings/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@ The `llm embed` command can be used to calculate embedding vectors for a string
The simplest way to use this command is to pass content to it using the `-c/--content` option, like this:

```bash
llm embed -c 'This is some content'
llm embed -c 'This is some content' -m ada-002
```
The command will return a JSON array of floating point numbers directly to the terminal:
`-m ada-002` specifies the OpenAI `ada-002` model. You will need to have set an OpenAI API key using `llm keys set openai` for this to work.

```json
[0.123, 0.456, 0.789...]
You can install plugins to access other models. The [llm-sentence-transformers](https://github.com/simonw/llm-sentence-transformers) plugin can be used to run models on your own laptop, such as the [MiniLM-L6](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model:

```bash
llm install llm-sentence-transformers
llm embed -c 'This is some content' -m sentence-transformers/all-MiniLM-L6-v2
```
By default it uses the {ref}`default embedding model <embeddings-cli-embed-models-default>`.

Use the `-m/--model` option to specify a different model:
The `llm embed` command returns a JSON array of floating point numbers directly to the terminal:

```bash
llm -m sentence-transformers/all-MiniLM-L6-v2 \
-c 'This is some content'
```json
[0.123, 0.456, 0.789...]
```
You can omit the `-m/--model` option if you set a {ref}`default embedding model <embeddings-cli-embed-models-default>`.

See {ref}`embeddings-binary` for options to get back embeddings in formats other than JSON.

(embeddings-collections)=
Expand All @@ -37,6 +40,11 @@ Embeddings are much more useful if you store them somewhere, so you can calculat

LLM includes the concept of a "collection" of embeddings. A collection groups together a set of stored embeddings created using the same model, each with a unique ID within that collection.

First, we'll set a default model so we don't have to keep repeating it:
```bash
llm embed-models default ada-002
```

The `llm embed` command can store results directly in a named collection like this:

```bash
Expand Down
7 changes: 7 additions & 0 deletions docs/embeddings/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ A collection instance has the following properties and methods:
- `similar_by_vector(vector: List[float], number: int=10, skip_id: str=None)` - returns a list of entries that are most similar to the given embedding vector, optionally skipping the entry with the given ID
- `delete()` - deletes the collection and its embeddings from the database

There is also a `Collection.exists(db, name)` class method which returns a boolean value and can be used to determine if a collection exists or not in a database:

```python
if Collection.exists(db, "entries"):
print("The entries collection exists")
```

(embeddings-python-similar)=
## Retrieving similar items

Expand Down
2 changes: 1 addition & 1 deletion llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_embedding_model(name):
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + name)
raise UnknownModelError("Unknown model: " + str(name))


def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
Expand Down
28 changes: 18 additions & 10 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,20 +929,28 @@ def get_db():
model_obj = None
if collection:
db = get_db()
collection_obj = Collection(collection, db, model_id=model)
model_obj = collection_obj.model()
if Collection.exists(db, collection):
# Load existing collection and use its model
collection_obj = Collection(collection, db)
model_obj = collection_obj.model()
else:
# We will create a new one, but that means model is required
if not model:
model = get_default_embedding_model()
if model is None:
raise click.ClickException(
"You need to specify a model (no default model is set)"
)
collection_obj = Collection(collection, db=db, model_id=model)
model_obj = collection_obj.model()

if model_obj is None:
if not model:
model = get_default_embedding_model()
if model is None:
raise click.ClickException(
"You need to specify a model (no default model is set)"
)
try:
model_obj = get_embedding_model(model)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
except UnknownModelError:
raise click.ClickException(
"You need to specify a model (no default model is set)"
)

show_output = True
if collection and (format_ is None):
Expand Down
11 changes: 11 additions & 0 deletions llm/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ def similar(self, text: str, number: int = 10) -> List[Entry]:
comparison_vector = self.model().embed(text)
return self.similar_by_vector(comparison_vector, number)

@classmethod
def exists(cls, db: Database, name: str) -> bool:
"""
Does this collection exist in the database?
Args:
name (str): Name of the collection
"""
rows = list(db["collections"].rows_where("name = ?", [name]))
return bool(rows)

def delete(self):
"""
Delete the collection and its embeddings from the database
Expand Down
36 changes: 33 additions & 3 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def test_embed_store(user_path, metadata, metadata_error):
assert embeddings_db.exists()
# Check the contents
db = sqlite_utils.Database(str(embeddings_db))
assert list(db["collections"].rows) == [
{"id": 1, "name": "items", "model": "embed-demo"}
]
rows = list(db["collections"].rows)
assert rows == [{"id": 1, "name": "items", "model": "embed-demo"}]
expected_metadata = None
if metadata and not metadata_error:
expected_metadata = metadata
Expand Down Expand Up @@ -390,3 +389,34 @@ def test_default_embedding_model():
result5 = runner.invoke(cli, ["embed-models", "default"])
assert result5.exit_code == 0
assert result5.output == "<No default embedding model set>\n"


@pytest.mark.parametrize("default_is_set", (False, True))
@pytest.mark.parametrize("command", ("embed", "embed-multi"))
def test_default_embed_model_errors(user_path, default_is_set, command):
runner = CliRunner()
if default_is_set:
(user_path / "default_embedding_model.txt").write_text(
"embed-demo", encoding="utf8"
)
args = []
input = None
if command == "embed-multi":
args = ["embed-multi", "example", "-"]
input = "id,name\n1,hello"
else:
args = ["embed", "example", "1", "-c", "hello world"]
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
if default_is_set:
assert result.exit_code == 0
else:
assert result.exit_code == 1
assert "You need to specify a model (no default model is set)" in result.output
# Now set the default model and try again
result2 = runner.invoke(cli, ["embed-models", "default", "embed-demo"])
assert result2.exit_code == 0
result3 = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result3.exit_code == 0
# At the end of this, there should be 2 embeddings
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
assert db["embeddings"].count == 1

0 comments on commit 0eda99e

Please sign in to comment.