From 317e0c1f6c0d3df0a1adabb67ddcdba55f8d3d87 Mon Sep 17 00:00:00 2001 From: Miles Zimmerman Date: Sun, 22 Oct 2023 21:20:29 -0700 Subject: [PATCH 1/6] add CloudflareWorkersAIEmbeddingFunction to embedding_functions.py --- chromadb/utils/embedding_functions.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index aaef53c01e2..7f2bd4d236f 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -479,6 +479,38 @@ def __call__(self, texts: Documents) -> Embeddings: return embeddings +class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction): + # Follow API Quickstart for Cloudflare Workers AI + # https://developers.cloudflare.com/workers-ai/ + # Information about the text embedding modules in Google Vertex AI + # https://developers.cloudflare.com/workers-ai/models/embedding/ + def __init__( + self, + api_token: str, + account_id: str = None, + model_name: str = "@cf/baai/bge-base-en-v1.5", + gateway_url: str = None, # use Cloudflare AI Gateway instead of the usual endpoint + ): + self._api_base_url = gateway_url ? gateway_url : f"https://api.cloudflare.com/client/v4/accounts/{account_id}}/ai/run/" + self._session = requests.Session() + self._session.headers.update({"Authorization": f"Bearer {api_token}"}) + + def __call__(self, texts: Documents) -> Embeddings: + processed = [] + # Endpoint accepts up to 100 items at a time + for i in range(0, len(texts), 100): + batch = texts[i:i+100] + response = self._session.post( + f"{self._api_base_url}{self._model_name}", json={"text":batch} + ).json() + + if 'result' in response: + if 'data' in response["result"] + embeddings.append(response["result"]["data"]) + + return embeddings + + # List of all classes in this module _classes = [ name From 669cb64e19df67b9917248ccf9318e7c1ed571ab Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Sun, 7 Apr 2024 19:00:04 +0300 Subject: [PATCH 2/6] feat: Updated CF embedding function - Removed batching, the choice and control of batching and any failures arising from is better left to the users and handled outside the EF - Added JS implementation of the EF - Added tests for both python and JS EF --- chromadb/test/ef/test_cloudflare_ef.py | 61 ++++++++++++ chromadb/utils/embedding_functions.py | 76 ++++++++------ .../CloudflareWorkersAIEmbeddingFunction.ts | 82 +++++++++++++++ clients/js/src/index.ts | 4 +- clients/js/test/embeddings/cloudflare.test.ts | 99 +++++++++++++++++++ 5 files changed, 290 insertions(+), 32 deletions(-) create mode 100644 chromadb/test/ef/test_cloudflare_ef.py create mode 100644 clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts create mode 100644 clients/js/test/embeddings/cloudflare.test.ts diff --git a/chromadb/test/ef/test_cloudflare_ef.py b/chromadb/test/ef/test_cloudflare_ef.py new file mode 100644 index 00000000000..c00eabbe9f2 --- /dev/null +++ b/chromadb/test/ef/test_cloudflare_ef.py @@ -0,0 +1,61 @@ +import os + +import pytest + +from chromadb.utils.embedding_functions import CloudflareWorkersAIEmbeddingFunction + + +def test_cf_ef_token_and_account() -> None: + if "CF_API_TOKEN" not in os.environ or "CF_ACCOUNT_ID" not in os.environ: + pytest.skip("CF_API_TOKEN and CF_ACCOUNT_ID not set") + ef = CloudflareWorkersAIEmbeddingFunction( + api_token=os.environ.get("CF_API_TOKEN", ""), + account_id=os.environ.get("CF_ACCOUNT_ID"), + ) + embeddings = ef(["test doc"]) + assert embeddings is not None + assert len(embeddings) == 1 + assert len(embeddings[0]) > 0 + + +def test_cf_ef_gateway() -> None: + if "CF_API_TOKEN" not in os.environ or "CF_GATEWAY_ENDPOINT" not in os.environ: + pytest.skip("CF_API_TOKEN and CF_GATEWAY_ENDPOINT not set") + ef = CloudflareWorkersAIEmbeddingFunction( + api_token=os.environ.get("CF_API_TOKEN", ""), + gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"), + ) + embeddings = ef(["test doc"]) + assert embeddings is not None + assert len(embeddings) == 1 + assert len(embeddings[0]) > 0 + + +def test_cf_ef_large_batch() -> None: + if "CF_API_TOKEN" not in os.environ: + pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") + + ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy") + with pytest.raises(ValueError, match="Batch too large"): + ef(["test doc"] * 101) + + +def test_cf_ef_missing_account_or_gateway() -> None: + if "CF_API_TOKEN" not in os.environ: + pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") + with pytest.raises( + ValueError, match="Please provide either an account_id or a gateway_url" + ): + CloudflareWorkersAIEmbeddingFunction(api_token="dummy") + + +def test_cf_ef_with_account_or_gateway() -> None: + if "CF_API_TOKEN" not in os.environ: + pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") + with pytest.raises( + ValueError, + match="Please provide either an account_id or a gateway_url, not both", + ): + CloudflareWorkersAIEmbeddingFunction( + api_token="dummy", account_id="dummy", gateway_url="dummy" + ) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 73da88a406d..c27cc85afc8 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -682,7 +682,6 @@ def __call__(self, input: Documents) -> Embeddings: return embeddings - class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): def __init__( self, @@ -744,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, api_key: str = "", api_url = "https://infer.roboflow.com" - ) -> None: + def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None: """ Create a RoboflowEmbeddingFunction. @@ -758,7 +755,7 @@ def __init__( api_key = os.environ.get("ROBOFLOW_API_KEY") self._api_url = api_url - self._api_key = api_key + self._api_key = api_key try: self._PILImage = importlib.import_module("PIL.Image") @@ -790,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: json=infer_clip_payload, ) - result = res.json()['embeddings'] + result = res.json()["embeddings"] embeddings.append(result[0]) - + elif is_document(item): infer_clip_payload = { "text": input, @@ -804,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: json=infer_clip_payload, ) - result = res.json()['embeddings'] + result = res.json()["embeddings"] embeddings.append(result[0]) return embeddings - + class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, @@ -901,8 +898,7 @@ def __call__(self, input: Documents) -> Embeddings: ) - -class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction): +class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]): # Follow API Quickstart for Cloudflare Workers AI # https://developers.cloudflare.com/workers-ai/ # Information about the text embedding modules in Google Vertex AI @@ -910,29 +906,49 @@ class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction): def __init__( self, api_token: str, - account_id: str = None, - model_name: str = "@cf/baai/bge-base-en-v1.5", - gateway_url: str = None, # use Cloudflare AI Gateway instead of the usual endpoint + account_id: Optional[str] = None, + model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5", + gateway_url: Optional[ + str + ] = None, # use Cloudflare AI Gateway instead of the usual endpoint + # right now endpoint schema supports up to 100 docs at a time + # https://developers.cloudflare.com/workers-ai/models/bge-small-en-v1.5/#api-schema (Input JSON Schema) + max_batch_size: Optional[int] = 100, + headers: Optional[Dict[str, str]] = None, ): - self._api_base_url = gateway_url ? gateway_url : f"https://api.cloudflare.com/client/v4/accounts/{account_id}}/ai/run/" + if not gateway_url and not account_id: + raise ValueError("Please provide either an account_id or a gateway_url.") + if gateway_url and account_id: + raise ValueError( + "Please provide either an account_id or a gateway_url, not both." + ) + if gateway_url is not None and not gateway_url.endswith("/"): + gateway_url += "/" + self._api_url = ( + f"{gateway_url}/{model_name}" + if gateway_url is not None + else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}" + ) self._session = requests.Session() + self._session.headers.update(headers or {}) self._session.headers.update({"Authorization": f"Bearer {api_token}"}) + self._max_batch_size = max_batch_size def __call__(self, texts: Documents) -> Embeddings: - processed = [] - # Endpoint accepts up to 100 items at a time - for i in range(0, len(texts), 100): - batch = texts[i:i+100] - response = self._session.post( - f"{self._api_base_url}{self._model_name}", json={"text":batch} - ).json() - - if 'result' in response: - if 'data' in response["result"] - embeddings.append(response["result"]["data"]) - - return embeddings + # Endpoint accepts up to 100 items at a time. We'll reject anything larger. + # It would be up to the user to split the input into smaller batches. + if self._max_batch_size and len(texts) > self._max_batch_size: + raise ValueError( + f"Batch too large {len(texts)} > {self._max_batch_size} (maximum batch size)." + ) + response = self._session.post(f"{self._api_url}", json={"text": texts}) + response.raise_for_status() + _json = response.json() + if "result" in _json and "data" in _json["result"]: + return cast(Embeddings, _json["result"]["data"]) + else: + raise ValueError(f"Error calling Cloudflare Workers AI: {response.text}") def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore @@ -997,7 +1013,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) - + class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). @@ -1053,7 +1069,7 @@ def __call__(self, input: Documents) -> Embeddings: ], ) - + # List of all classes in this module _classes = [ name diff --git a/clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts b/clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts new file mode 100644 index 00000000000..1e810229d3c --- /dev/null +++ b/clients/js/src/embeddings/CloudflareWorkersAIEmbeddingFunction.ts @@ -0,0 +1,82 @@ +import { IEmbeddingFunction } from "./IEmbeddingFunction"; + +export class CloudflareWorkersAIEmbeddingFunction + implements IEmbeddingFunction +{ + private apiUrl: string; + private headers: { [key: string]: string }; + private maxBatchSize: number; + + constructor({ + apiToken, + model, + accountId, + gatewayUrl, + maxBatchSize, + headers, + }: { + apiToken: string; + model?: string; + accountId?: string; + gatewayUrl?: string; + maxBatchSize?: number; + headers?: { [key: string]: string }; + }) { + model = model || "@cf/baai/bge-base-en-v1.5"; + this.maxBatchSize = maxBatchSize || 100; + if (accountId === undefined && gatewayUrl === undefined) { + throw new Error("Please provide either an accountId or a gatewayUrl."); + } + if (accountId !== undefined && gatewayUrl !== undefined) { + throw new Error( + "Please provide either an accountId or a gatewayUrl, not both.", + ); + } + if (gatewayUrl !== undefined && !gatewayUrl.endsWith("/")) { + gatewayUrl += "/" + model; + } + this.apiUrl = + gatewayUrl || + `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}`; + this.headers = headers || {}; + this.headers["Authorization"] = `Bearer ${apiToken}`; + } + + public async generate(texts: string[]) { + if (texts.length === 0) { + return []; + } + if (texts.length > this.maxBatchSize) { + throw new Error( + `Batch too large ${texts.length} > ${this.maxBatchSize} (maximum batch size).`, + ); + } + try { + const response = await fetch(this.apiUrl, { + method: "POST", + headers: this.headers, + body: JSON.stringify({ + text: texts, + }), + }); + + const data = (await response.json()) as { + success?: boolean; + messages: any[]; + errors?: any[]; + result: { shape: any[]; data: number[][] }; + }; + if (data.success === false) { + throw new Error(`${JSON.stringify(data.errors)}`); + } + return data.result.data; + } catch (error) { + console.error(error); + if (error instanceof Error) { + throw new Error(`Error calling CF API: ${error}`); + } else { + throw new Error(`Error calling CF API: ${error}`); + } + } + } +} diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index c925f9e4871..c10d6e7a46b 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -10,8 +10,8 @@ export { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction" export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbeddingServerFunction"; export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction"; export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction"; -export { OllamaEmbeddingFunction } from './embeddings/OllamaEmbeddingFunction'; - +export { OllamaEmbeddingFunction } from "./embeddings/OllamaEmbeddingFunction"; +export { CloudflareWorkersAIEmbeddingFunction } from "./embeddings/CloudflareWorkersAIEmbeddingFunction"; export { IncludeEnum, diff --git a/clients/js/test/embeddings/cloudflare.test.ts b/clients/js/test/embeddings/cloudflare.test.ts new file mode 100644 index 00000000000..67e880e4a97 --- /dev/null +++ b/clients/js/test/embeddings/cloudflare.test.ts @@ -0,0 +1,99 @@ +import { expect, test } from "@jest/globals"; +import { DOCUMENTS } from "../data"; +import { CloudflareWorkersAIEmbeddingFunction } from "../../src"; + +if (!process.env.CF_API_TOKEN) { + test.skip("it should generate Cloudflare embeddings with apiToken and AccountId", async () => {}); +} else { + test("it should generate Cloudflare embeddings with apiToken and AccountId", async () => { + const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + accountId: process.env.CF_ACCOUNT_ID, + }); + const embeddings = await embedder.generate(DOCUMENTS); + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(DOCUMENTS.length); + }); +} + +if (!process.env.CF_API_TOKEN) { + test.skip("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => {}); +} else { + test("it should generate Cloudflare embeddings with apiToken and AccountId and model", async () => { + const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + accountId: process.env.CF_ACCOUNT_ID, + model: "@cf/baai/bge-small-en-v1.5", + }); + const embeddings = await embedder.generate(DOCUMENTS); + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(DOCUMENTS.length); + }); +} + +if (!process.env.CF_API_TOKEN) { + test.skip("it should generate Cloudflare embeddings with apiToken and gateway", async () => {}); +} else { + test("it should generate Cloudflare embeddings with apiToken and gateway", async () => { + const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + gatewayUrl: process.env.CF_GATEWAY_ENDPOINT, + }); + const embeddings = await embedder.generate(DOCUMENTS); + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(DOCUMENTS.length); + }); +} + +if (!process.env.CF_API_TOKEN) { + test.skip("it should fail when batch too large", async () => {}); +} else { + test("it should fail when batch too large", async () => { + const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + gatewayUrl: process.env.CF_GATEWAY_ENDPOINT, + }); + const largeBatch = Array(100) + .fill([...DOCUMENTS]) + .flat(); + try { + await embedder.generate(largeBatch); + } catch (e: any) { + expect(e.message).toMatch("Batch too large"); + } + }); +} + +if (!process.env.CF_API_TOKEN) { + test.skip("it should fail when gateway endpoint and account id are both provided", async () => {}); +} else { + test("it should fail when gateway endpoint and account id are both provided", async () => { + try { + new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + accountId: process.env.CF_ACCOUNT_ID, + gatewayUrl: process.env.CF_GATEWAY_ENDPOINT, + }); + } catch (e: any) { + expect(e.message).toMatch( + "Please provide either an accountId or a gatewayUrl, not both.", + ); + } + }); +} + +if (!process.env.CF_API_TOKEN) { + test.skip("it should fail when neither gateway endpoint nor account id are provided", async () => {}); +} else { + test("it should fail when neither gateway endpoint nor account id are provided", async () => { + try { + new CloudflareWorkersAIEmbeddingFunction({ + apiToken: process.env.CF_API_TOKEN as string, + }); + } catch (e: any) { + expect(e.message).toMatch( + "Please provide either an accountId or a gatewayUrl.", + ); + } + }); +} From c42287255ab75f714b1b4733ad8d58585303f394 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 8 Apr 2024 09:15:34 +0300 Subject: [PATCH 3/6] chore: Test cleanup --- chromadb/test/ef/test_cloudflare_ef.py | 31 +++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/chromadb/test/ef/test_cloudflare_ef.py b/chromadb/test/ef/test_cloudflare_ef.py index c00eabbe9f2..914b9c5ad99 100644 --- a/chromadb/test/ef/test_cloudflare_ef.py +++ b/chromadb/test/ef/test_cloudflare_ef.py @@ -5,9 +5,11 @@ from chromadb.utils.embedding_functions import CloudflareWorkersAIEmbeddingFunction +@pytest.mark.skipif( + "CF_API_TOKEN" not in os.environ, + reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", +) def test_cf_ef_token_and_account() -> None: - if "CF_API_TOKEN" not in os.environ or "CF_ACCOUNT_ID" not in os.environ: - pytest.skip("CF_API_TOKEN and CF_ACCOUNT_ID not set") ef = CloudflareWorkersAIEmbeddingFunction( api_token=os.environ.get("CF_API_TOKEN", ""), account_id=os.environ.get("CF_ACCOUNT_ID"), @@ -18,9 +20,11 @@ def test_cf_ef_token_and_account() -> None: assert len(embeddings[0]) > 0 +@pytest.mark.skipif( + "CF_API_TOKEN" not in os.environ, + reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", +) def test_cf_ef_gateway() -> None: - if "CF_API_TOKEN" not in os.environ or "CF_GATEWAY_ENDPOINT" not in os.environ: - pytest.skip("CF_API_TOKEN and CF_GATEWAY_ENDPOINT not set") ef = CloudflareWorkersAIEmbeddingFunction( api_token=os.environ.get("CF_API_TOKEN", ""), gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"), @@ -31,27 +35,32 @@ def test_cf_ef_gateway() -> None: assert len(embeddings[0]) > 0 +@pytest.mark.skipif( + "CF_API_TOKEN" not in os.environ, + reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", +) def test_cf_ef_large_batch() -> None: - if "CF_API_TOKEN" not in os.environ: - pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") - ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy") with pytest.raises(ValueError, match="Batch too large"): ef(["test doc"] * 101) +@pytest.mark.skipif( + "CF_API_TOKEN" not in os.environ, + reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", +) def test_cf_ef_missing_account_or_gateway() -> None: - if "CF_API_TOKEN" not in os.environ: - pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") with pytest.raises( ValueError, match="Please provide either an account_id or a gateway_url" ): CloudflareWorkersAIEmbeddingFunction(api_token="dummy") +@pytest.mark.skipif( + "CF_API_TOKEN" not in os.environ, + reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", +) def test_cf_ef_with_account_or_gateway() -> None: - if "CF_API_TOKEN" not in os.environ: - pytest.skip("CF_API_TOKEN not set, not going to test Cloudflare EF.") with pytest.raises( ValueError, match="Please provide either an account_id or a gateway_url, not both", From cb7ea7fe2dba531946b1925c793734747740f63c Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 21 Jun 2024 18:23:37 +0200 Subject: [PATCH 4/6] feat: Rebase + bug fix of the gateway_url composition --- chromadb/test/ef/test_cloudflare_ef.py | 4 +- ...loudflare_workers_ai_embedding_function.py | 63 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py diff --git a/chromadb/test/ef/test_cloudflare_ef.py b/chromadb/test/ef/test_cloudflare_ef.py index 914b9c5ad99..3845a275c46 100644 --- a/chromadb/test/ef/test_cloudflare_ef.py +++ b/chromadb/test/ef/test_cloudflare_ef.py @@ -2,7 +2,9 @@ import pytest -from chromadb.utils.embedding_functions import CloudflareWorkersAIEmbeddingFunction +from chromadb.utils.embedding_functions.cloudflare_workers_ai_embedding_function import ( + CloudflareWorkersAIEmbeddingFunction, +) @pytest.mark.skipif( diff --git a/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py new file mode 100644 index 00000000000..4e001c99b4d --- /dev/null +++ b/chromadb/utils/embedding_functions/cloudflare_workers_ai_embedding_function.py @@ -0,0 +1,63 @@ +import logging +from typing import Optional, Dict, cast + +import httpx + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings + +logger = logging.getLogger(__name__) + + +class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]): + # Follow API Quickstart for Cloudflare Workers AI + # https://developers.cloudflare.com/workers-ai/ + # Information about the text embedding modules in Google Vertex AI + # https://developers.cloudflare.com/workers-ai/models/embedding/ + def __init__( + self, + api_token: str, + account_id: Optional[str] = None, + model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5", + gateway_url: Optional[ + str + ] = None, # use Cloudflare AI Gateway instead of the usual endpoint + # right now endpoint schema supports up to 100 docs at a time + # https://developers.cloudflare.com/workers-ai/models/bge-small-en-v1.5/#api-schema (Input JSON Schema) + max_batch_size: Optional[int] = 100, + headers: Optional[Dict[str, str]] = None, + ): + if not gateway_url and not account_id: + raise ValueError("Please provide either an account_id or a gateway_url.") + if gateway_url and account_id: + raise ValueError( + "Please provide either an account_id or a gateway_url, not both." + ) + if gateway_url is not None and not gateway_url.endswith("/"): + gateway_url += "/" + self._api_url = ( + f"{gateway_url}{model_name}" + if gateway_url is not None + else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}" + ) + self._session = httpx.Client() + self._session.headers.update(headers or {}) + self._session.headers.update({"Authorization": f"Bearer {api_token}"}) + self._max_batch_size = max_batch_size + + def __call__(self, texts: Documents) -> Embeddings: + # Endpoint accepts up to 100 items at a time. We'll reject anything larger. + # It would be up to the user to split the input into smaller batches. + if self._max_batch_size and len(texts) > self._max_batch_size: + raise ValueError( + f"Batch too large {len(texts)} > {self._max_batch_size} (maximum batch size)." + ) + + print("URI", self._api_url) + + response = self._session.post(f"{self._api_url}", json={"text": texts}) + response.raise_for_status() + _json = response.json() + if "result" in _json and "data" in _json["result"]: + return cast(Embeddings, _json["result"]["data"]) + else: + raise ValueError(f"Error calling Cloudflare Workers AI: {response.text}") From 8f21304445748a1a539e1b0c61dcecd1f86c1568 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 21 Jun 2024 18:38:52 +0200 Subject: [PATCH 5/6] docs: Added docs from docs repo Refs: chroma-core/docs#231 --- .../pages/guides/embeddings.md | 17 +++-- .../pages/integrations/_sidenav.js | 1 + .../pages/integrations/cloudflare.md | 73 +++++++++++++++++++ 3 files changed, 83 insertions(+), 8 deletions(-) create mode 100644 docs/docs.trychroma.com/pages/integrations/cloudflare.md diff --git a/docs/docs.trychroma.com/pages/guides/embeddings.md b/docs/docs.trychroma.com/pages/guides/embeddings.md index d523c7d3089..500998df3ad 100644 --- a/docs/docs.trychroma.com/pages/guides/embeddings.md +++ b/docs/docs.trychroma.com/pages/guides/embeddings.md @@ -9,15 +9,16 @@ Chroma provides lightweight wrappers around popular embedding providers, making {% special_table %} {% /special_table %} -| | Python | JS | -|--------------|-----------|---------------| -| [OpenAI](/integrations/openai) | ✅ | ✅ | -| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ | -| [Cohere](/integrations/cohere) | ✅ | ✅ | -| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ | -| [Instructor](/integrations/instructor) | ✅ | ➖ | +| | Python | JS | +|--------------------------------------------------------------------|-----------|---------------| +| [OpenAI](/integrations/openai) | ✅ | ✅ | +| [Google Generative AI](/integrations/google-gemini) | ✅ | ✅ | +| [Cohere](/integrations/cohere) | ✅ | ✅ | +| [Hugging Face](/integrations/hugging-face) | ✅ | ➖ | +| [Instructor](/integrations/instructor) | ✅ | ➖ | | [Hugging Face Embedding Server](/integrations/hugging-face-server) | ✅ | ✅ | -| [Jina AI](/integrations/jinaai) | ✅ | ✅ | +| [Jina AI](/integrations/jinaai) | ✅ | ✅ | +| [Cloudflare Workers AI](/integrations/cloudflare) | ✅ | ✅ | We welcome pull requests to add new Embedding Functions to the community. diff --git a/docs/docs.trychroma.com/pages/integrations/_sidenav.js b/docs/docs.trychroma.com/pages/integrations/_sidenav.js index 5feeda2d186..8586946263c 100644 --- a/docs/docs.trychroma.com/pages/integrations/_sidenav.js +++ b/docs/docs.trychroma.com/pages/integrations/_sidenav.js @@ -11,6 +11,7 @@ export const items = [ { href: '/integrations/jinaai', children: 'JinaAI' }, { href: '/integrations/roboflow', children: 'Roboflow' }, { href: '/integrations/ollama', children: 'Ollama Embeddings' }, + { href: '/integrations/cloudflare', children: 'Cloudflare Workers AI Embeddings' }, ] }, { diff --git a/docs/docs.trychroma.com/pages/integrations/cloudflare.md b/docs/docs.trychroma.com/pages/integrations/cloudflare.md new file mode 100644 index 00000000000..fe842161bd3 --- /dev/null +++ b/docs/docs.trychroma.com/pages/integrations/cloudflare.md @@ -0,0 +1,73 @@ +--- +title: Cloudflare Workers AI +--- + +Chroma provides a convenient wrapper around Cloudflare Workers AI REST API. This embedding function runs remotely on a Cloudflare Workers AI. It requires an API key and an account Id or gateway endpoint. You can get an API key by signing up for an account at [Cloudflare Workers AI](https://cloudflare.com/). + +Visit the [Cloudflare Workers AI documentation](https://developers.cloudflare.com/workers-ai/) for more information on getting started. + +:::note +Currently cloudflare embeddings endpoints allow batches of maximum 100 documents in a single request. The EF has a hard limit of 100 documents per request, and will raise an error if you try to pass more than 100 documents. +::: + +{% tabs group="code-lang" %} +{% tab label="Python" %} + +```python +import chromadb.utils.embedding_functions as embedding_functions +cf_ef = embedding_functions.CloudflareWorkersAIEmbeddingFunction( + api_key = "YOUR_API_KEY", + account_id = "YOUR_ACCOUNT_ID", # or gateway_endpoint + model_name = "@cf/baai/bge-base-en-v1.5", + ) +cf_ef(input=["This is my first text to embed", "This is my second document"]) +``` + +You can pass in an optional `model_name` argument, which lets you choose which Cloudflare Workers AI [model](https://developers.cloudflare.com/workers-ai/models/#text-embeddings) to use. By default, Chroma uses `@cf/baai/bge-base-en-v1.5`. + +{% /tab %} +{% tab label="Javascript" %} + +{% codetabs customHeader="js" %} +{% codetab label="ESM" %} + +```js {% codetab=true %} +import {CloudflareWorkersAIEmbeddingFunction} from "chromadb"; +const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: 'YOUR_API_KEY', + accountId: "YOUR_ACCOUNT_ID", // or gatewayEndpoint + model: '@cf/baai/bge-base-en-v1.5', +}); + +// use directly +const embeddings = embedder.generate(['document1', 'document2']); + +// pass documents to query for .add and .query +const collection = await client.createCollection({name: "name", embeddingFunction: embedder}) +const collectionGet = await client.getCollection({name:"name", embeddingFunction: embedder}) +``` +{% /codetab %} +{% codetab label="CJS" %} + +```js {% codetab=true %} +const {CloudflareWorkersAIEmbeddingFunction} = require('chromadb'); +const embedder = new CloudflareWorkersAIEmbeddingFunction({ + apiToken: 'YOUR_API_KEY', + accountId: "YOUR_ACCOUNT_ID", // or gatewayEndpoint + model: '@cf/baai/bge-base-en-v1.5', +}); + +// use directly +const embeddings = embedder.generate(['document1', 'document2']); + +// pass documents to query for .add and .query +const collection = await client.createCollection({name: "name", embeddingFunction: embedder}) +const collectionGet = await client.getCollection({name:"name", embeddingFunction: embedder}) +``` + +{% /codetab %} +{% /codetabs %} + +{% /tab %} + +{% /tabs %} From 3b8180cf3624f6ee4b17bf1930d11a7d8e8d92d5 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Fri, 21 Jun 2024 19:28:06 +0200 Subject: [PATCH 6/6] fix: Fixing EF imports test --- chromadb/test/ef/test_ef.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py index c93502e3fc8..364bf52c9f7 100644 --- a/chromadb/test/ef/test_ef.py +++ b/chromadb/test/ef/test_ef.py @@ -30,6 +30,7 @@ def test_get_builtins_holds() -> None: "SentenceTransformerEmbeddingFunction", "Text2VecEmbeddingFunction", "ChromaLangchainEmbeddingFunction", + "CloudflareWorkersAIEmbeddingFunction", } assert expected_builtins == embedding_functions.get_builtins()