Skip to content

Commit

Permalink
feat: allow stream usage to be set for azure openai requests
Browse files Browse the repository at this point in the history
  • Loading branch information
garethjevans authored and ilayaperumalg committed Dec 11, 2024
1 parent 863cf38 commit e7bd8d5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
Expand Down Expand Up @@ -206,7 +207,7 @@ public ChatResponse call(Prompt prompt) {
this.observationRegistry)
.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);
ChatCompletionsOptionsAccessHelper.setStream(options, false);

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
Expand All @@ -230,7 +231,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);
ChatCompletionsOptionsAccessHelper.setStream(options, true);

Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);
Expand All @@ -252,10 +253,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
// The last element, when using stream_options will contain the usage data
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())
|| chatCompletions.getUsage() != null)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
if (!chatCompletions.getChoices().isEmpty()) {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
}
return chatCompletions;
})
.windowUntil(chatCompletions -> {
Expand Down Expand Up @@ -493,7 +498,13 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
}

ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages());
mergedAzureOptions.setStream(fromAzureOptions.isStream());

ChatCompletionsOptionsAccessHelper.setStream(mergedAzureOptions,
fromAzureOptions.isStream() != null ? fromAzureOptions.isStream() : false);

ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
fromAzureOptions.getStreamOptions() != null ? fromAzureOptions.getStreamOptions()
: toSpringAiOptions.getStreamOptions());

mergedAzureOptions.setMaxTokens((fromAzureOptions.getMaxTokens() != null) ? fromAzureOptions.getMaxTokens()
: toSpringAiOptions.getMaxTokens());
Expand Down Expand Up @@ -629,6 +640,15 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

if (fromSpringAiOptions.getStreamOptions() != null) {
ChatCompletionsOptionsAccessHelper.setStreamOptions(mergedAzureOptions,
fromSpringAiOptions.getStreamOptions());
}

if (fromSpringAiOptions.getEnhancements() != null) {
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

return mergedAzureOptions;
}

Expand All @@ -640,8 +660,13 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {

ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages());
copyOptions.setStream(fromOptions.isStream());

if (fromOptions.isStream() != null) {
ChatCompletionsOptionsAccessHelper.setStream(copyOptions, fromOptions.isStream());
}
if (fromOptions.getStreamOptions() != null) {
ChatCompletionsOptionsAccessHelper.setStreamOptions(copyOptions, fromOptions.getStreamOptions());
}
if (fromOptions.getMaxTokens() != null) {
copyOptions.setMaxTokens(fromOptions.getMaxTokens());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Set;

import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand Down Expand Up @@ -193,6 +194,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
@JsonIgnore
private AzureChatEnhancementConfiguration enhancements;

@JsonProperty("stream_options")
private ChatCompletionStreamOptions streamOptions;

@JsonIgnore
private Map<String, Object> toolContext;

Expand All @@ -219,6 +223,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.withTopLogprobs(fromOptions.getTopLogProbs())
.withEnhancements(fromOptions.getEnhancements())
.withToolContext(fromOptions.getToolContext())
.withStreamOptions(fromOptions.getStreamOptions())
.build();
}

Expand Down Expand Up @@ -412,6 +417,14 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

public ChatCompletionStreamOptions getStreamOptions() {
return this.streamOptions;
}

public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
this.streamOptions = streamOptions;
}

@Override
public AzureOpenAiChatOptions copy() {
return fromOptions(this);
Expand Down Expand Up @@ -536,6 +549,11 @@ public Builder withToolContext(Map<String, Object> toolContext) {
return this;
}

public Builder withStreamOptions(ChatCompletionStreamOptions streamOptions) {
this.options.streamOptions = streamOptions;
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
<!-- production dependencies -->
<spring-boot.version>3.3.6</spring-boot.version>
<ST4.version>4.3.4</ST4.version>
<azure-open-ai-client.version>1.0.0-beta.12</azure-open-ai-client.version>
<azure-open-ai-client.version>1.0.0-beta.13</azure-open-ai-client.version>
<jtokkit.version>1.1.0</jtokkit.version>
<victools.version>4.31.1</victools.version>

Expand Down

0 comments on commit e7bd8d5

Please sign in to comment.