Skip to content

Commit

Permalink
fix for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
starkt committed Jul 12, 2024
1 parent 33c197f commit e491842
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
Expand All @@ -59,13 +61,9 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand Down Expand Up @@ -104,17 +102,17 @@ public class AzureOpenAiChatModel extends
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
this(microsoftOpenAiClient,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
this(microsoftOpenAiClient, options, null);
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
Expand Down Expand Up @@ -148,9 +146,9 @@ public ChatResponse call(Prompt prompt) {
logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

Expand All @@ -160,51 +158,81 @@ public ChatResponse call(Prompt prompt) {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);

// we have to map with a custom function to handle the tool call requests
// due to the existing bugs in the azure api (see comments in streamWithAzureApi)
// we have to recursively call this specific method for tool calls instead of using the one from the AbstractFunctionCallSupport
return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices)
.map(choice -> {
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
});
}

private Flux<ChatCompletions> streamWithAzureOpenAi(ChatCompletionsOptions options) {
options.setStream(true);

IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
.getChatCompletionsStream(options.getModel(), options);
.getChatCompletionsStream(options.getModel(), options);

Flux<ChatCompletions> chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream);

final var isFunctionCall = new AtomicBoolean(false);
final var accessibleChatCompletionsFlux = chatCompletionsFlux
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);
return accessibleChatCompletionsFlux
.switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options,
Flux.just(accessibleChatCompletions)))
.flatMapIterable(ChatCompletions::getChoices)
.map(choice -> {
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
});
.switchMap(accessibleChatCompletions -> handleToolCallRequests(options,
Flux.just(accessibleChatCompletions)));

}

private Flux<ChatCompletions> handleToolCallRequests(ChatCompletionsOptions request, Flux<ChatCompletions> response) {
return response.switchMap(resp -> {
if (!this.isToolFunctionCall(resp)) {
return Mono.just(resp);
}

// The chat completion tool call requires the complete conversation
// history. Including the initial user message.
List<ChatRequestMessage> conversationHistory = new ArrayList<>();

conversationHistory.addAll(this.doGetUserMessages(request));

ChatRequestMessage responseMessage = this.doGetToolResponseMessage(resp);

// Add the assistant response to the message conversation history.
conversationHistory.add(responseMessage);

ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);

// recursively go backwards and call our stream again (including all bug fixes / workarounds for the azure api)
return this.streamWithAzureOpenAi(newRequest);
});
}

/**
* Test access.
*/
Expand All @@ -213,9 +241,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
Set<String> functionsForThisRequest = new HashSet<>();

List<ChatRequestMessage> azureMessages = prompt.getInstructions()
.stream()
.map(this::fromSpringAiMessage)
.toList();
.stream()
.map(this::fromSpringAiMessage)
.toList();

ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);

Expand Down Expand Up @@ -250,8 +278,8 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<ChatCompletionsFunctionToolDefinition> tools = this.getFunctionTools(functionsForThisRequest);
List<ChatCompletionsToolDefinition> tools2 = tools.stream()
.map(t -> ((ChatCompletionsToolDefinition) t))
.toList();
.map(t -> ((ChatCompletionsToolDefinition) t))
.toList();
options.setTools(tools2);
}

Expand All @@ -264,7 +292,7 @@ private List<ChatCompletionsFunctionToolDefinition> getFunctionTools(Set<String>
FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName());
functionDefinition.setDescription(functionCallback.getDescription());
BinaryData parameters = BinaryData
.fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
.fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
functionDefinition.setParameters(parameters);
return new ChatCompletionsFunctionToolDefinition(functionDefinition);
}).toList();
Expand All @@ -279,10 +307,10 @@ private ChatRequestMessage fromSpringAiMessage(Message message) {
items.add(new ChatMessageTextContentItem(message.getContent()));
if (!CollectionUtils.isEmpty(message.getMedia())) {
items.addAll(message.getMedia()
.stream()
.map(media -> new ChatMessageImageContentItem(
new ChatMessageImageUrl(media.getData().toString())))
.toList());
.stream()
.map(media -> new ChatMessageImageContentItem(
new ChatMessageImageUrl(media.getData().toString())))
.toList());
}
return new ChatRequestUserMessage(items);
case SYSTEM:
Expand All @@ -305,9 +333,9 @@ private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
chatCompletions.getPromptFilterResults());

return PromptMetadata.of(promptFilterResults.stream()
.map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(),
promptFilterResult.getContentFilterResults()))
.toList());
.map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(),
promptFilterResult.getContentFilterResults()))
.toList());
}

private <T> List<T> nullSafeList(List<T> list) {
Expand All @@ -320,7 +348,7 @@ private <T> List<T> nullSafeList(List<T> list) {
* {@link ChatCompletionsOptions} instance.
*/
private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
AzureOpenAiChatOptions toSpringAiOptions) {
AzureOpenAiChatOptions toSpringAiOptions) {

if (toSpringAiOptions == null) {
return fromAzureOptions;
Expand All @@ -336,7 +364,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
: toSpringAiOptions.getLogitBias());

mergedAzureOptions
.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());
.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());

mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
Expand Down Expand Up @@ -366,7 +394,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN());

mergedAzureOptions
.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());
.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());

mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
: toSpringAiOptions.getDeploymentName());
Expand All @@ -383,7 +411,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
* @return a new {@link ChatCompletionsOptions} instance.
*/
private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
ChatCompletionsOptions toAzureOptions) {
ChatCompletionsOptions toAzureOptions) {

if (fromSpringAiOptions == null) {
return toAzureOptions;
Expand Down Expand Up @@ -542,7 +570,7 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {

@Override
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {

// Every tool-call item requires a separate function call and a response (TOOL)
// message.
Expand Down Expand Up @@ -579,7 +607,7 @@ protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions requ
protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
final var accessibleChatChoice = response.getChoices().get(0);
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
.orElse(accessibleChatChoice.getDelta());
.orElse(accessibleChatChoice.getDelta());
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
final var toolCalls = responseMessage.getToolCalls();
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
Expand Down
Loading

0 comments on commit e491842

Please sign in to comment.