Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
starkt committed Jul 12, 2024
1 parent e491842 commit e8415e0
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,17 @@ public class AzureOpenAiChatModel extends
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
this(microsoftOpenAiClient,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
this(microsoftOpenAiClient, options, null);
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
Expand Down Expand Up @@ -146,9 +146,9 @@ public ChatResponse call(Prompt prompt) {
logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

Expand All @@ -162,54 +162,54 @@ public Flux<ChatResponse> stream(Prompt prompt) {

// we have to map with a custom function to handle the tool call requests
// due to the existing bugs in the azure api (see comments in streamWithAzureApi)
// we have to recursively call this specific method for tool calls instead of using the one from the AbstractFunctionCallSupport
return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices)
.map(choice -> {
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
});
// we have to recursively call this specific method for tool calls instead of
// using the one from the AbstractFunctionCallSupport
return streamWithAzureOpenAi(options).flatMapIterable(ChatCompletions::getChoices).map(choice -> {
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
});
}

private Flux<ChatCompletions> streamWithAzureOpenAi(ChatCompletionsOptions options) {
options.setStream(true);

IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
.getChatCompletionsStream(options.getModel(), options);
.getChatCompletionsStream(options.getModel(), options);

Flux<ChatCompletions> chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream);

final var isFunctionCall = new AtomicBoolean(false);
final var accessibleChatCompletionsFlux = chatCompletionsFlux
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);
return accessibleChatCompletionsFlux
.switchMap(accessibleChatCompletions -> handleToolCallRequests(options,
Flux.just(accessibleChatCompletions)));
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);
return accessibleChatCompletionsFlux.switchMap(
accessibleChatCompletions -> handleToolCallRequests(options, Flux.just(accessibleChatCompletions)));

}

private Flux<ChatCompletions> handleToolCallRequests(ChatCompletionsOptions request, Flux<ChatCompletions> response) {
private Flux<ChatCompletions> handleToolCallRequests(ChatCompletionsOptions request,
Flux<ChatCompletions> response) {
return response.switchMap(resp -> {
if (!this.isToolFunctionCall(resp)) {
return Mono.just(resp);
Expand All @@ -226,9 +226,11 @@ private Flux<ChatCompletions> handleToolCallRequests(ChatCompletionsOptions requ
// Add the assistant response to the message conversation history.
conversationHistory.add(responseMessage);

ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
ChatCompletionsOptions newRequest = this.doCreateToolResponseRequest(request, responseMessage,
conversationHistory);

// recursively go backwards and call our stream again (including all bug fixes / workarounds for the azure api)
// recursively go backwards and call our stream again (including all bug fixes
// / workarounds for the azure api)
return this.streamWithAzureOpenAi(newRequest);
});
}
Expand All @@ -241,9 +243,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
Set<String> functionsForThisRequest = new HashSet<>();

List<ChatRequestMessage> azureMessages = prompt.getInstructions()
.stream()
.map(this::fromSpringAiMessage)
.toList();
.stream()
.map(this::fromSpringAiMessage)
.toList();

ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);

Expand Down Expand Up @@ -278,8 +280,8 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<ChatCompletionsFunctionToolDefinition> tools = this.getFunctionTools(functionsForThisRequest);
List<ChatCompletionsToolDefinition> tools2 = tools.stream()
.map(t -> ((ChatCompletionsToolDefinition) t))
.toList();
.map(t -> ((ChatCompletionsToolDefinition) t))
.toList();
options.setTools(tools2);
}

Expand All @@ -292,7 +294,7 @@ private List<ChatCompletionsFunctionToolDefinition> getFunctionTools(Set<String>
FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName());
functionDefinition.setDescription(functionCallback.getDescription());
BinaryData parameters = BinaryData
.fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
.fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
functionDefinition.setParameters(parameters);
return new ChatCompletionsFunctionToolDefinition(functionDefinition);
}).toList();
Expand All @@ -307,10 +309,10 @@ private ChatRequestMessage fromSpringAiMessage(Message message) {
items.add(new ChatMessageTextContentItem(message.getContent()));
if (!CollectionUtils.isEmpty(message.getMedia())) {
items.addAll(message.getMedia()
.stream()
.map(media -> new ChatMessageImageContentItem(
new ChatMessageImageUrl(media.getData().toString())))
.toList());
.stream()
.map(media -> new ChatMessageImageContentItem(
new ChatMessageImageUrl(media.getData().toString())))
.toList());
}
return new ChatRequestUserMessage(items);
case SYSTEM:
Expand All @@ -333,9 +335,9 @@ private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
chatCompletions.getPromptFilterResults());

return PromptMetadata.of(promptFilterResults.stream()
.map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(),
promptFilterResult.getContentFilterResults()))
.toList());
.map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(),
promptFilterResult.getContentFilterResults()))
.toList());
}

private <T> List<T> nullSafeList(List<T> list) {
Expand All @@ -348,7 +350,7 @@ private <T> List<T> nullSafeList(List<T> list) {
* {@link ChatCompletionsOptions} instance.
*/
private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
AzureOpenAiChatOptions toSpringAiOptions) {
AzureOpenAiChatOptions toSpringAiOptions) {

if (toSpringAiOptions == null) {
return fromAzureOptions;
Expand All @@ -364,7 +366,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
: toSpringAiOptions.getLogitBias());

mergedAzureOptions
.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());
.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());

mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
Expand Down Expand Up @@ -394,7 +396,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN());

mergedAzureOptions
.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());
.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());

mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
: toSpringAiOptions.getDeploymentName());
Expand All @@ -411,7 +413,7 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
* @return a new {@link ChatCompletionsOptions} instance.
*/
private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
ChatCompletionsOptions toAzureOptions) {
ChatCompletionsOptions toAzureOptions) {

if (fromSpringAiOptions == null) {
return toAzureOptions;
Expand Down Expand Up @@ -570,7 +572,7 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {

@Override
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {

// Every tool-call item requires a separate function call and a response (TOOL)
// message.
Expand Down Expand Up @@ -607,7 +609,7 @@ protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions requ
protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
final var accessibleChatChoice = response.getChoices().get(0);
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
.orElse(accessibleChatChoice.getDelta());
.orElse(accessibleChatChoice.getDelta());
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
final var toolCalls = responseMessage.getToolCalls();
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));

Expand All @@ -95,13 +95,13 @@ void functionCallSequentialTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));

Expand All @@ -121,27 +121,27 @@ void functionCallSequentialAndStreamTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

var response = chatModel.stream(new Prompt(messages, promptOptions));

final var counter = new AtomicInteger();
String content = response.doOnEach(listSignal -> counter.getAndIncrement())
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());

logger.info("Response: {}", response);

Expand All @@ -157,26 +157,26 @@ void streamFunctionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();

Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, promptOptions));

final var counter = new AtomicInteger();
String content = response.doOnEach(listSignal -> counter.getAndIncrement())
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(counter.get()).isGreaterThan(30).as("The response should be chunked in more than 30 messages");
Expand All @@ -192,8 +192,8 @@ public static class TestConfiguration {
@Bean
public OpenAIClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
}

@Bean
Expand Down

0 comments on commit e8415e0

Please sign in to comment.