Skip to content

Commit

Permalink
Add VertexAI Embedding Model support
Browse files Browse the repository at this point in the history
 - add new spring-ai-vertex-ai-embedding project.
 - add VertexAiTextEmbeddingModel and VertexAiMultimodalEmbeddingMode with related options configuration classes.
 - add ITs
 - add auto-configuraiton and boot starters.
 - register to BOM.
 - add documentation.
 - add multimodal embedding documentation
 - extend the Embedding metdata so that it can keep references to the source document's data, Id, mediatype

 Resolves #1013
 Related to #1009
  • Loading branch information
tzolov authored and markpollack committed Jul 10, 2024
1 parent 40df264 commit c70c20b
Show file tree
Hide file tree
Showing 61 changed files with 3,432 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.function.Predicate;

import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -126,7 +126,7 @@ public AnthropicApi(String baseUrl, String anthropicApiKey, String anthropicVers
* "https://docs.anthropic.com/claude/docs/models-overview#model-comparison">model
* comparison</a> for additional details and options.
*/
public enum ChatModel implements ModelDescription {
public enum ChatModel implements ChatModelDescription {

// @formatter:off
CLAUDE_3_5_SONNET("claude-3-5-sonnet-20240620"),
Expand All @@ -153,7 +153,7 @@ public String getValue() {
}

@Override
public String getModelName() {
public String getName() {
return this.value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getModelName())
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription(
Expand All @@ -257,7 +257,7 @@ void streamFunctionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName())
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription(
Expand All @@ -280,7 +280,7 @@ void streamFunctionCallTest() {

@Test
void validateCallResponseMetadata() {
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getModelName();
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(AnthropicChatOptions.builder().withModel(model).build())
Expand All @@ -295,7 +295,7 @@ void validateCallResponseMetadata() {

@Test
void validateStreamCallResponseMetadata() {
String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName();
String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(AnthropicChatOptions.builder().withModel(model).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -227,7 +227,7 @@ public record AnthropicChatResponse(
/**
* Anthropic models version.
*/
public enum AnthropicChatModel implements ModelDescription {
public enum AnthropicChatModel implements ChatModelDescription {
/**
* anthropic.claude-instant-v1
*/
Expand Down Expand Up @@ -255,7 +255,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
Expand Down Expand Up @@ -437,7 +437,7 @@ public record Delta(@JsonProperty("type") String type, @JsonProperty("text") Str
/**
* Anthropic models version.
*/
public enum AnthropicChatModel implements ModelDescription {
public enum AnthropicChatModel implements ChatModelDescription {

/**
* anthropic.claude-instant-v1
Expand Down Expand Up @@ -482,7 +482,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -367,7 +367,7 @@ public enum FinishReason {
/**
* Cohere models version.
*/
public enum CohereChatModel implements ModelDescription {
public enum CohereChatModel implements ChatModelDescription {

/**
* cohere.command-light-text-v14
Expand All @@ -393,7 +393,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;

import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
Expand Down Expand Up @@ -372,7 +372,7 @@ public record FinishReason(
/**
* Ai21 Jurassic2 models version.
*/
public enum Ai21Jurassic2ChatModel implements ModelDescription {
public enum Ai21Jurassic2ChatModel implements ChatModelDescription {

/**
* ai21.j2-mid-v1
Expand All @@ -398,7 +398,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;

import java.time.Duration;

Expand Down Expand Up @@ -205,7 +205,7 @@ public enum StopReason {
/**
* Llama models version.
*/
public enum LlamaChatModel implements ModelDescription {
public enum LlamaChatModel implements ChatModelDescription {

/**
* meta.llama2-13b-chat-v1
Expand Down Expand Up @@ -241,7 +241,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;

/**
* Java client for the Bedrock Titan chat model.
Expand Down Expand Up @@ -266,7 +266,7 @@ public record TitanChatResponseChunk(
/**
* Titan models version.
*/
public enum TitanChatModel implements ModelDescription {
public enum TitanChatModel implements ChatModelDescription {

/**
* amazon.titan-text-lite-v1
Expand Down Expand Up @@ -297,7 +297,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.util.api.ApiUtils;
Expand Down Expand Up @@ -113,7 +113,7 @@ public MiniMaxApi(String baseUrl, String miniMaxToken, RestClient.Builder restCl
* MiniMax Chat Completion Models:
* <a href="https://www.minimaxi.com/document/algorithm-concept">MiniMax Model</a>.
*/
public enum ChatModel implements ModelDescription {
public enum ChatModel implements ChatModelDescription {
ABAB_6_5_Chat("abab6.5-chat"),
ABAB_6_5_S_Chat("abab6.5s-chat"),
ABAB_6_5_T_Chat("abab6.5t-chat"),
Expand All @@ -135,7 +135,7 @@ public String getValue() {
}

@Override
public String getModelName() {
public String getName() {
return this.value;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public Builder withModel(String model) {
}

public Builder withModel(MistralAiApi.ChatModel chatModel) {
this.options.setModel(chatModel.getModelName());
this.options.setModel(chatModel.getName());
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
Expand Down Expand Up @@ -701,7 +701,7 @@ public record ChunkChoice(
* 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral
* Medium, Mistral Large, and Mistral Embeddings).
*/
public enum ChatModel implements ModelDescription {
public enum ChatModel implements ChatModelDescription {

// @formatter:off
@Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B
Expand All @@ -728,7 +728,7 @@ public String getValue() {
}

@Override
public String getModelName() {
public String getName() {
return this.value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void streamFunctionCallTest() {

@Test
void validateCallResponseMetadata() {
String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getModelName();
String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(model).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
Expand Down Expand Up @@ -479,7 +479,7 @@ public record ChunkChoice(
* <li><b>MOONSHOT_V1_128K</b> - moonshot-v1-128k</li>
* </ul>
*/
public enum ChatModel implements ModelDescription {
public enum ChatModel implements ChatModelDescription {

// @formatter:off
MOONSHOT_V1_8K("moonshot-v1-8k"),
Expand All @@ -498,7 +498,7 @@ public String getValue() {
}

@Override
public String getModelName() {
public String getName() {
return this.value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
*/
package org.springframework.ai.ollama.api;

import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;

/**
* Helper class for common Ollama models.
*
* @author Siarhei Blashuk
* @since 0.8.1
*/
public enum OllamaModel implements ModelDescription {
public enum OllamaModel implements ChatModelDescription {

/**
* Llama 2 is a collection of language models ranging from 7B to 70B parameters.
Expand Down Expand Up @@ -102,7 +102,7 @@ public String id() {
}

@Override
public String getModelName() {
public String getName() {
return this.id;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public Builder withModel(String model) {
}

public Builder withModel(OpenAiApi.ChatModel openAiChatModel) {
this.options.model = openAiChatModel.getModelName();
this.options.model = openAiChatModel.getName();
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.retry.RetryUtils;
Expand Down Expand Up @@ -118,7 +118,7 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
* - <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a>
* - <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
*/
public enum ChatModel implements ModelDescription {
public enum ChatModel implements ChatModelDescription {
/**
* Multimodal flagship model that’s cheaper and faster than GPT-4 Turbo.
* Currently points to gpt-4o-2024-05-13.
Expand Down Expand Up @@ -221,7 +221,7 @@ public String getValue() {
}

@Override
public String getModelName() {
public String getName() {
return this.value;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ void streamingMultiModalityImageUrl() throws IOException {

@Test
void validateCallResponseMetadata() {
String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getModelName();
String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(OpenAiChatOptions.builder().withModel(model).build())
Expand Down
Loading

0 comments on commit c70c20b

Please sign in to comment.