-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH]: Ollama embedding function (#1813)
## Description of changes *Summarize the changes made by this PR.* - New functionality - New Ollama embedding function (Python and JS) - Example of how to run Ollama with the embedding function ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes chroma-core/docs#222
- Loading branch information
Showing
6 changed files
with
198 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
import pytest | ||
import requests | ||
from requests import HTTPError | ||
from requests.exceptions import ConnectionError | ||
|
||
from chromadb.utils.embedding_functions import OllamaEmbeddingFunction | ||
|
||
|
||
def test_ollama() -> None: | ||
""" | ||
To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file | ||
Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables. | ||
""" | ||
if ( | ||
os.environ.get("OLLAMA_SERVER_URL") is None | ||
or os.environ.get("OLLAMA_MODEL") is None | ||
): | ||
pytest.skip( | ||
"OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test." | ||
) | ||
try: | ||
response = requests.get(os.environ.get("OLLAMA_SERVER_URL", "")) | ||
# If the response was successful, no Exception will be raised | ||
response.raise_for_status() | ||
except (HTTPError, ConnectionError): | ||
pytest.skip("Ollama server not running. Skipping test.") | ||
ef = OllamaEmbeddingFunction( | ||
model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text", | ||
url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings", | ||
) | ||
embeddings = ef(["Here is an article about llamas...", "this is another article"]) | ||
assert len(embeddings) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import { IEmbeddingFunction } from "./IEmbeddingFunction"; | ||
|
||
export class OllamaEmbeddingFunction implements IEmbeddingFunction { | ||
private readonly url: string; | ||
private readonly model: string; | ||
|
||
constructor({ url, model }: { url: string, model: string }) { | ||
// we used to construct the client here, but we need to async import the types | ||
// for the openai npm package, and the constructor can not be async | ||
this.url = url; | ||
this.model = model; | ||
} | ||
|
||
public async generate(texts: string[]) { | ||
let embeddings:number[][] = []; | ||
for (let text of texts) { | ||
const response = await fetch(this.url, { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ 'model':this.model, 'prompt': text }) | ||
}); | ||
|
||
if (!response.ok) { | ||
throw new Error(`Failed to generate embeddings: ${response.status} (${response.statusText})`); | ||
} | ||
let finalResponse = await response.json(); | ||
embeddings.push(finalResponse['embedding']); | ||
} | ||
return embeddings; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Ollama | ||
|
||
First let's run a local docker container with Ollama. We'll pull `nomic-embed-text` model: | ||
|
||
```bash | ||
docker run -d -v ./ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama | ||
docker exec -it ollama ollama run nomic-embed-text # press Ctrl+D to exit after model downloads successfully | ||
# test it | ||
curl http://localhost:11434/api/embeddings -d '{"model": "nomic-embed-text","prompt": "Here is an article about llamas..."}' | ||
``` | ||
|
||
Now let's configure our OllamaEmbeddingFunction Embedding (python) function with the default Ollama endpoint: | ||
|
||
```python | ||
import chromadb | ||
from chromadb.utils.embedding_functions import OllamaEmbeddingFunction | ||
|
||
client = chromadb.PersistentClient(path="ollama") | ||
|
||
# create EF with custom endpoint | ||
ef = OllamaEmbeddingFunction( | ||
model_name="nomic-embed-text", | ||
url="http://127.0.0.1:11434/api/embeddings", | ||
) | ||
|
||
print(ef(["Here is an article about llamas..."])) | ||
``` | ||
|
||
For JS users, you can use the `OllamaEmbeddingFunction` class to create embeddings: | ||
|
||
```javascript | ||
const {OllamaEmbeddingFunction} = require('chromadb'); | ||
const embedder = new OllamaEmbeddingFunction({ | ||
url: "http://127.0.0.1:11434/api/embeddings", | ||
model: "llama2" | ||
}) | ||
|
||
// use directly | ||
const embeddings = embedder.generate(["Here is an article about llamas..."]) | ||
``` |