Skip to content

Commit

Permalink
Merge branch 'main' into v891-register-java-time-module
Browse files Browse the repository at this point in the history
  • Loading branch information
v891 authored Jul 15, 2024
2 parents f1fe3fd + ed815d8 commit ee50a68
Show file tree
Hide file tree
Showing 17 changed files with 566 additions and 277 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -34,23 +33,28 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* The {@link ChatModel} implementation for the Anthropic service.
Expand All @@ -60,13 +64,11 @@
* @author Mariusz Bernacki
* @since 1.0.0
*/
public class AnthropicChatModel extends
AbstractFunctionCallSupport<AnthropicApi.AnthropicMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletionResponse>>
implements ChatModel {
public class AnthropicChatModel extends AbstractToolCallSupport<ChatCompletionResponse> implements ChatModel {

private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);

public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue();
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();

public static final Integer DEFAULT_MAX_TOKENS = 500;

Expand Down Expand Up @@ -148,7 +150,14 @@ public ChatResponse call(Prompt prompt) {
ChatCompletionRequest request = createRequest(prompt, false);

return this.retryTemplate.execute(ctx -> {
ResponseEntity<ChatCompletionResponse> completionEntity = this.callWithFunctionSupport(request);
ResponseEntity<ChatCompletionResponse> completionEntity = this.anthropicApi.chatCompletionEntity(request);

if (this.isToolFunctionCall(completionEntity.getBody())) {
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
completionEntity.getBody());
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
}

return toChatResponse(completionEntity.getBody());
});
}
Expand All @@ -162,14 +171,52 @@ public Flux<ChatResponse> stream(Prompt prompt) {

Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);

return response
.switchMap(chatCompletionResponse -> handleFunctionCallOrReturnStream(request,
Flux.just(ResponseEntity.of(Optional.of(chatCompletionResponse)))))
.map(ResponseEntity::getBody)
.map(this::toChatResponse);
return response.switchMap(chatCompletionResponse -> {

if (this.isToolFunctionCall(chatCompletionResponse)) {
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
chatCompletionResponse);
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
}

return Mono.just(chatCompletionResponse).map(this::toChatResponse);
});
});
}

private List<Message> handleToolCallRequests(List<Message> previousMessages,
ChatCompletionResponse chatCompletionResponse) {

AnthropicMessage anthropicAssistantMessage = new AnthropicMessage(chatCompletionResponse.content(),
Role.ASSISTANT);

List<ContentBlock> toolToUseList = anthropicAssistantMessage.content()
.stream()
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
.toList();

List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();

for (ContentBlock toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input());

toolCalls.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
}

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage);

// History
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
toolCallMessageConversation.add(assistantMessage);
toolCallMessageConversation.add(toolResponseMessage);

return toolCallMessageConversation;
}

private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
if (chatCompletion == null) {
logger.warn("Null chat completion returned");
Expand Down Expand Up @@ -203,18 +250,45 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() != MessageType.SYSTEM)
.map(m -> {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(m.getContent())));
if (!CollectionUtils.isEmpty(m.getMedia())) {
List<ContentBlock> mediaContent = m.getMedia()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getContent())));
if (!CollectionUtils.isEmpty(message.getMedia())) {
List<ContentBlock> mediaContent = message.getMedia()
.stream()
.map(media -> new ContentBlock(media.getMimeType().toString(),
this.fromMediaData(media.getData())))
.toList();
contents.addAll(mediaContent);
}
return new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name()));
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getContent())) {
contentBlocks.add(new ContentBlock(message.getContent()));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
contentBlocks.add(new ContentBlock(ContentBlockType.TOOL_USE, toolCall.id(),
toolCall.name(), ModelOptionsUtils.jsonToMap(toolCall.arguments())));
}
}
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
.stream()
.map(media -> new ContentBlock(media.getMimeType().toString(),
this.fromMediaData(media.getData())))
.map(toolResponse -> new ContentBlock(ContentBlockType.TOOL_RESULT, toolResponse.id(),
toolResponse.responseData()))
.toList();
contents.addAll(mediaContent);
return new AnthropicMessage(toolResponses, Role.USER);
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
return new AnthropicMessage(contents, Role.valueOf(m.getMessageType().name()));
})
.toList();

