Skip to content

Commit

Permalink
3-small-512, 3-large-256, 3-large-1024 embedding models, refs #394
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 25, 2024
1 parent 0446893 commit a1b97c0
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,31 @@ def register_embedding_models(register):
)
register(OpenAIEmbeddingModel("3-small", "text-embedding-3-small"))
register(OpenAIEmbeddingModel("3-large", "text-embedding-3-large"))
# With varying dimensions
register(OpenAIEmbeddingModel("3-small-512", "text-embedding-3-small", 512))
register(OpenAIEmbeddingModel("3-large-256", "text-embedding-3-large", 256))
register(OpenAIEmbeddingModel("3-large-1024", "text-embedding-3-large", 1024))


class OpenAIEmbeddingModel(EmbeddingModel):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
batch_size = 100

def __init__(self, model_id, openai_model_id):
def __init__(self, model_id, openai_model_id, dimensions=None):
self.model_id = model_id
self.openai_model_id = openai_model_id
self.dimensions = dimensions

def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
results = openai.Embedding.create(
input=items, model=self.openai_model_id, api_key=self.get_key()
)["data"]
kwargs = {
"input": items,
"model": self.openai_model_id,
"api_key": self.get_key(),
}
if self.dimensions:
kwargs["dimensions"] = self.dimensions
results = openai.Embedding.create(**kwargs)["data"]
return ([float(r) for r in result["embedding"]] for result in results)


Expand Down

0 comments on commit a1b97c0

Please sign in to comment.