From 7336910a900528f974f6821b35c7c15fca8587e6 Mon Sep 17 00:00:00 2001 From: David Schmitt Date: Mon, 9 Dec 2024 12:49:00 +0100 Subject: [PATCH] Accept azureModelName configuration --- llm/openai.go | 5 ++++- llm/openai_test.go | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llm/openai.go b/llm/openai.go index 1a16678..9f3c847 100644 --- a/llm/openai.go +++ b/llm/openai.go @@ -18,10 +18,13 @@ import ( // Creates a new OpenAI Provider. The user must pass in the API key, the model // to use, and the name. The name is used in the subsequent names of all the // assistants that are created. Set `azureBaseURL` to enable azure mode -func NewOpenAIProvider(apiKey string, azureBaseURL string, model string, name string, jsonMode bool) *openAIProvider { +func NewOpenAIProvider(apiKey string, azureBaseURL string, model, azureModelName string, name string, jsonMode bool) *openAIProvider { var cfg openai.ClientConfig if azureBaseURL != "" { cfg = openai.DefaultAzureConfig(apiKey, azureBaseURL) + cfg.AzureModelMapperFunc = func(m string) string { + return azureModelName + } } else { cfg = openai.DefaultConfig(apiKey) } diff --git a/llm/openai_test.go b/llm/openai_test.go index 1e7e2d4..643c63e 100644 --- a/llm/openai_test.go +++ b/llm/openai_test.go @@ -18,7 +18,7 @@ func TestNewOpenAIProvider(t *testing.T) { t.Skip("OPENAI_API_KEY not set") } - openaiProvider := NewOpenAIProvider(key, "", openai.GPT4oMini, t.Name(), false) + openaiProvider := NewOpenAIProvider(key, "", openai.GPT4oMini, "", t.Name(), false) // Assert that the result matches the provider interface var _ Provider = openaiProvider