Skip to content

Commit

Permalink
Add Bedrock Meta LLama3 AI model support.
Browse files Browse the repository at this point in the history
 - re-enable llama structured output tests
  • Loading branch information
wmz7year authored and tzolov committed Apr 25, 2024
1 parent b0add71 commit 9e865f0
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 230 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
* OpenAI
* Azure OpenAI
* Amazon Bedrock (Anthropic, Llama2, Cohere, Titan, Jurassic2)
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
* HuggingFace
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
Expand Down
2 changes: 1 addition & 1 deletion models/spring-ai-bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- [Anthropic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-anthropic.html)
- [Cohere Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-cohere.html)
- [Cohere Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-cohere-embedding.html)
- [Llama2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama2.html)
- [Llama Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama.html)
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
Expand Down Expand Up @@ -63,9 +63,9 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class))
hints.reflection().registerType(tr, mcs);

for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class))
for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlama2ChatOptions.class))
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class))
hints.reflection().registerType(tr, mcs);

for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.llama2;
package org.springframework.ai.bedrock.llama;

import java.util.List;

import reactor.core.publisher.Flux;

import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
Expand All @@ -35,26 +35,27 @@
import org.springframework.util.Assert;

/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
* generative.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient {
public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {

private final Llama2ChatBedrockApi chatApi;
private final LlamaChatBedrockApi chatApi;

private final BedrockLlama2ChatOptions defaultOptions;
private final BedrockLlamaChatOptions defaultOptions;

public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) {
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
this(chatApi,
BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
}

public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) {
Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null");
Assert.notNull(options, "BedrockLlama2ChatOptions must not be null");
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
Assert.notNull(options, "BedrockLlamaChatOptions must not be null");

this.chatApi = chatApi;
this.defaultOptions = options;
Expand All @@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) {

var request = createRequest(prompt);

Llama2ChatResponse response = this.chatApi.chatCompletion(request);
LlamaChatResponse response = this.chatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata(
ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))));
Expand All @@ -76,7 +77,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {

var request = createRequest(prompt);

Flux<Llama2ChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);
Flux<LlamaChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);

return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason().name() : null;
Expand All @@ -85,7 +86,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

private Usage extractUsage(Llama2ChatResponse response) {
private Usage extractUsage(LlamaChatResponse response) {
return new Usage() {

@Override
Expand All @@ -103,22 +104,22 @@ public Long getGenerationTokens() {
/**
* Accessible for testing.
*/
Llama2ChatRequest createRequest(Prompt prompt) {
LlamaChatRequest createRequest(Prompt prompt) {

final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build();
LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build();

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class);
request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockLlama2ChatOptions.class);
BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockLlamaChatOptions.class);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.llama2;
package org.springframework.ai.bedrock.llama;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand All @@ -26,7 +26,7 @@
* @author Christian Tzolov
*/
@JsonInclude(Include.NON_NULL)
public class BedrockLlama2ChatOptions implements ChatOptions {
public class BedrockLlamaChatOptions implements ChatOptions {

/**
* The temperature value controls the randomness of the generated text. Use a lower
Expand All @@ -51,7 +51,7 @@ public static Builder builder() {

public static class Builder {

private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions();
private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions();

public Builder withTemperature(Float temperature) {
this.options.setTemperature(temperature);
Expand All @@ -68,7 +68,7 @@ public Builder withMaxGenLen(Integer maxGenLen) {
return this;
}

public BedrockLlama2ChatOptions build() {
public BedrockLlamaChatOptions build() {
return this.options;
}

Expand Down
Loading

0 comments on commit 9e865f0

Please sign in to comment.