Skip to content

Commit

Permalink
[ENH] XAI embedding function
Browse files Browse the repository at this point in the history
  • Loading branch information
itaismith committed Dec 1, 2024
1 parent 337fe73 commit 6ffc8eb
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions chromadb/utils/embedding_functions/xai_embedding_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from typing import List, cast, TypedDict

import httpx

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embedding

logger = logging.getLogger(__name__)

class XAIEmbedding(TypedDict):
Float: List[float]

class XAIResponseItem(TypedDict):
embedding: Embedding
index: int
object: str


class XAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the XAI API.
It requires an API key and a model name. You can use the "list embedding models" endpoint
to verify what embeddings models are available for your API key.
"""

def __init__(self, api_key: str, model_name: str):
"""
Initialize the XAIEmbeddingFunction.
Args:
api_key (str): Your API key for the XAI API.
model_name (str, optional): The name of the model to use for embeddings.
"""
self._model_name = model_name
self._api_url = "https://api.x.ai/v1/embeddings"
self._session = httpx.Client()
self._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
"""
resp = self._session.post(
self._api_url, json={"input": input, "model": self._model_name, "encoding_format": "float"}
).json()
if "data" not in resp:
raise RuntimeError(resp["error"])

embeddings: List[XAIResponseItem] = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])

# Return just the embeddings
return cast(Embeddings, [result["embedding"]["Float"] for result in sorted_embeddings])

0 comments on commit 6ffc8eb

Please sign in to comment.