From a1b97c06e60c9f321b7831a21af2650822fac8df Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 25 Jan 2024 13:22:21 -0800 Subject: [PATCH] 3-small-512, 3-large-256, 3-large-1024 embedding models, refs #394 --- llm/default_plugins/openai_models.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index cc04c293..565f1657 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -83,6 +83,10 @@ def register_embedding_models(register): ) register(OpenAIEmbeddingModel("3-small", "text-embedding-3-small")) register(OpenAIEmbeddingModel("3-large", "text-embedding-3-large")) + # With varying dimensions + register(OpenAIEmbeddingModel("3-small-512", "text-embedding-3-small", 512)) + register(OpenAIEmbeddingModel("3-large-256", "text-embedding-3-large", 256)) + register(OpenAIEmbeddingModel("3-large-1024", "text-embedding-3-large", 1024)) class OpenAIEmbeddingModel(EmbeddingModel): @@ -90,14 +94,20 @@ class OpenAIEmbeddingModel(EmbeddingModel): key_env_var = "OPENAI_API_KEY" batch_size = 100 - def __init__(self, model_id, openai_model_id): + def __init__(self, model_id, openai_model_id, dimensions=None): self.model_id = model_id self.openai_model_id = openai_model_id + self.dimensions = dimensions def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]: - results = openai.Embedding.create( - input=items, model=self.openai_model_id, api_key=self.get_key() - )["data"] + kwargs = { + "input": items, + "model": self.openai_model_id, + "api_key": self.get_key(), + } + if self.dimensions: + kwargs["dimensions"] = self.dimensions + results = openai.Embedding.create(**kwargs)["data"] return ([float(r) for r in result["embedding"]] for result in results)