Skip to content

Commit

Permalink
Add Bedrock Anthropic Claude 3 models support
Browse files Browse the repository at this point in the history
 - Created low-leverl anthropic messages API for Claude 3.
   Use the new message API.
 - Add Chat Client with tests.
 - Add bedrok anthropic 3 docs.
 - Add multibudality support + tests.
 - Add auto-configuraiton & tests.
 - Rename Athropic to Athropic3 in class names to avoid confusion with previous Bedrock Anthropic 2 impl.
  • Loading branch information
ben-gineer authored and tzolov committed Mar 19, 2024
1 parent 7634c6b commit e004751
Show file tree
Hide file tree
Showing 20 changed files with 1,839 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ public static class Builder {
private Integer topK;// = 10;
private Float topP;
private List<String> stopSequences;
// private String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;
private String anthropicVersion;

private Builder(String prompt) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.anthropic3;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.prompt.ChatOptions;

import java.util.List;

/**
* @author Ben Middleton
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class Anthropic3ChatOptions implements ChatOptions {

// @formatter:off
/**
* Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will
* produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising
* responses from the generative. This value specifies default to be used by the backend while making the call to
* the generative.
*/
private @JsonProperty("temperature") Float temperature;

/**
* Specify the maximum number of tokens to use in the generated response. Note that the models may stop before
* reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We
* recommend a limit of 4,000 tokens for optimal performance.
*/
private @JsonProperty("max_tokens") Integer maxTokens;

/**
* Specify the number of token choices the generative uses to generate the next token.
*/
private @JsonProperty("top_k") Integer topK;

/**
* The maximum cumulative probability of tokens to consider when sampling. The generative uses combined Top-k and
* nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP.
*/
private @JsonProperty("top_p") Float topP;

/**
* Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops
* generating further tokens. The returned text doesn't contain the stop sequence.
*/
private @JsonProperty("stop_sequences") List<String> stopSequences;

/**
* The version of the generative to use. The default value is bedrock-2023-05-31.
*/
private @JsonProperty("anthropic_version") String anthropicVersion;
// @formatter:on

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final Anthropic3ChatOptions options = new Anthropic3ChatOptions();

public Builder withTemperature(Float temperature) {
this.options.setTemperature(temperature);
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
this.options.setMaxTokens(maxTokens);
return this;
}

public Builder withTopK(Integer topK) {
this.options.setTopK(topK);
return this;
}

public Builder withTopP(Float topP) {
this.options.setTopP(topP);
return this;
}

public Builder withStopSequences(List<String> stopSequences) {
this.options.setStopSequences(stopSequences);
return this;
}

public Builder withAnthropicVersion(String anthropicVersion) {
this.options.setAnthropicVersion(anthropicVersion);
return this;
}

public Anthropic3ChatOptions build() {
return this.options;
}

}

@Override
public Float getTemperature() {
return this.temperature;
}

public void setTemperature(Float temperature) {
this.temperature = temperature;
}

public Integer getMaxTokens() {
return this.maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

@Override
public Integer getTopK() {
return this.topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

@Override
public Float getTopP() {
return this.topP;
}

public void setTopP(Float topP) {
this.topP = topP;
}

public List<String> getStopSequences() {
return this.stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

public String getAnthropicVersion() {
return this.anthropicVersion;
}

public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.anthropic3;

import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.CollectionUtils;

import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat
* generative.
*
* @author Ben Middleton
* @author Christian Tzolov
* @since 1.0.0
*/
public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient {

private final Anthropic3ChatBedrockApi anthropicChatApi;

private final Anthropic3ChatOptions defaultOptions;

public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi) {
this(chatApi,
Anthropic3ChatOptions.builder()
.withTemperature(0.8f)
.withMaxTokens(500)
.withTopK(10)
.withAnthropicVersion(Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION)
.build());
}

public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
this.anthropicChatApi = chatApi;
this.defaultOptions = options;
}

@Override
public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.content().get(0).text())));
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

Flux<Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse> fluxResponse = this.anthropicChatApi
.chatCompletionStream(request);

AtomicReference<Integer> inputTokens = new AtomicReference<>(0);
return fluxResponse.map(response -> {
if (response.type() == StreamingType.MESSAGE_START) {
inputTokens.set(response.message().usage().inputTokens());
}
String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : "";

var generation = new Generation(content);

if (response.type() == StreamingType.MESSAGE_DELTA) {
generation = generation.withGenerationMetadata(ChatGenerationMetadata
.from(response.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(),
response.usage().outputTokens())));
}

return new ChatResponse(List.of(generation));
});
}

/**
* Accessible for testing.
*/
AnthropicChatRequest createRequest(Prompt prompt) {

AnthropicChatRequest request = AnthropicChatRequest.builder(toAnthropicMessages(prompt))
.withSystem(toAnthropicSystemContext(prompt))
.build();

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, Anthropic3ChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

return request;
}

/**
* Extracts system context from prompt.
* @param prompt The prompt.
* @return The system context.
*/
private String toAnthropicSystemContext(Prompt prompt) {

return prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(Message::getContent)
.collect(Collectors.joining(System.lineSeparator()));
}

/**
* Extracts list of messages from prompt.
* @param prompt The prompt.
* @return The list of {@link ChatCompletionMessage}.
*/
private List<ChatCompletionMessage> toAnthropicMessages(Prompt prompt) {

return prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
.map(message -> {
List<MediaContent> contents = new ArrayList<>(List.of(new MediaContent(message.getContent())));
if (!CollectionUtils.isEmpty(message.getMedia())) {
List<MediaContent> mediaContent = message.getMedia()
.stream()
.map(media -> new MediaContent(media.getMimeType().toString(),
this.fromMediaData(media.getData())))
.toList();
contents.addAll(mediaContent);
}
return new ChatCompletionMessage(contents, Role.valueOf(message.getMessageType().name()));
})
.toList();
}

private String fromMediaData(Object mediaData) {
if (mediaData instanceof byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes);
}
else if (mediaData instanceof String text) {
return text;
}
else {
throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
}
}

}
Loading

0 comments on commit e004751

Please sign in to comment.