-
Notifications
You must be signed in to change notification settings - Fork 917
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Bedrock Anthropic Claude 3 models support
- Created low-leverl anthropic messages API for Claude 3. Use the new message API. - Add Chat Client with tests. - Add bedrok anthropic 3 docs. - Add multibudality support + tests. - Add auto-configuraiton & tests. - Rename Athropic to Athropic3 in class names to avoid confusion with previous Bedrock Anthropic 2 impl.
- Loading branch information
1 parent
7634c6b
commit e004751
Showing
20 changed files
with
1,839 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
166 changes: 166 additions & 0 deletions
166
...edrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
/* | ||
* Copyright 2023 - 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.bedrock.anthropic3; | ||
|
||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.annotation.JsonInclude.Include; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import org.springframework.ai.chat.prompt.ChatOptions; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author Ben Middleton | ||
* @since 1.0.0 | ||
*/ | ||
@JsonInclude(Include.NON_NULL) | ||
public class Anthropic3ChatOptions implements ChatOptions { | ||
|
||
// @formatter:off | ||
/** | ||
* Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will | ||
* produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising | ||
* responses from the generative. This value specifies default to be used by the backend while making the call to | ||
* the generative. | ||
*/ | ||
private @JsonProperty("temperature") Float temperature; | ||
|
||
/** | ||
* Specify the maximum number of tokens to use in the generated response. Note that the models may stop before | ||
* reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We | ||
* recommend a limit of 4,000 tokens for optimal performance. | ||
*/ | ||
private @JsonProperty("max_tokens") Integer maxTokens; | ||
|
||
/** | ||
* Specify the number of token choices the generative uses to generate the next token. | ||
*/ | ||
private @JsonProperty("top_k") Integer topK; | ||
|
||
/** | ||
* The maximum cumulative probability of tokens to consider when sampling. The generative uses combined Top-k and | ||
* nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | ||
*/ | ||
private @JsonProperty("top_p") Float topP; | ||
|
||
/** | ||
* Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops | ||
* generating further tokens. The returned text doesn't contain the stop sequence. | ||
*/ | ||
private @JsonProperty("stop_sequences") List<String> stopSequences; | ||
|
||
/** | ||
* The version of the generative to use. The default value is bedrock-2023-05-31. | ||
*/ | ||
private @JsonProperty("anthropic_version") String anthropicVersion; | ||
// @formatter:on | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final Anthropic3ChatOptions options = new Anthropic3ChatOptions(); | ||
|
||
public Builder withTemperature(Float temperature) { | ||
this.options.setTemperature(temperature); | ||
return this; | ||
} | ||
|
||
public Builder withMaxTokens(Integer maxTokens) { | ||
this.options.setMaxTokens(maxTokens); | ||
return this; | ||
} | ||
|
||
public Builder withTopK(Integer topK) { | ||
this.options.setTopK(topK); | ||
return this; | ||
} | ||
|
||
public Builder withTopP(Float topP) { | ||
this.options.setTopP(topP); | ||
return this; | ||
} | ||
|
||
public Builder withStopSequences(List<String> stopSequences) { | ||
this.options.setStopSequences(stopSequences); | ||
return this; | ||
} | ||
|
||
public Builder withAnthropicVersion(String anthropicVersion) { | ||
this.options.setAnthropicVersion(anthropicVersion); | ||
return this; | ||
} | ||
|
||
public Anthropic3ChatOptions build() { | ||
return this.options; | ||
} | ||
|
||
} | ||
|
||
@Override | ||
public Float getTemperature() { | ||
return this.temperature; | ||
} | ||
|
||
public void setTemperature(Float temperature) { | ||
this.temperature = temperature; | ||
} | ||
|
||
public Integer getMaxTokens() { | ||
return this.maxTokens; | ||
} | ||
|
||
public void setMaxTokens(Integer maxTokens) { | ||
this.maxTokens = maxTokens; | ||
} | ||
|
||
@Override | ||
public Integer getTopK() { | ||
return this.topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
@Override | ||
public Float getTopP() { | ||
return this.topP; | ||
} | ||
|
||
public void setTopP(Float topP) { | ||
this.topP = topP; | ||
} | ||
|
||
public List<String> getStopSequences() { | ||
return this.stopSequences; | ||
} | ||
|
||
public void setStopSequences(List<String> stopSequences) { | ||
this.stopSequences = stopSequences; | ||
} | ||
|
||
public String getAnthropicVersion() { | ||
return this.anthropicVersion; | ||
} | ||
|
||
public void setAnthropicVersion(String anthropicVersion) { | ||
this.anthropicVersion = anthropicVersion; | ||
} | ||
|
||
} |
190 changes: 190 additions & 0 deletions
190
.../src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
/* | ||
* Copyright 2023 - 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.bedrock.anthropic3; | ||
|
||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; | ||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; | ||
import org.springframework.ai.chat.ChatClient; | ||
import org.springframework.ai.chat.ChatResponse; | ||
import org.springframework.ai.chat.Generation; | ||
import org.springframework.ai.chat.StreamingChatClient; | ||
import org.springframework.ai.chat.messages.Message; | ||
import org.springframework.ai.chat.messages.MessageType; | ||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata; | ||
import org.springframework.ai.chat.prompt.ChatOptions; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.model.ModelOptionsUtils; | ||
import org.springframework.util.CollectionUtils; | ||
|
||
import reactor.core.publisher.Flux; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Base64; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
import java.util.stream.Collectors; | ||
|
||
/** | ||
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat | ||
* generative. | ||
* | ||
* @author Ben Middleton | ||
* @author Christian Tzolov | ||
* @since 1.0.0 | ||
*/ | ||
public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient { | ||
|
||
private final Anthropic3ChatBedrockApi anthropicChatApi; | ||
|
||
private final Anthropic3ChatOptions defaultOptions; | ||
|
||
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi) { | ||
this(chatApi, | ||
Anthropic3ChatOptions.builder() | ||
.withTemperature(0.8f) | ||
.withMaxTokens(500) | ||
.withTopK(10) | ||
.withAnthropicVersion(Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) | ||
.build()); | ||
} | ||
|
||
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) { | ||
this.anthropicChatApi = chatApi; | ||
this.defaultOptions = options; | ||
} | ||
|
||
@Override | ||
public ChatResponse call(Prompt prompt) { | ||
|
||
AnthropicChatRequest request = createRequest(prompt); | ||
|
||
AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); | ||
|
||
return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); | ||
} | ||
|
||
@Override | ||
public Flux<ChatResponse> stream(Prompt prompt) { | ||
|
||
AnthropicChatRequest request = createRequest(prompt); | ||
|
||
Flux<Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse> fluxResponse = this.anthropicChatApi | ||
.chatCompletionStream(request); | ||
|
||
AtomicReference<Integer> inputTokens = new AtomicReference<>(0); | ||
return fluxResponse.map(response -> { | ||
if (response.type() == StreamingType.MESSAGE_START) { | ||
inputTokens.set(response.message().usage().inputTokens()); | ||
} | ||
String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : ""; | ||
|
||
var generation = new Generation(content); | ||
|
||
if (response.type() == StreamingType.MESSAGE_DELTA) { | ||
generation = generation.withGenerationMetadata(ChatGenerationMetadata | ||
.from(response.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), | ||
response.usage().outputTokens()))); | ||
} | ||
|
||
return new ChatResponse(List.of(generation)); | ||
}); | ||
} | ||
|
||
/** | ||
* Accessible for testing. | ||
*/ | ||
AnthropicChatRequest createRequest(Prompt prompt) { | ||
|
||
AnthropicChatRequest request = AnthropicChatRequest.builder(toAnthropicMessages(prompt)) | ||
.withSystem(toAnthropicSystemContext(prompt)) | ||
.build(); | ||
|
||
if (this.defaultOptions != null) { | ||
request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class); | ||
} | ||
|
||
if (prompt.getOptions() != null) { | ||
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { | ||
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, | ||
ChatOptions.class, Anthropic3ChatOptions.class); | ||
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class); | ||
} | ||
else { | ||
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " | ||
+ prompt.getOptions().getClass().getSimpleName()); | ||
} | ||
} | ||
|
||
return request; | ||
} | ||
|
||
/** | ||
* Extracts system context from prompt. | ||
* @param prompt The prompt. | ||
* @return The system context. | ||
*/ | ||
private String toAnthropicSystemContext(Prompt prompt) { | ||
|
||
return prompt.getInstructions() | ||
.stream() | ||
.filter(m -> m.getMessageType() == MessageType.SYSTEM) | ||
.map(Message::getContent) | ||
.collect(Collectors.joining(System.lineSeparator())); | ||
} | ||
|
||
/** | ||
* Extracts list of messages from prompt. | ||
* @param prompt The prompt. | ||
* @return The list of {@link ChatCompletionMessage}. | ||
*/ | ||
private List<ChatCompletionMessage> toAnthropicMessages(Prompt prompt) { | ||
|
||
return prompt.getInstructions() | ||
.stream() | ||
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) | ||
.map(message -> { | ||
List<MediaContent> contents = new ArrayList<>(List.of(new MediaContent(message.getContent()))); | ||
if (!CollectionUtils.isEmpty(message.getMedia())) { | ||
List<MediaContent> mediaContent = message.getMedia() | ||
.stream() | ||
.map(media -> new MediaContent(media.getMimeType().toString(), | ||
this.fromMediaData(media.getData()))) | ||
.toList(); | ||
contents.addAll(mediaContent); | ||
} | ||
return new ChatCompletionMessage(contents, Role.valueOf(message.getMessageType().name())); | ||
}) | ||
.toList(); | ||
} | ||
|
||
private String fromMediaData(Object mediaData) { | ||
if (mediaData instanceof byte[] bytes) { | ||
return Base64.getEncoder().encodeToString(bytes); | ||
} | ||
else if (mediaData instanceof String text) { | ||
return text; | ||
} | ||
else { | ||
throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName()); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.