Expand Down Expand Up @@ -265,74 +339,17 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
}).toList();
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
AnthropicMessage responseMessage, List<AnthropicMessage> conversationHistory) {

List<ContentBlock> toolToUseList = responseMessage.content()
.stream()
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
.toList();

List<ContentBlock> toolResults = new ArrayList<>();

for (ContentBlock toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = toolToUse.input();

if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}

String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));

toolResults.add(new ContentBlock(ContentBlockType.TOOL_RESULT, functionCallId, functionResponse));
}

// Add the function response to the conversation.
conversationHistory.add(new AnthropicMessage(toolResults, Role.USER));

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
}

@Override
protected List<AnthropicMessage> doGetUserMessages(ChatCompletionRequest request) {
return request.messages();
}

@Override
protected AnthropicMessage doGetToolResponseMessage(ResponseEntity<ChatCompletionResponse> response) {
return new AnthropicMessage(response.getBody().content(), Role.ASSISTANT);
}

@Override
protected ResponseEntity<ChatCompletionResponse> doChatCompletion(ChatCompletionRequest request) {
return this.anthropicApi.chatCompletionEntity(request);
}

@SuppressWarnings("null")
@Override
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletionResponse> response) {
if (response == null || response.getBody() == null || CollectionUtils.isEmpty(response.getBody().content())) {
protected boolean isToolFunctionCall(ChatCompletionResponse response) {
if (response == null || CollectionUtils.isEmpty(response.content())) {
return false;
}
return response.getBody()
.content()
return response.content()
.stream()
.anyMatch(content -> content.type() == ContentBlock.ContentBlockType.TOOL_USE);
}

@Override
protected Flux<ResponseEntity<ChatCompletionResponse>> doChatCompletionStream(ChatCompletionRequest request) {

return this.anthropicApi.chatCompletionStream(request).map(Optional::ofNullable).map(ResponseEntity::of);
}

@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
*/
package org.springframework.ai.openai;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -49,17 +57,10 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
* backed by {@link OpenAiApi}.
Expand Down Expand Up @@ -265,12 +266,12 @@ private List<Message> handleToolCallRequests(List<Message> previousMessages, Cha
AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(),
assistantToolCalls);

List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage);
ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage);

// History
List<Message> messages = new ArrayList<>(previousMessages);
messages.add(assistantMessage);
messages.addAll(toolResponseMessages);
messages.add(toolResponseMessage);

return messages;
}
Expand Down Expand Up @@ -320,8 +321,8 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
content = contentList;
}

return new ChatCompletionMessage(content,
ChatCompletionMessage.Role.valueOf(message.getMessageType().name()));
return List.of(new ChatCompletionMessage(content,
ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
var assistantMessage = (AssistantMessage) message;
Expand All @@ -332,18 +333,27 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
return new ToolCall(toolCall.id(), toolCall.type(), function);
}).toList();
}
return new ChatCompletionMessage(assistantMessage.getContent(), ChatCompletionMessage.Role.ASSISTANT,
null, null, toolCalls);
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
return new ChatCompletionMessage(toolMessage.getContent(), ChatCompletionMessage.Role.TOOL,
toolMessage.getName(), toolMessage.getId(), null);

toolMessage.getResponses().forEach(response -> {
Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id");
Assert.isTrue(response.name() != null, "ToolResponseMessage must have a name");
});

return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null))
.toList();
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
}).toList();
}).flatMap(List::stream).toList();

ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = OpenAiChatModel3IT.Config.class)
@SpringBootTest(classes = OpenAiChatModelFunctionCallingIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiChatModel3IT {
class OpenAiChatModelFunctionCallingIT {

private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel3IT.class);
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelFunctionCallingIT.class);

@Autowired
ChatModel chatModel;
Expand All @@ -72,9 +72,7 @@ void functionCallTest() {

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
}

@Test
Expand Down Expand Up @@ -105,9 +103,7 @@ void streamFunctionCallTest() {
.collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(content).containsAnyOf("30.0", "30");
assertThat(content).containsAnyOf("10.0", "10");
assertThat(content).containsAnyOf("15.0", "15");
assertThat(content).contains("30", "10", "15");
}

@SpringBootConfiguration
Expand Down
Loading

0 comments on commit ee50a68

Please sign in to comment.