From e7bd8d5da113fe541629464eefa08f1c65067055 Mon Sep 17 00:00:00 2001 From: Gareth Evans Date: Thu, 5 Dec 2024 11:00:14 +0000 Subject: [PATCH] feat: allow stream usage to be set for azure openai requests --- .../ai/azure/openai/AzureOpenAiChatModel.java | 39 +++++++++++++++---- .../azure/openai/AzureOpenAiChatOptions.java | 18 +++++++++ pom.xml | 2 +- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 226a9eef06..ae5b355693 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -29,6 +29,7 @@ import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; @@ -206,7 +207,7 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); - options.setStream(false); + ChatCompletionsOptionsAccessHelper.setStream(options, false); ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); ChatResponse chatResponse = toChatResponse(chatCompletions); @@ -230,7 +231,7 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); - options.setStream(true); + ChatCompletionsOptionsAccessHelper.setStream(options, true); Flux chatCompletionsStream = this.openAIAsyncClient .getChatCompletionsStream(options.getModel(), options); @@ -252,10 +253,14 @@ public Flux stream(Prompt prompt) { final Flux accessibleChatCompletionsFlux = chatCompletionsStream // Note: the first chat completions can be ignored when using Azure OpenAI // service which is a known service bug. - .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())) + // The last element, when using stream_options will contain the usage data + .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()) + || chatCompletions.getUsage() != null) .map(chatCompletions -> { - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); + if (!chatCompletions.getChoices().isEmpty()) { + final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); + isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); + } return chatCompletions; }) .windowUntil(chatCompletions -> { @@ -493,7 +498,13 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, } ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages()); - mergedAzureOptions.setStream(fromAzureOptions.isStream()); + + ChatCompletionsOptionsAccessHelper.setStream(mergedAzureOptions, + fromAzureOptions.isStream() != null ? fromAzureOptions.isStream() : false); + + ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions, + fromAzureOptions.getStreamOptions() != null ? fromAzureOptions.getStreamOptions() + : toSpringAiOptions.getStreamOptions()); mergedAzureOptions.setMaxTokens((fromAzureOptions.getMaxTokens() != null) ? fromAzureOptions.getMaxTokens() : toSpringAiOptions.getMaxTokens()); @@ -629,6 +640,15 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); } + if (fromSpringAiOptions.getStreamOptions() != null) { + ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions, + fromSpringAiOptions.getStreamOptions()); + } + + if (fromSpringAiOptions.getEnhancements() != null) { + mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); + } + return mergedAzureOptions; } @@ -640,8 +660,13 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages()); - copyOptions.setStream(fromOptions.isStream()); + if (fromOptions.isStream() != null) { + ChatCompletionsOptionsAccessHelper.setStream(copyOptions, fromOptions.isStream()); + } + if (fromOptions.getStreamOptions() != null) { + ChatCompletionsOptionsAccessHelper.setStreamOptions(copyOptions, fromOptions.getStreamOptions()); + } if (fromOptions.getMaxTokens() != null) { copyOptions.setMaxTokens(fromOptions.getMaxTokens()); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index cc09ae1ed7..7b98d4daf1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -23,6 +23,7 @@ import java.util.Set; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -193,6 +194,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions { @JsonIgnore private AzureChatEnhancementConfiguration enhancements; + @JsonProperty("stream_options") + private ChatCompletionStreamOptions streamOptions; + @JsonIgnore private Map toolContext; @@ -219,6 +223,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withTopLogprobs(fromOptions.getTopLogProbs()) .withEnhancements(fromOptions.getEnhancements()) .withToolContext(fromOptions.getToolContext()) + .withStreamOptions(fromOptions.getStreamOptions()) .build(); } @@ -412,6 +417,14 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + public ChatCompletionStreamOptions getStreamOptions() { + return this.streamOptions; + } + + public void setStreamOptions(ChatCompletionStreamOptions streamOptions) { + this.streamOptions = streamOptions; + } + @Override public AzureOpenAiChatOptions copy() { return fromOptions(this); @@ -536,6 +549,11 @@ public Builder withToolContext(Map toolContext) { return this; } + public Builder withStreamOptions(ChatCompletionStreamOptions streamOptions) { + this.options.streamOptions = streamOptions; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/pom.xml b/pom.xml index a138315445..767db56553 100644 --- a/pom.xml +++ b/pom.xml @@ -174,7 +174,7 @@ 3.3.6 4.3.4 - 1.0.0-beta.12 + 1.0.0-beta.13 1.1.0 4.31.1