diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index f8d861bd3cb..8a668f465c7 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -24,7 +24,7 @@ import java.util.function.Predicate; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.HttpHeaders; @@ -126,7 +126,7 @@ public AnthropicApi(String baseUrl, String anthropicApiKey, String anthropicVers * "https://docs.anthropic.com/claude/docs/models-overview#model-comparison">model * comparison for additional details and options. */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { // @formatter:off CLAUDE_3_5_SONNET("claude-3-5-sonnet-20240620"), @@ -153,7 +153,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 130d42a6ed4..fc817d4d3bb 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -230,7 +230,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AnthropicChatOptions.builder() - .withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getModelName()) + .withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription( @@ -257,7 +257,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = AnthropicChatOptions.builder() - .withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName()) + .withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription( @@ -280,7 +280,7 @@ void streamFunctionCallTest() { @Test void validateCallResponseMetadata() { - String model = AnthropicApi.ChatModel.CLAUDE_2_1.getModelName(); + String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName(); // @formatter:off ChatResponse response = ChatClient.create(chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) @@ -295,7 +295,7 @@ void validateCallResponseMetadata() { @Test void validateStreamCallResponseMetadata() { - String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName(); + String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName(); // @formatter:off ChatResponse response = ChatClient.create(chatModel).prompt() .options(AnthropicChatOptions.builder().withModel(model).build()) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index cd85fb7a5e3..b6a738b5132 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -29,7 +29,7 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.util.Assert; /** @@ -227,7 +227,7 @@ public record AnthropicChatResponse( /** * Anthropic models version. */ - public enum AnthropicChatModel implements ModelDescription { + public enum AnthropicChatModel implements ChatModelDescription { /** * anthropic.claude-instant-v1 */ @@ -255,7 +255,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index e584353c206..b8407debed5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -23,7 +23,7 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse; import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.util.Assert; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; @@ -437,7 +437,7 @@ public record Delta(@JsonProperty("type") String type, @JsonProperty("text") Str /** * Anthropic models version. */ - public enum AnthropicChatModel implements ModelDescription { + public enum AnthropicChatModel implements ChatModelDescription { /** * anthropic.claude-instant-v1 @@ -482,7 +482,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index 766271b87c9..0d69de07b57 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -30,7 +30,7 @@ import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.util.Assert; /** @@ -367,7 +367,7 @@ public enum FinishReason { /** * Cohere models version. */ - public enum CohereChatModel implements ModelDescription { + public enum CohereChatModel implements ChatModelDescription { /** * cohere.command-light-text-v14 @@ -393,7 +393,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index fecf70fa4e6..06f5216f71e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -27,7 +27,7 @@ import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -372,7 +372,7 @@ public record FinishReason( /** * Ai21 Jurassic2 models version. */ - public enum Ai21Jurassic2ChatModel implements ModelDescription { + public enum Ai21Jurassic2ChatModel implements ChatModelDescription { /** * ai21.j2-mid-v1 @@ -398,7 +398,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java index 16af9735edf..2531e6c7d8e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java @@ -26,7 +26,7 @@ import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import java.time.Duration; @@ -205,7 +205,7 @@ public enum StopReason { /** * Llama models version. */ - public enum LlamaChatModel implements ModelDescription { + public enum LlamaChatModel implements ChatModelDescription { /** * meta.llama2-13b-chat-v1 @@ -241,7 +241,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index ce1842adf32..f7516ddc378 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -31,7 +31,7 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; /** * Java client for the Bedrock Titan chat model. @@ -266,7 +266,7 @@ public record TitanChatResponseChunk( /** * Titan models version. */ - public enum TitanChatModel implements ModelDescription { + public enum TitanChatModel implements ChatModelDescription { /** * amazon.titan-text-lite-v1 @@ -297,7 +297,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java index 3b7c98f0c3e..725bc96202d 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java @@ -19,7 +19,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.util.api.ApiUtils; @@ -113,7 +113,7 @@ public MiniMaxApi(String baseUrl, String miniMaxToken, RestClient.Builder restCl * MiniMax Chat Completion Models: * MiniMax Model. */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { ABAB_6_5_Chat("abab6.5-chat"), ABAB_6_5_S_Chat("abab6.5s-chat"), ABAB_6_5_T_Chat("abab6.5t-chat"), @@ -135,7 +135,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 68ae47349c0..10bda6c9535 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -141,7 +141,7 @@ public Builder withModel(String model) { } public Builder withModel(MistralAiApi.ChatModel chatModel) { - this.options.setModel(chatModel.getModelName()); + this.options.setModel(chatModel.getName()); return this; } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index d4c660e4d54..c193e856bd7 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -27,7 +27,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; @@ -701,7 +701,7 @@ public record ChunkChoice( * 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral * Medium, Mistral Large, and Mistral Embeddings). */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { // @formatter:off @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B @@ -728,7 +728,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 23cbb15c37b..8a526272a7c 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -251,7 +251,7 @@ void streamFunctionCallTest() { @Test void validateCallResponseMetadata() { - String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getModelName(); + String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName(); // @formatter:off ChatResponse response = ChatClient.create(chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(model).build()) diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index f0118b718c5..e1ecb4addd0 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -18,7 +18,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; @@ -479,7 +479,7 @@ public record ChunkChoice( *
  • MOONSHOT_V1_128K - moonshot-v1-128k
  • * */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { // @formatter:off MOONSHOT_V1_8K("moonshot-v1-8k"), @@ -498,7 +498,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 449bab64781..d9275066476 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -15,7 +15,7 @@ */ package org.springframework.ai.ollama.api; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; /** * Helper class for common Ollama models. @@ -23,7 +23,7 @@ * @author Siarhei Blashuk * @since 0.8.1 */ -public enum OllamaModel implements ModelDescription { +public enum OllamaModel implements ChatModelDescription { /** * Llama 2 is a collection of language models ranging from 7B to 70B parameters. @@ -102,7 +102,7 @@ public String id() { } @Override - public String getModelName() { + public String getName() { return this.id; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 76ab2af67e6..0fc7d958256 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -188,7 +188,7 @@ public Builder withModel(String model) { } public Builder withModel(OpenAiApi.ChatModel openAiChatModel) { - this.options.model = openAiChatModel.getModelName(); + this.options.model = openAiChatModel.getName(); return this; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 59199086732..386533d4116 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -26,7 +26,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; @@ -118,7 +118,7 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie * - GPT-4 and GPT-4 Turbo * - GPT-3.5 Turbo. */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { /** * Multimodal flagship model that’s cheaper and faster than GPT-4 Turbo. * Currently points to gpt-4o-2024-05-13. @@ -221,7 +221,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 39fabed0dae..5cedfc4329b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -346,7 +346,7 @@ void streamingMultiModalityImageUrl() throws IOException { @Test void validateCallResponseMetadata() { - String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getModelName(); + String model = OpenAiApi.ChatModel.GPT_3_5_TURBO.getName(); // @formatter:off ChatResponse response = ChatClient.create(chatModel).prompt() .options(OpenAiChatOptions.builder().withModel(model).build()) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 46852496308..dcb5e3fb4da 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -198,7 +198,7 @@ public OpenAiApi chatCompletionApi() { public OpenAiChatModel openAiClient(OpenAiApi openAiApi, FunctionCallbackContext functionCallbackContext) { return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder() - .withModel(ChatModel.GPT_4_TURBO.getModelName()) + .withModel(ChatModel.GPT_4_TURBO.getName()) .withTemperature(0.1f) .build(), functionCallbackContext, RetryUtils.DEFAULT_RETRY_TEMPLATE); diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml new file mode 100644 index 00000000000..8fdbc8afc08 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -0,0 +1,83 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-vertex-ai-embedding + jar + Spring AI Model - Vertex AI Embedding + Vertex AI Embedding models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + com.google.cloud + libraries-bom + ${com.google.cloud.version} + pom + import + + + + + + + + com.google.cloud + google-cloud-aiplatform + + + commons-logging + commons-logging + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework + spring-web + ${spring-framework.version} + + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java new file mode 100644 index 00000000000..a4653f248e9 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddigConnectionDetails.java @@ -0,0 +1,163 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding; + +import java.io.IOException; + +import org.springframework.util.StringUtils; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; + +/** + * VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex + * AI embedding service. It provides methods to access the project ID, location, + * publisher, and PredictionServiceSettings. + */ +public class VertexAiEmbeddigConnectionDetails { + + private static final String DEFAULT_LOCATION = "us-central1"; + + public static final String DEFAULT_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_PUBLISHER = "google"; + + private PredictionServiceSettings predictionServiceSettings; + + /** + * Your project ID. + */ + private final String projectId; + + /** + * A location is a region + * you can specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private final String location; + + private final String publisher; + + public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) { + this.projectId = projectId; + this.location = location; + this.publisher = publisher; + + try { + this.predictionServiceSettings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + /** + * The Vertex AI embedding endpoint. + */ + private String endpoint; + + /** + * Your project ID. + */ + private String projectId; + + /** + * A location is a + * region you can + * specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private String location; + + /** + * + */ + private String publisher; + + public Builder withApiEndpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public Builder withProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder withLocation(String location) { + this.location = location; + return this; + } + + public Builder withPublisher(String publisher) { + this.publisher = publisher; + return this; + } + + public VertexAiEmbeddigConnectionDetails build() { + if (!StringUtils.hasText(this.endpoint)) { + if (!StringUtils.hasText(this.location)) { + this.endpoint = DEFAULT_ENDPOINT; + this.location = DEFAULT_LOCATION; + } + else { + this.endpoint = this.location + DEFAULT_ENDPOINT_SUFFIX; + } + } + + if (!StringUtils.hasText(this.publisher)) { + this.publisher = DEFAULT_PUBLISHER; + } + + return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher); + } + + } + + public String getProjectId() { + return this.projectId; + } + + public String getLocation() { + return this.location; + } + + public String getPublisher() { + return this.publisher; + } + + public EndpointName getEndpointName(String modelName) { + return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, + modelName); + } + + public PredictionServiceSettings getPredictionServiceSettings() { + return this.predictionServiceSettings; + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java new file mode 100644 index 00000000000..b2122f7bc73 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java @@ -0,0 +1,439 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.vertexai.embedding; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; + +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.StringUtils; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; + +/** + * Utility class for constructing parameter objects for Vertex AI embedding requests. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public abstract class VertexAiEmbeddingUtils { + + ////////////////////////////////////////////////////// + // Text Only + ////////////////////////////////////////////////////// + public static class TextParametersBuilder { + + public Integer outputDimensionality; + + public Boolean autoTruncate; + + public static TextParametersBuilder of() { + return new TextParametersBuilder(); + } + + public TextParametersBuilder withOutputDimensionality(Integer outputDimensionality) { + Assert.notNull(outputDimensionality, "Output dimensionality must not be null"); + this.outputDimensionality = outputDimensionality; + return this; + } + + public TextParametersBuilder withAutoTruncate(Boolean autoTruncate) { + Assert.notNull(autoTruncate, "Auto truncate must not be null"); + this.autoTruncate = autoTruncate; + return this; + } + + public Struct build() { + Struct.Builder textParametersBuilder = Struct.newBuilder(); + + if (this.outputDimensionality != null) { + textParametersBuilder.putFields("outputDimensionality", valueOf(this.outputDimensionality)); + } + if (this.autoTruncate != null) { + textParametersBuilder.putFields("autoTruncate", valueOf(this.autoTruncate)); + } + return textParametersBuilder.build(); + } + + } + + public static class TextInstanceBuilder { + + public String content; + + public String taskType; + + public String title; + + public static TextInstanceBuilder of(String content) { + Assert.hasText(content, "Content must not be empty"); + var builder = new TextInstanceBuilder(); + builder.content = content; + return builder; + } + + public TextInstanceBuilder withTaskType(String taskType) { + Assert.hasText(taskType, "Task type must not be empty"); + this.taskType = taskType; + return this; + } + + public TextInstanceBuilder withTitle(String title) { + Assert.hasText(title, "Title must not be empty"); + this.title = title; + return this; + } + + public Struct build() { + Struct.Builder textBuilder = Struct.newBuilder(); + textBuilder.putFields("content", valueOf(this.content)); + if (StringUtils.hasText(this.taskType)) { + textBuilder.putFields("taskType", valueOf(this.taskType)); + } + if (StringUtils.hasText(this.title)) { + textBuilder.putFields("title", valueOf(this.title)); + } + return textBuilder.build(); + } + + } + + ////////////////////////////////////////////////////// + // Multimodality + ////////////////////////////////////////////////////// + public static class MultimodalInstanceBuilder { + + /** + * The text to generate embeddings for. + */ + private String text; + + /** + * The dimension of the embedding, included in the response. Only applies to text + * and image input. Accepted values: 128, 256, 512, or 1408. + */ + private Integer dimension; + + /** + * The image to generate embeddings for. + */ + private Struct image; + + /** + * The video segment to generate embeddings for. + */ + private Struct video; + + public static MultimodalInstanceBuilder of() { + return new MultimodalInstanceBuilder(); + } + + public MultimodalInstanceBuilder withText(String text) { + Assert.hasText(text, "Text must not be empty"); + this.text = text; + return this; + } + + public MultimodalInstanceBuilder withDimension(Integer dimension) { + Assert.isTrue(dimension == 128 || dimension == 256 || dimension == 512 || dimension == 1408, + "Invalid dimension value: " + dimension + ". Accepted values: 128, 256, 512, or 1408."); + this.dimension = dimension; + return this; + } + + public MultimodalInstanceBuilder withImage(Struct image) { + Assert.notNull(image, "Image must not be null"); + this.image = image; + return this; + } + + public MultimodalInstanceBuilder withVideo(Struct video) { + Assert.notNull(video, "Video must not be null"); + this.video = video; + return this; + } + + public Struct build() { + Struct.Builder builder = Struct.newBuilder(); + + if (this.text != null) { + builder.putFields("text", valueOf(this.text)); + } + if (this.dimension != null) { + Struct.Builder dimensionBuilder = Struct.newBuilder(); + dimensionBuilder.putFields("dimension", valueOf(this.dimension)); + builder.putFields("parameters", Value.newBuilder().setStructValue(dimensionBuilder.build()).build()); + } + if (this.image != null) { + builder.putFields("image", Value.newBuilder().setStructValue(this.image).build()); + } + if (this.video != null) { + builder.putFields("video", Value.newBuilder().setStructValue(this.video).build()); + } + + Assert.isTrue(builder.getFieldsCount() > 0, "At least one of the text, image or video must be set"); + + return builder.build(); + } + + } + + public static class ImageBuilder { + + /** + * Image bytes to be encoded in a base64 string. + */ + public byte[] imageBytes; + + /** + * The Cloud Storage location of the image to perform the embedding. One of + * bytesBase64Encoded or gcsUri. + */ + public String gcsUri; + + /** + * The MIME type of the content of the image. Supported values: image/jpeg and + * image/png. + */ + public MimeType mimeType; + + public static ImageBuilder of(MimeType mimeType) { + Assert.notNull(mimeType, "MimeType must not be null"); + var builder = new ImageBuilder(); + builder.mimeType = mimeType; + return builder; + } + + public ImageBuilder withImageData(Object imageData) { + Assert.notNull(imageData, "Image data must not be null"); + if (imageData instanceof byte[] bytes) { + return withImageBytes(bytes); + } + else if (imageData instanceof String uri) { + return withGcsUri(uri); + } + else { + throw new IllegalArgumentException("Unsupported image data type: " + imageData.getClass()); + } + } + + public ImageBuilder withImageBytes(byte[] imageBytes) { + Assert.notNull(imageBytes, "Image bytes must not be null"); + this.imageBytes = imageBytes; + return this; + } + + public ImageBuilder withGcsUri(String gcsUri) { + Assert.hasText(gcsUri, "GCS URI must not be empty"); + this.gcsUri = gcsUri; + return this; + } + + public Struct build() { + + Struct.Builder imageBuilder = Struct.newBuilder(); + + if (this.imageBytes != null) { + byte[] imageData = Base64.getEncoder().encode(this.imageBytes); + String encodedImage = new String(imageData, StandardCharsets.UTF_8); + imageBuilder.putFields("bytesBase64Encoded", valueOf(encodedImage)); + } + else if (this.gcsUri != null) { + imageBuilder.putFields("gcsUri", valueOf(this.gcsUri)); + } + if (this.mimeType != null) { + imageBuilder.putFields("mimeType", valueOf(this.mimeType.toString())); + } + + Assert.isTrue(imageBuilder.getFieldsCount() > 0, "At least one of the imageBytes or gcsUri must be set"); + + return imageBuilder.build(); + } + + } + + public static class VideoBuilder { + + /** + * Video bytes to be encoded in base64 string. One of videoBytes or gcsUri. + */ + public byte[] videoBytes; + + /** + * The Cloud Storage location of the video on which to perform the embedding. One + * of videoBytes or gcsUri. + */ + public String gcsUri; + + /** + * + */ + public MimeType mimeType; + + /** + * The start offset of the video segment in seconds. If not specified, it's + * calculated with max(0, endOffsetSec - 120). + */ + public Integer startOffsetSec; + + /** + * The end offset of the video segment in seconds. If not specified, it's + * calculated with min(video length, startOffSec + 120). If both startOffSec and + * endOffSec are specified, endOffsetSec is adjusted to min(startOffsetSec+120, + * endOffsetSec). + */ + public Integer endOffsetSec; + + /** + * The interval of the video the embedding will be generated. The minimum value + * for interval_sec is 4. If the interval is less than 4, an InvalidArgumentError + * is returned. There are no limitations on the maximum value of the interval. + * However, if the interval is larger than min(video length, 120s), it impacts the + * quality of the generated embeddings. Default value: 16. + */ + public Integer intervalSec; + + public static VideoBuilder of(MimeType mimeType) { + Assert.notNull(mimeType, "MimeType must not be null"); + var builder = new VideoBuilder(); + builder.mimeType = mimeType; + return builder; + } + + public VideoBuilder withVideoData(Object imageData) { + Assert.notNull(imageData, "Video data must not be null"); + if (imageData instanceof byte[] imageBytes) { + return withVideoBytes(imageBytes); + } + else if (imageData instanceof String uri) { + return withGcsUri(uri); + } + else { + throw new IllegalArgumentException("Unsupported image data type: " + imageData.getClass()); + } + } + + public VideoBuilder withVideoBytes(byte[] imageBytes) { + Assert.notNull(imageBytes, "Video bytes must not be null"); + this.videoBytes = imageBytes; + return this; + } + + public VideoBuilder withGcsUri(String gcsUri) { + Assert.hasText(gcsUri, "GCS URI must not be empty"); + this.gcsUri = gcsUri; + return this; + } + + public VideoBuilder withStartOffsetSec(Integer startOffsetSec) { + if (startOffsetSec != null) { + this.startOffsetSec = startOffsetSec; + } + return this; + } + + public VideoBuilder withEndOffsetSec(Integer endOffsetSec) { + if (endOffsetSec != null) { + this.endOffsetSec = endOffsetSec; + } + return this; + + } + + public VideoBuilder withIntervalSec(Integer intervalSec) { + if (intervalSec != null) { + this.intervalSec = intervalSec; + } + return this; + } + + public Struct build() { + + Struct.Builder videoBuilder = Struct.newBuilder(); + + if (this.videoBytes != null) { + byte[] imageData = Base64.getEncoder().encode(this.videoBytes); + String encodedImage = new String(imageData, StandardCharsets.UTF_8); + videoBuilder.putFields("bytesBase64Encoded", valueOf(encodedImage)); + } + else if (this.gcsUri != null) { + videoBuilder.putFields("gcsUri", valueOf(this.gcsUri)); + } + if (this.mimeType != null) { + videoBuilder.putFields("mimeType", valueOf(this.mimeType.toString())); + } + + Struct.Builder videoConfigBuilder = Struct.newBuilder(); + + if (this.startOffsetSec != null) { + videoConfigBuilder.putFields("startOffsetSec", valueOf(this.startOffsetSec)); + } + if (this.endOffsetSec != null) { + videoConfigBuilder.putFields("endOffsetSec", valueOf(this.endOffsetSec)); + } + if (this.intervalSec != null) { + videoConfigBuilder.putFields("intervalSec", valueOf(this.intervalSec)); + } + if (videoConfigBuilder.getFieldsCount() > 0) { + videoBuilder.putFields("videoSegmentConfig", + Value.newBuilder().setStructValue(videoConfigBuilder.build()).build()); + } + + Assert.isTrue(videoBuilder.getFieldsCount() > 0, "At least one of the videoBytes or gcsUri must be set"); + + return videoBuilder.build(); + } + + } + + public static Value valueOf(boolean n) { + return Value.newBuilder().setBoolValue(n).build(); + } + + public static Value valueOf(String s) { + return Value.newBuilder().setStringValue(s).build(); + } + + public static Value valueOf(int n) { + return Value.newBuilder().setNumberValue(n).build(); + } + + public static Value valueOf(Struct struct) { + return Value.newBuilder().setStructValue(struct).build(); + } + + // Convert a Json string to a protobuf.Value + public static Value jsonToValue(String json) throws InvalidProtocolBufferException { + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + public static List toVector(Value value) { + return value.getListValue() + .getValuesList() + .stream() + .map(Value::getNumberValue) + // .map(Double::floatValue) + .toList(); + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java new file mode 100644 index 00000000000..e4c3132ffa4 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.multimodal; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Media; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingModel; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.MultimodalInstanceBuilder; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.VideoBuilder; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Implementation of the Vertex AI Multimodal Embedding Model. Note: This implementation + * is not yet fully functional and is subject to change. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(VertexAiMultimodalEmbeddingModel.class); + + public final VertexAiMultimodalEmbeddingOptions defaultOptions; + + private static final MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); + + private static final MimeType IMAGE_MIME_TYPE = MimeTypeUtils.parseMimeType("image/*"); + + private static final MimeType VIDEO_MIME_TYPE = MimeTypeUtils.parseMimeType("video/*"); + + private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, + MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp")); + + private final VertexAiEmbeddigConnectionDetails connectionDetails; + + public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails, + VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) { + + Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null"); + this.defaultOptions = defaultEmbeddingOptions; + this.connectionDetails = connectionDetails; + } + + @Override + public EmbeddingResponse call(DocumentEmbeddingRequest request) { + + EmbeddingResponse finalResponse = new EmbeddingResponse(List.of()); + + // merge the runtime and default vertex ai options. + VertexAiMultimodalEmbeddingOptions mergedOptions = this.defaultOptions; + + if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { + var defaultOptionsCopy = VertexAiMultimodalEmbeddingOptions.builder().from(this.defaultOptions).build(); + mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, + VertexAiMultimodalEmbeddingOptions.class); + } + + // Create the Vertex AI Prediction Service client. + try (PredictionServiceClient client = PredictionServiceClient + .create(this.connectionDetails.getPredictionServiceSettings())) { + + EndpointName endpointName = this.connectionDetails.getEndpointName(mergedOptions.getModel()); + + for (Document document : request.getInstructions()) { + EmbeddingResponse singleDocResponse = this.doSingleDocumentPrediction(client, endpointName, document, + mergedOptions); + var mergedEmbeddings = new ArrayList<>(finalResponse.getResults()); + mergedEmbeddings.addAll(singleDocResponse.getResults()); + finalResponse = new EmbeddingResponse(mergedEmbeddings, singleDocResponse.getMetadata()); + } + + } + catch (Exception e) { + throw new RuntimeException(e); + } + + return finalResponse; + } + + record DocumentMetadata(String documentId, MimeType mimeType, Object data) { + } + + private EmbeddingResponse doSingleDocumentPrediction(PredictionServiceClient client, EndpointName endpointName, + Document document, VertexAiMultimodalEmbeddingOptions mergedOptions) throws InvalidProtocolBufferException { + + var instanceBuilder = MultimodalInstanceBuilder.of(); + + Map documentMetadata = new EnumMap<>(ModalityType.class); + + // optional dimensions parameter + if (mergedOptions.getDimensions() != null) { + instanceBuilder.withDimension(mergedOptions.getDimensions()); + } + + // optional text parameter + if (StringUtils.hasText(document.getContent())) { + instanceBuilder.withText(document.getContent()); + documentMetadata.put(ModalityType.TEXT, + new DocumentMetadata(document.getId(), MimeTypeUtils.TEXT_PLAIN, document.getContent())); + } + + if (!CollectionUtils.isEmpty(document.getMedia())) { + + for (Media media : document.getMedia()) { + if (media.getMimeType().isCompatibleWith(TEXT_MIME_TYPE)) { + instanceBuilder.withText(media.getData().toString()); + documentMetadata.put(ModalityType.TEXT, + new DocumentMetadata(document.getId(), MimeTypeUtils.TEXT_PLAIN, media.getData())); + if (StringUtils.hasText(document.getContent())) { + logger.warn("Media type String overrides the Document text content!"); + } + } + else if (media.getMimeType().isCompatibleWith(IMAGE_MIME_TYPE)) { + if (SUPPORTED_IMAGE_MIME_SUB_TYPES.contains(media.getMimeType())) { + instanceBuilder + .withImage(ImageBuilder.of(media.getMimeType()).withImageData(media.getData()).build()); + documentMetadata.put(ModalityType.IMAGE, + new DocumentMetadata(document.getId(), media.getMimeType(), media.getData())); + } + else { + logger.warn("Unsupported image mime type: {}", media.getMimeType()); + throw new IllegalArgumentException("Unsupported image mime type: " + media.getMimeType()); + } + } + else if (media.getMimeType().isCompatibleWith(VIDEO_MIME_TYPE)) { + instanceBuilder.withVideo(VideoBuilder.of(media.getMimeType()) + .withVideoData(media.getData()) + .withStartOffsetSec(mergedOptions.getVideoStartOffsetSec()) + .withEndOffsetSec(mergedOptions.getVideoEndOffsetSec()) + .withIntervalSec(mergedOptions.getVideoIntervalSec()) + .build()); + documentMetadata.put(ModalityType.VIDEO, + new DocumentMetadata(document.getId(), media.getMimeType(), media.getData())); + } + else { + logger.warn("Unsupported media type: {}", media.getMimeType()); + throw new IllegalArgumentException("Unsupported media type: " + media.getMimeType()); + } + } + } + + List instances = List.of(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build())); + + PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .setParameters(VertexAiEmbeddingUtils.jsonToValue(ModelOptionsUtils.toJsonString(Map.of()))) + .addAllInstances(instances); + + PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build()); + + int index = 0; + List embeddingList = new ArrayList<>(); + for (Value prediction : embeddingResponse.getPredictionsList()) { + if (prediction.getStructValue().containsFields("textEmbedding")) { + Value textEmbedding = prediction.getStructValue().getFieldsOrThrow("textEmbedding"); + List textVector = VertexAiEmbeddingUtils.toVector(textEmbedding); + + var docMetadata = documentMetadata.get(ModalityType.TEXT); + embeddingList.add(new Embedding(textVector, index++, new EmbeddingResultMetadata(docMetadata.documentId, + ModalityType.TEXT, docMetadata.mimeType, docMetadata.data))); + } + if (prediction.getStructValue().containsFields("imageEmbedding")) { + Value imageEmbedding = prediction.getStructValue().getFieldsOrThrow("imageEmbedding"); + List imageVector = VertexAiEmbeddingUtils.toVector(imageEmbedding); + + var docMetadata = documentMetadata.get(ModalityType.IMAGE); + embeddingList + .add(new Embedding(imageVector, index++, new EmbeddingResultMetadata(docMetadata.documentId, + ModalityType.IMAGE, docMetadata.mimeType, docMetadata.data))); + } + if (prediction.getStructValue().containsFields("videoEmbeddings")) { + Value videoEmbeddings = prediction.getStructValue().getFieldsOrThrow("videoEmbeddings"); + if (videoEmbeddings.getListValue().getValues(0).getStructValue().containsFields("embedding")) { + Value embeddings = videoEmbeddings.getListValue() + .getValues(0) + .getStructValue() + .getFieldsOrThrow("embedding"); + List videoVector = VertexAiEmbeddingUtils.toVector(embeddings); + + var docMetadata = documentMetadata.get(ModalityType.VIDEO); + embeddingList + .add(new Embedding(videoVector, index++, new EmbeddingResultMetadata(docMetadata.documentId, + ModalityType.VIDEO, docMetadata.mimeType, docMetadata.data))); + } + } + } + + String deploymentModelId = embeddingResponse.getDeployedModelId(); + + EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), -1); + + responseMetadata.put("deployment-model-id", + StringUtils.hasText(deploymentModelId) ? deploymentModelId : "unknown"); + + return new EmbeddingResponse(embeddingList, generateResponseMetadata(mergedOptions.getModel(), 0)); + + } + + private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("model", model); + metadata.put("total-tokens", tokenCount); + return metadata; + } + + @Override + public int dimensions() { + return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), 768); + } + + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiMultimodalEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, + VertexAiMultimodalEmbeddingModelName::getDimensions)); + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java new file mode 100644 index 00000000000..5dc546b5f3c --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.multimodal; + +import org.springframework.ai.model.EmbeddingModelDescription; + +/** + * VertexAI Embedding Models: - Text + * embeddings - Multimodal + * embeddings + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public enum VertexAiMultimodalEmbeddingModelName implements EmbeddingModelDescription { + + /** + * Multimodal model.Expires on May 14, 2025. + */ + MULTIMODAL_EMBEDDING_001("multimodalembedding@001", "001", 1408, "Multimodal model"); + + private final String modelVersion; + + private final String modelName; + + private final String description; + + private final int dimensions; + + VertexAiMultimodalEmbeddingModelName(String value, String modelVersion, int dimensions, String description) { + this.modelName = value; + this.modelVersion = modelVersion; + this.dimensions = dimensions; + this.description = description; + } + + @Override + public String getName() { + return this.modelName; + } + + @Override + public String getVersion() { + return this.modelVersion; + } + + @Override + public int getDimensions() { + return this.dimensions; + } + + @Override + public String getDescription() { + return this.description; + } + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java new file mode 100644 index 00000000000..1e6bc8802c9 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java @@ -0,0 +1,211 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.multimodal; + +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Class representing the options for Vertex AI Multimodal Embedding. + * + *

    + * The options include the embedding model name, the number of dimensions of the resulting + * output, the start and end offset of the video segment, and the interval of the video + * for embedding generation. + *

    + * + *

    + * The supported embedding models are text-embedding-004, text-multilingual-embedding-002, + * and multimodalembedding@001. + *

    + * + *

    + * The number of dimensions is used to specify the size of the resulting output + * embeddings. This can be useful for storage optimization purposes. Supported for model + * version 004 and later. + *

    + * + *

    + * The video start offset and end offset specify the segment of the video to be used for + * embedding generation. If not specified, the default values are calculated based on the + * video length and are adjusted to ensure a minimum segment of 120 seconds. + *

    + * + *

    + * The video interval specifies the period of the video over which embeddings will be + * generated. The minimum value is 4, and if it is lower, an InvalidArgumentError is + * returned. There is no maximum limit for the interval value, but if it exceeds the video + * length or 120 seconds, it may impact the quality of the generated embeddings. The + * default value is 16. + *

    + * + * @author Christian Tzolov + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class VertexAiMultimodalEmbeddingOptions implements EmbeddingOptions { + + public static final String DEFAULT_MODEL_NAME = VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001 + .getName(); + + // @formatter:off + /** + * The embedding model name to use. Supported models are: + * text-embedding-004, text-multilingual-embedding-002 and multimodalembedding@001. + */ + private @JsonProperty("model") String model; + + /** + * The number of dimensions the resulting output embeddings should have. + * Supported for model version 004 and later. You can use this parameter to reduce the + * embedding size, for example, for storage optimization. + */ + private @JsonProperty("dimensions") Integer dimensions; + + /** + * The start offset of the video segment in seconds. If not specified, it's calculated with max(0, endOffsetSec - 120). + */ + private @JsonProperty("videoStartOffsetSec") Integer videoStartOffsetSec; + + + /** + * The end offset of the video segment in seconds. If not specified, it's calculated with min(video length, startOffSec + 120). + * If both startOffSec and endOffSec are specified, endOffsetSec is adjusted to min(startOffsetSec+120, endOffsetSec). + */ + private @JsonProperty("videoEndOffsetSec") Integer videoEndOffsetSec; + + /** + * The interval of the video the embedding will be generated. The minimum value for interval_sec is 4. + * If the interval is less than 4, an InvalidArgumentError is returned. There are no limitations on the maximum value + * of the interval. However, if the interval is larger than min(video length, 120s), it impacts the quality of the + * generated embeddings. Default value: 16. + */ + private @JsonProperty("videoIntervalSec") Integer videoIntervalSec; + + + // @formatter:on + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected VertexAiMultimodalEmbeddingOptions options; + + public Builder() { + this.options = new VertexAiMultimodalEmbeddingOptions(); + } + + public Builder from(VertexAiMultimodalEmbeddingOptions fromOptions) { + if (fromOptions.getDimensions() != null) { + this.options.setDimensions(fromOptions.getDimensions()); + } + if (StringUtils.hasText(fromOptions.getModel())) { + this.options.setModel(fromOptions.getModel()); + } + if (fromOptions.getVideoStartOffsetSec() != null) { + this.options.setVideoStartOffsetSec(fromOptions.getVideoStartOffsetSec()); + } + if (fromOptions.getVideoEndOffsetSec() != null) { + this.options.setVideoEndOffsetSec(fromOptions.getVideoEndOffsetSec()); + } + if (fromOptions.getVideoIntervalSec() != null) { + this.options.setVideoIntervalSec(fromOptions.getVideoIntervalSec()); + } + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withModel(VertexAiMultimodalEmbeddingModelName model) { + this.options.setModel(model.getName()); + return this; + } + + public Builder withDimensions(Integer dimensions) { + this.options.setDimensions(dimensions); + return this; + } + + public Builder withVideoStartOffsetSec(Integer videoStartOffsetSec) { + this.options.setVideoStartOffsetSec(videoStartOffsetSec); + return this; + } + + public Builder withVideoEndOffsetSec(Integer videoEndOffsetSec) { + this.options.setVideoEndOffsetSec(videoEndOffsetSec); + return this; + } + + public Builder withVideoIntervalSec(Integer videoIntervalSec) { + this.options.setVideoIntervalSec(videoIntervalSec); + return this; + } + + public VertexAiMultimodalEmbeddingOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public Integer getVideoStartOffsetSec() { + return this.videoStartOffsetSec; + } + + public void setVideoStartOffsetSec(Integer videoStartOffsetSec) { + this.videoStartOffsetSec = videoStartOffsetSec; + } + + public Integer getVideoEndOffsetSec() { + return this.videoEndOffsetSec; + } + + public void setVideoEndOffsetSec(Integer videoEndOffsetSec) { + this.videoEndOffsetSec = videoEndOffsetSec; + } + + public Integer getVideoIntervalSec() { + return this.videoIntervalSec; + } + + public void setVideoIntervalSec(Integer videoIntervalSec) { + this.videoIntervalSec = videoIntervalSec; + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java new file mode 100644 index 00000000000..70364a4394c --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -0,0 +1,155 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.text; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.AbstractEmbeddingModel; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A class representing a Vertex AI Text Embedding Model. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { + + public final VertexAiTextEmbeddingOptions defaultOptions; + + private final VertexAiEmbeddigConnectionDetails connectionDetails; + + public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails, + VertexAiTextEmbeddingOptions defaultEmbeddingOptions) { + + Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null"); + + this.defaultOptions = defaultEmbeddingOptions.initializeDefaults(); + + this.connectionDetails = connectionDetails; + } + + @Override + public List embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent()); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions; + + if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) { + var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); + finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, + VertexAiTextEmbeddingOptions.class); + } + + try (PredictionServiceClient client = PredictionServiceClient + .create(this.connectionDetails.getPredictionServiceSettings())) { + + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + + PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()); + + TextParametersBuilder parametersBuilder = TextParametersBuilder.of(); + + if (finalOptions.getAutoTruncate() != null) { + parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate()); + } + + if (finalOptions.getDimensions() != null) { + parametersBuilder.withOutputDimensionality(finalOptions.getDimensions()); + } + + predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build())); + + for (int i = 0; i < request.getInstructions().size(); i++) { + + TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i)) + .withTaskType(finalOptions.getTaskType().name()); + if (StringUtils.hasText(finalOptions.getTitle())) { + instanceBuilder.withTitle(finalOptions.getTitle()); + } + predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build())); + } + + PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build()); + + int index = 0; + int totalTokenCount = 0; + List embeddingList = new ArrayList<>(); + for (Value prediction : embeddingResponse.getPredictionsList()) { + Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); + Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); + Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); + totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); + + Value values = embeddings.getStructValue().getFieldsOrThrow("values"); + + List vectorValues = VertexAiEmbeddingUtils.toVector(values); + + embeddingList.add(new Embedding(vectorValues, index++)); + } + return new EmbeddingResponse(embeddingList, + generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("model", model); + metadata.put("total-tokens", tokenCount); + return metadata; + } + + @Override + public int dimensions() { + return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); + } + + private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream + .of(VertexAiTextEmbeddingModelName.values()) + .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, + VertexAiTextEmbeddingModelName::getDimensions)); + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java new file mode 100644 index 00000000000..c49471d061c --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.text; + +import org.springframework.ai.model.EmbeddingModelDescription; + +/** + * VertexAI Embedding Models: - Text + * embeddings - Multimodal + * embeddings + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public enum VertexAiTextEmbeddingModelName implements EmbeddingModelDescription { + + /** + * English model. Expires on May 14, 2025. + */ + TEXT_EMBEDDING_004("text-embedding-004", "004", 768, "English text model"), + + /** + * Multilingual model. Expires on May 14, 2025. + */ + TEXT_MULTILINGUAL_EMBEDDING_002("text-multilingual-embedding-002", "002", 768, "Multilingual text model"); + + private final String modelVersion; + + private final String modelName; + + private final String description; + + private final int dimensions; + + VertexAiTextEmbeddingModelName(String value, String modelVersion, int dimensions, String description) { + this.modelName = value; + this.modelVersion = modelVersion; + this.dimensions = dimensions; + this.description = description; + } + + @Override + public String getName() { + return this.modelName; + } + + @Override + public String getVersion() { + return this.modelVersion; + } + + @Override + public int getDimensions() { + return this.dimensions; + } + + @Override + public String getDescription() { + return this.description; + } + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java new file mode 100644 index 00000000000..a7b2e14020c --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java @@ -0,0 +1,226 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.text; + +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.util.StringUtils; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class VertexAiTextEmbeddingOptions implements EmbeddingOptions { + + public static final String DEFAULT_MODEL_NAME = VertexAiTextEmbeddingModelName.TEXT_EMBEDDING_004.getName(); + + public enum TaskType { + + /** + * Specifies the given text is a document in a search/retrieval setting. + */ + RETRIEVAL_QUERY, + + /** + * Specifies the given text is a query in a search/retrieval setting. + */ + RETRIEVAL_DOCUMENT, + + /** + * Specifies the given text will be used for semantic textual similarity (STS). + */ + SEMANTIC_SIMILARITY, + + /** + * Specifies that the embeddings will be used for classification. + */ + CLASSIFICATION, + + /** + * Specifies that the embeddings will be used for clustering. + */ + CLUSTERING, + + /** + * Specifies that the query embedding is used for answering questions. Use + * RETRIEVAL_DOCUMENT for the document side. + */ + QUESTION_ANSWERING, + + /** + * Specifies that the query embedding is used for fact verification. + */ + FACT_VERIFICATION + + } + + // @formatter:off + /** + * The embedding model name to use. Supported models are: + * text-embedding-004, text-multilingual-embedding-002 and multimodalembedding@001. + */ + private @JsonProperty("model") String model; + + /** + * The intended downstream application to help the model produce better quality embeddings. + * Not all model versions support all task types. + */ + private @JsonProperty("task") TaskType taskType; + + /** + * The number of dimensions the resulting output embeddings should have. + * Supported for model version 004 and later. You can use this parameter to reduce the + * embedding size, for example, for storage optimization. + */ + private @JsonProperty("dimensions") Integer dimensions; + + /** + * Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. + */ + private @JsonProperty("title") String title; + + + /** + * When set to true, input text will be truncated. When set to false, an error is returned + * if the input text is longer than the maximum length supported by the model. Defaults to true. + */ + private @JsonProperty("autoTruncate") Boolean autoTruncate; + + + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected VertexAiTextEmbeddingOptions options; + + public Builder() { + this.options = new VertexAiTextEmbeddingOptions(); + } + + public Builder from(VertexAiTextEmbeddingOptions fromOptions) { + if (fromOptions.getDimensions() != null) { + this.options.setDimensions(fromOptions.getDimensions()); + } + if (StringUtils.hasText(fromOptions.getModel())) { + this.options.setModel(fromOptions.getModel()); + } + if (fromOptions.getTaskType() != null) { + this.options.setTaskType(fromOptions.getTaskType()); + } + if (StringUtils.hasText(fromOptions.getTitle())) { + this.options.setTitle(fromOptions.getTitle()); + } + return this; + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withModel(VertexAiTextEmbeddingModelName model) { + this.options.setModel(model.getName()); + return this; + } + + public Builder withTaskType(TaskType taskType) { + this.options.setTaskType(taskType); + return this; + } + + public Builder withDimensions(Integer dimensions) { + this.options.dimensions = dimensions; + return this; + } + + public Builder withTitle(String user) { + this.options.setTitle(user); + return this; + } + + public Builder withAutoTruncate(Boolean autoTruncate) { + this.options.setAutoTruncate(autoTruncate); + return this; + } + + public VertexAiTextEmbeddingOptions build() { + return this.options; + } + + } + + public VertexAiTextEmbeddingOptions initializeDefaults() { + + if (this.getTaskType() == null) { + this.setTaskType(TaskType.RETRIEVAL_DOCUMENT); + } + + if (StringUtils.hasText(this.getTitle()) && this.getTaskType() != TaskType.RETRIEVAL_DOCUMENT) { + throw new IllegalArgumentException("Title is only valid with task_type=RETRIEVAL_DOCUMENT"); + } + + return this; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public TaskType getTaskType() { + return this.taskType; + } + + public void setTaskType(TaskType taskType) { + this.taskType = taskType; + } + + public Integer getDimensions() { + return this.dimensions; + } + + public void setDimensions(Integer dimensions) { + this.dimensions = dimensions; + } + + public String getTitle() { + return this.title; + } + + public void setTitle(String user) { + this.title = user; + } + + public Boolean getAutoTruncate() { + return this.autoTruncate; + } + + public void setAutoTruncate(Boolean autoTruncate) { + this.autoTruncate = autoTruncate; + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java new file mode 100644 index 00000000000..652f8abee21 --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodelEmbeddingModelIT.java @@ -0,0 +1,231 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.multimodal; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.messages.Media; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.ClassPathResource; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +@SpringBootTest(classes = VertexAiMultimodelEmbeddingModelIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +class VertexAiMultimodelEmbeddingModelIT { + + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api + + @Autowired + private VertexAiMultimodalEmbeddingModel multiModelEmbeddingModel; + + @Test + void multipleInstancesEmbedding() { + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(new Document("Hello World"), + new Document("Hello World2")); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(2); + + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.TEXT_PLAIN); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getDocumentId()) + .isEqualTo(embeddingRequest.getInstructions().get(0).getId()); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getResults().get(1).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(1).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.TEXT_PLAIN); + assertThat(embeddingResponse.getResults().get(1).getMetadata().getDocumentId()) + .isEqualTo(embeddingRequest.getInstructions().get(1).getId()); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @Test + void textContentEmbedding() { + + var document = new Document("Hello World"); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.TEXT_PLAIN); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @Test + void textMediaEmbedding() { + assertThat(multiModelEmbeddingModel).isNotNull(); + + var document = Document.builder().withMedia(new Media(MimeTypeUtils.TEXT_PLAIN, "Hello World")).build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.TEXT_PLAIN); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @Test + void imageEmbedding() { + + var document = Document.builder() + .withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png"))) + .build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.IMAGE); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(MimeTypeUtils.IMAGE_PNG); + + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @Test + void videoEmbedding() { + + var document = Document.builder() + .withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4"))) + .build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.VIDEO); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getMimeType()) + .isEqualTo(new MimeType("video", "mp4")); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @Test + void textImageAndVideoEmbedding() { + + var document = Document.builder() + .withContent("Hello World") + .withMedia(new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png"))) + .withMedia(new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4"))) + .build(); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(document); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(3); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getResults().get(1)).isNotNull(); + assertThat(embeddingResponse.getResults().get(1).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.IMAGE); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getResults().get(2)).isNotNull(); + assertThat(embeddingResponse.getResults().get(2).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.VIDEO); + assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public VertexAiEmbeddigConnectionDetails connectionDetails() { + return VertexAiEmbeddigConnectionDetails.builder() + .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")) + .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION")) + .build(); + } + + @Bean + public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel( + VertexAiEmbeddigConnectionDetails connectionDetails) { + + VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder() + .withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001) + .build(); + + return new VertexAiMultimodalEmbeddingModel(connectionDetails, options); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java new file mode 100644 index 00000000000..43a6c86cc3b --- /dev/null +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.embedding.text; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +class VertexAiTextEmbeddingModelIT { + + // https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/textembedding-gecko?project=gen-lang-client-0587361272 + + @Autowired + private VertexAiTextEmbeddingModel embeddingModel; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "text-embedding-004", "text-multilingual-embedding-002" }) + void defaultEmbedding(String modelName) { + assertThat(embeddingModel).isNotNull(); + + var options = VertexAiTextEmbeddingOptions.builder().withModel(modelName).build(); + + EmbeddingResponse embeddingResponse = embeddingModel + .call(new EmbeddingRequest(List.of("Hello World", "World is Big"), options)); + + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768); + assertThat(embeddingResponse.getMetadata()).containsEntry("model", modelName); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 5); + + assertThat(embeddingModel.dimensions()).isEqualTo(768); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public VertexAiEmbeddigConnectionDetails connectionDetails() { + return VertexAiEmbeddigConnectionDetails.builder() + .withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")) + .withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION")) + .build(); + } + + @Bean + public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) { + + VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiTextEmbeddingModel(connectionDetails, options); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-embedding/src/test/resources/test.image.png b/models/spring-ai-vertex-ai-embedding/src/test/resources/test.image.png new file mode 100644 index 00000000000..8abb4c81aea Binary files /dev/null and b/models/spring-ai-vertex-ai-embedding/src/test/resources/test.image.png differ diff --git a/models/spring-ai-vertex-ai-embedding/src/test/resources/test.video.mp4 b/models/spring-ai-vertex-ai-embedding/src/test/resources/test.video.mp4 new file mode 100644 index 00000000000..543d1ab2846 Binary files /dev/null and b/models/spring-ai-vertex-ai-embedding/src/test/resources/test.video.mp4 differ diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 83d639ad43a..8cc1d732346 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -43,7 +43,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -98,7 +98,7 @@ public String getValue() { } - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { GEMINI_PRO_VISION("gemini-pro-vision"), @@ -119,7 +119,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } diff --git a/models/spring-ai-vertex-ai-palm2/README.md b/models/spring-ai-vertex-ai-palm2/README.md index 59a1315fb37..c414ff68cf3 100644 --- a/models/spring-ai-vertex-ai-palm2/README.md +++ b/models/spring-ai-vertex-ai-palm2/README.md @@ -1,4 +1,4 @@ [VertexAI PaLM2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/vertexai-palm2-chat.html) -[VertexAI PaLM2 Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/vertexai-embeddings.html) +[VertexAI PaLM2 Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/vertexai-embeddings-palm2.html) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 143024e3a2a..d29e4e12f96 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -18,7 +18,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.util.api.ApiUtils; @@ -112,7 +112,7 @@ public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restCl * ZhiPuAI Chat Completion Models: * ZhiPuAI Model. */ - public enum ChatModel implements ModelDescription { + public enum ChatModel implements ChatModelDescription { GLM_4("GLM-4"), GLM_4V("glm-4v"), GLM_4_Air("glm-4-air"), @@ -131,7 +131,7 @@ public String getValue() { } @Override - public String getModelName() { + public String getName() { return this.value; } } diff --git a/pom.xml b/pom.xml index effd402053e..ce1bc39a47f 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,7 @@ models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-gemini + models/spring-ai-vertex-ai-embedding models/spring-ai-vertex-ai-palm2 models/spring-ai-watsonx-ai models/spring-ai-zhipuai @@ -90,6 +91,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-qianfan spring-ai-spring-boot-starters/spring-ai-starter-stability-ai spring-ai-spring-boot-starters/spring-ai-starter-transformers + spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2 spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 281963bdfcc..c31495befef 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -105,6 +105,12 @@ spring-ai-vertex-ai-palm2 ${project.version} + + + org.springframework.ai + spring-ai-vertex-ai-embedding + ${project.version} + org.springframework.ai @@ -403,6 +409,12 @@ ${project.version} + + org.springframework.ai + spring-ai-vertex-ai-embedding-spring-boot-starter + ${project.version} + + org.springframework.ai spring-ai-vertex-ai-gemini-spring-boot-starter diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 268f7c2f4dd..391685d3e11 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -30,6 +30,7 @@ import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Content; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * A document is a container for the content and metadata of a document. It also contains @@ -40,6 +41,8 @@ public class Document implements Content { public final static ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig(); + public final static String EMPTY_TEXT = ""; + /** * Unique ID */ @@ -93,7 +96,7 @@ public Document(String id, String content, Map metadata) { public Document(String id, String content, List media, Map metadata) { Assert.hasText(id, "id must not be null or empty"); - Assert.hasText(content, "content must not be null or empty"); + Assert.notNull(content, "content must not be null"); Assert.notNull(metadata, "metadata must not be null"); this.id = id; @@ -102,6 +105,74 @@ public Document(String id, String content, List media, Map media = new ArrayList<>(); + + private Map metadata = new HashMap<>(); + + private IdGenerator idGenerator = new RandomIdGenerator(); + + public Builder withIdGenerator(IdGenerator idGenerator) { + Assert.notNull(idGenerator, "idGenerator must not be null"); + this.idGenerator = idGenerator; + return this; + } + + public Builder withId(String id) { + Assert.hasText(id, "id must not be null or empty"); + this.id = id; + return this; + } + + public Builder withContent(String content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + public Builder withMedia(List media) { + Assert.notNull(media, "media must not be null"); + this.media = media; + return this; + } + + public Builder withMedia(Media media) { + Assert.notNull(media, "media must not be null"); + this.media.add(media); + return this; + } + + public Builder withMetadata(Map metadata) { + Assert.notNull(metadata, "metadata must not be null"); + this.metadata = metadata; + return this; + } + + public Builder withMetadata(String key, Object value) { + Assert.notNull(key, "key must not be null"); + Assert.notNull(value, "value must not be null"); + this.metadata.put(key, value); + return this; + } + + public Document build() { + if (!StringUtils.hasText(this.id)) { + this.id = this.idGenerator.generateId(content, metadata); + } + return new Document(id, content, media, metadata); + } + + } + public String getId() { return id; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java new file mode 100644 index 00000000000..eb4a8354004 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingModel.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.embedding; + +import org.springframework.ai.model.Model; + +/** + * EmbeddingModel is a generic interface for embedding models. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public interface DocumentEmbeddingModel extends Model { + + @Override + EmbeddingResponse call(DocumentEmbeddingRequest request); + + int dimensions(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java new file mode 100644 index 00000000000..c68a418bed6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/DocumentEmbeddingRequest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.embedding; + +import java.util.Arrays; +import java.util.List; + +import org.springframework.ai.document.Document; +import org.springframework.ai.model.ModelRequest; + +/** + * Represents a request to embed a list of documents. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class DocumentEmbeddingRequest implements ModelRequest> { + + private final List inputs; + + private final EmbeddingOptions options; + + public DocumentEmbeddingRequest(Document... inputs) { + this(Arrays.asList(inputs), EmbeddingOptions.EMPTY); + } + + public DocumentEmbeddingRequest(List inputs) { + this(inputs, EmbeddingOptions.EMPTY); + } + + public DocumentEmbeddingRequest(List inputs, EmbeddingOptions options) { + this.inputs = inputs; + this.options = options; + } + + @Override + public List getInstructions() { + return this.inputs; + } + + @Override + public EmbeddingOptions getOptions() { + return this.options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java index 7a929bfa7b0..292a8906d8b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java @@ -37,8 +37,19 @@ public class Embedding implements ModelResult> { * @param index the embedding index in a list of embeddings. */ public Embedding(List embedding, Integer index) { + this(embedding, index, EmbeddingResultMetadata.EMPTY); + } + + /** + * Creates a new {@link Embedding} instance. + * @param embedding the embedding vector values. + * @param index the embedding index in a list of embeddings. + * @param metadata the metadata associated with the embedding. + */ + public Embedding(List embedding, Integer index, EmbeddingResultMetadata metadata) { this.embedding = embedding; this.index = index; + this.metadata = metadata; } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java index 005c0be7fe0..9b7df810b39 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java @@ -16,10 +16,104 @@ package org.springframework.ai.embedding; import org.springframework.ai.model.ResultMetadata; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; /** * @author Christian Tzolov */ public class EmbeddingResultMetadata implements ResultMetadata { + public static EmbeddingResultMetadata EMPTY = new EmbeddingResultMetadata(); + + public enum ModalityType { + + TEXT, IMAGE, AUDIO, VIDEO; + + } + + /** + * The {@link MimeType} of the source data used to generate the embedding. + */ + private final ModalityType modalityType; + + private final String documentId; + + private final MimeType mimeType; + + private final Object documentData; + + public EmbeddingResultMetadata() { + this("", ModalityType.TEXT, MimeTypeUtils.TEXT_PLAIN, null); + } + + public EmbeddingResultMetadata(String documentId, ModalityType modalityType, MimeType mimeType, + Object documentData) { + Assert.notNull(modalityType, "ModalityType must not be null"); + Assert.notNull(mimeType, "MimeType must not be null"); + + this.documentId = documentId; + this.modalityType = modalityType; + this.mimeType = mimeType; + this.documentData = documentData; + } + + public ModalityType getModalityType() { + return this.modalityType; + } + + public MimeType getMimeType() { + return this.mimeType; + } + + public String getDocumentId() { + return this.documentId; + } + + public Object getDocumentData() { + return this.documentData; + } + + public static class ModalityUtils { + + private static MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); + + private static MimeType IMAGE_MIME_TYPE = MimeTypeUtils.parseMimeType("text/*"); + + private static MimeType VIDEO_MIME_TYPE = MimeTypeUtils.parseMimeType("video/*"); + + private static MimeType AUDIO_MIME_TYPE = MimeTypeUtils.parseMimeType("audio/*"); + + /** + * Infers the {@link ModalityType} of the source data used to generate the + * embedding using the source data {@link MimeType}. + * @param mimeType the {@link MimeType} of the source data. + * @return Returns the {@link ModalityType} of the source data used to generate + * the embedding. + */ + public static ModalityType getModalityType(MimeType mimeType) { + + if (mimeType == null) { + return ModalityType.TEXT; + } + + if (mimeType.isCompatibleWith(IMAGE_MIME_TYPE)) { + return ModalityType.IMAGE; + } + else if (mimeType.isCompatibleWith(AUDIO_MIME_TYPE)) { + return ModalityType.AUDIO; + } + else if (mimeType.isCompatibleWith(VIDEO_MIME_TYPE)) { + return ModalityType.VIDEO; + } + else if (mimeType.isCompatibleWith(TEXT_MIME_TYPE)) { + return ModalityType.TEXT; + } + + throw new IllegalArgumentException("Unsupported MimeType: " + mimeType); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ChatModelDescription.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ChatModelDescription.java new file mode 100644 index 00000000000..67aded900a3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ChatModelDescription.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * Marker interface, to be used to store info on the model such as the current context + * length. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public interface ChatModelDescription extends ModelDescription { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java b/spring-ai-core/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java new file mode 100644 index 00000000000..c1d003b275e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +/** + * @author Christian Tzolov + */ +public interface EmbeddingModelDescription extends ModelDescription { + + default int getDimensions() { + return -1; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java index 57b083fabaf..0335f341eba 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelDescription.java @@ -17,11 +17,15 @@ package org.springframework.ai.model; /** + * Describes an AI model's basic characteristics. Provides methods to retrieve the model's + * name, description, and version. + * * @author Christian Tzolov + * @since 1.0.0 */ public interface ModelDescription { - String getModelName(); + String getName(); default String getDescription() { return ""; @@ -31,8 +35,4 @@ default String getVersion() { return ""; } - default int getContextLength() { - return -1; - } - } diff --git a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties index 849a8ceacf4..b6fb61c96f3 100644 --- a/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties +++ b/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties @@ -1,5 +1,4 @@ # Map of embedding generative names and their dimensions -# OpenAI text-embedding-ada-002=1536 text-similarity-ada-001=1024 text-similarity-babbage-001=2048 @@ -17,4 +16,7 @@ code-search-ada-code-001=1024 code-search-ada-text-001=1024 code-search-babbage-code-001=2048 code-search-babbage-text-001=2048 -sentence-transformers/all-MiniLM-L6-v2=384 \ No newline at end of file +sentence-transformers/all-MiniLM-L6-v2=384 +text-embedding-004=768 +text-multilingual-embedding-002=768 +multimodalembedding@001=768 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java new file mode 100644 index 00000000000..2916bbb70fd --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.document; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Media; +import org.springframework.ai.document.id.IdGenerator; +import org.springframework.util.MimeTypeUtils; + +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class DocumentBuilderTests { + + private Document.Builder builder; + + @BeforeEach + void setUp() { + builder = Document.builder(); + } + + @Test + void testWithIdGenerator() { + IdGenerator mockGenerator = new IdGenerator() { + @Override + public String generateId(Object... contents) { + return "mockedId"; + } + }; + + Document.Builder result = builder.withIdGenerator(mockGenerator); + + assertThat(result).isSameAs(builder); + + Document document = result.withContent("Test content").withMetadata("key", "value").build(); + + assertThat(document.getId()).isEqualTo("mockedId"); + } + + @Test + void testWithIdGeneratorNull() { + assertThatThrownBy(() -> builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("idGenerator must not be null"); + } + + @Test + void testWithId() { + Document.Builder result = builder.withId("testId"); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getId()).isEqualTo("testId"); + } + + @Test + void testWithIdNullOrEmpty() { + assertThatThrownBy(() -> builder.withId(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id must not be null or empty"); + + assertThatThrownBy(() -> builder.withId("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id must not be null or empty"); + } + + @Test + void testWithContent() { + Document.Builder result = builder.withContent("Test content"); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getContent()).isEqualTo("Test content"); + } + + @Test + void testWithContentNull() { + assertThatThrownBy(() -> builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("content must not be null"); + } + + @Test + void testWithMediaList() { + List mediaList = getMediaList(); + Document.Builder result = builder.withMedia(mediaList); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getMedia()).isEqualTo(mediaList); + } + + @Test + void testWithMediaListNull() { + assertThatThrownBy(() -> builder.withMedia((List) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media must not be null"); + } + + @Test + void testWithMediaSingle() throws MalformedURLException { + URL mediaUrl = new URL("http://test"); + Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); + + Document.Builder result = builder.withMedia(media); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getMedia()).contains(media); + } + + @Test + void testWithMediaSingleNull() { + assertThatThrownBy(() -> builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media must not be null"); + } + + @Test + void testWithMetadataMap() { + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", 2); + Document.Builder result = builder.withMetadata(metadata); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getMetadata()).isEqualTo(metadata); + } + + @Test + void testWithMetadataMapNull() { + assertThatThrownBy(() -> builder.withMetadata((Map) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata must not be null"); + } + + @Test + void testWithMetadataKeyValue() { + Document.Builder result = builder.withMetadata("key", "value"); + + assertThat(result).isSameAs(builder); + assertThat(result.build().getMetadata()).containsEntry("key", "value"); + } + + @Test + void testWithMetadataKeyValueNull() { + assertThatThrownBy(() -> builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("key must not be null"); + + assertThatThrownBy(() -> builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("value must not be null"); + } + + @Test + void testBuildWithoutId() { + Document document = builder.withContent("Test content").build(); + + assertThat(document.getId()).isNotNull().isNotEmpty(); + assertThat(document.getContent()).isEqualTo("Test content"); + } + + @Test + void testBuildWithAllProperties() throws MalformedURLException { + + List mediaList = getMediaList(); + Map metadata = new HashMap<>(); + metadata.put("key", "value"); + + Document document = builder.withId("customId") + .withContent("Test content") + .withMedia(mediaList) + .withMetadata(metadata) + .build(); + + assertThat(document.getId()).isEqualTo("customId"); + assertThat(document.getContent()).isEqualTo("Test content"); + assertThat(document.getMedia()).isEqualTo(mediaList); + assertThat(document.getMetadata()).isEqualTo(metadata); + } + + private static List getMediaList() { + try { + URL mediaUrl1 = new URL("http://type1"); + URL mediaUrl2 = new URL("http://type2"); + Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1); + Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2); + List mediaList = List.of(media1, media2); + return mediaList; + } + catch (MalformedURLException e) { + throw new RuntimeException(e); + } + + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 8197d43db12..b8ce948358f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -35,11 +35,14 @@ **** xref:api/chat/functions/zhipuai-chat-functions.adoc[Function Calling] *** xref:api/chat/watsonx-ai-chat.adoc[Watsonx.AI] ** xref:api/embeddings.adoc[] +*** xref:api/embeddings/openai-embeddings.adoc[OpenAI] +*** xref:api/embeddings/ollama-embeddings.adoc[Ollama] +*** xref:api/embeddings/azure-openai-embeddings.adoc[Azure OpenAI] +*** xref:api/embeddings/postgresml-embeddings.adoc[PostgresML] *** xref:api/bedrock.adoc[Amazon Bedrock] **** xref:api/embeddings/bedrock-cohere-embedding.adoc[Cohere] **** xref:api/embeddings/bedrock-titan-embedding.adoc[Titan] *** xref:api/embeddings/azure-openai-embeddings.adoc[Azure OpenAI] -*** xref:api/embeddings/vertexai-embeddings.adoc[Google VertexAI PaLM2] *** xref:api/embeddings/mistralai-embeddings.adoc[Mistral AI] *** xref:api/embeddings/minimax-embeddings.adoc[MiniMax] *** xref:api/embeddings/ollama-embeddings.adoc[Ollama] @@ -47,6 +50,10 @@ *** xref:api/embeddings/openai-embeddings.adoc[OpenAI] *** xref:api/embeddings/postgresml-embeddings.adoc[PostgresML] *** xref:api/embeddings/qianfan-embeddings.adoc[QianFan] +*** VertexAI +**** xref:api/embeddings/vertexai-embeddings-text.adoc[Text Embedding] +**** xref:api/embeddings/vertexai-embeddings-multimodal.adoc[Multimodal Embedding] +**** xref:api/embeddings/vertexai-embeddings-palm2.adoc[PaLM2 Embedding] *** xref:api/embeddings/zhipuai-embeddings.adoc[ZhiPu AI] ** xref:api/imageclient.adoc[] *** xref:api/image/azure-openai-image.adoc[Azure OpenAI] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings.adoc index 3f84b9c42c6..c98c246f87a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings.adoc @@ -1,6 +1,14 @@ [[EmbeddingModel]] = Embeddings Model API +Embeddings are numerical representations of text, images, or videos that capture relationships between inputs. + +Embeddings work by converting text, image, and video into arrays of floating point numbers, called vectors. +These vectors are designed to capture the meaning of the text, images, and videos. +The length of the embedding array is called the vector's dimensionality. + +By calculating the numerical distance between the vector representations of two pieces of text, an application can determine the similarity between the objects used to generate the embedding vectors. + The `EmbeddingModel` interface is designed for straightforward integration with embedding models in AI and machine learning. Its primary function is to convert text into numerical vectors, commonly referred to as embeddings. These embeddings are crucial for various tasks such as semantic analysis and text classification. @@ -157,5 +165,6 @@ Internally the various `EmbeddingModel` implementations use different low-level * xref:api/embeddings/postgresml-embeddings.adoc[Spring AI PostgresML Embeddings] * xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings] * xref:api/embeddings/bedrock-titan-embedding.adoc[Spring AI Bedrock Titan Embeddings] -* xref:api/embeddings/vertexai-embeddings.adoc[Spring AI VertexAI PaLM2 Embeddings] +* xref:api/embeddings/vertexai-embeddings-text.adoc[Spring AI VertexAI Embeddings] +* xref:api/embeddings/vertexai-embeddings-palm2.adoc[Spring AI VertexAI PaLM2 Embeddings] * xref:api/embeddings/mistralai-embeddings.adoc[Spring AI Mistral AI Embeddings] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc new file mode 100644 index 00000000000..854b16b6599 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-multimodal.adoc @@ -0,0 +1,147 @@ += Google VertexAI Multimodal Embeddings + +NOTE: EXPERIMENTAL. Used for experimental purposes only. Not compatible yet with the `VectorStores`. + +Vertex AI supports two types of embeddings models, text and multimodal. +This document describes how to create a multimodal embedding using the Vertex AI link:https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings[Multimodal embeddings API]. + +The multimodal embeddings model generates 1408-dimension vectors based on the input you provide, which can include a combination of image, text, and video data. +The embedding vectors can then be used for subsequent tasks like image classification or video content moderation. + +The image embedding vector and text embedding vector are in the same semantic space with the same dimensionality. +Consequently, these vectors can be used interchangeably for use cases like searching image by text, or searching video by image. + +NOTE: The VertexAI Multimodal API imposes the link:https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings#api-limits[following limits]. + +TIP: For text-only embedding use cases, we recommend using the xref:api/embeddings/vertexai-embeddings-text.adoc[Vertex AI text-embeddings model] instead. + +== Prerequisites + +- Install the link:https://cloud.google.com/sdk/docs/install[gcloud] CLI, appropriate for you OS. +- Authenticate by running the following command. +Replace `PROJECT_ID` with your Google Cloud project ID and `ACCOUNT` with your Google Cloud username. + +[source] +---- +gcloud config set project && +gcloud auth application-default login +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the VertexAI Embedding Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-vertex-ai-embedding-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-vertex-ai-embedding-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +The prefix `spring.ai.vertex.ai.embedding` is used as the property prefix that lets you connect to VertexAI Embedding API. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.vertex.ai.embedding.project-id | Google Cloud Platform project ID | - +| spring.ai.vertex.ai.embedding.location | Region | - +| spring.ai.vertex.ai.embedding.apiEndpoint | Vertex AI Embedding API endpoint. | - + +|==== + +The prefix `spring.ai.vertex.ai.embedding.multimodal` is the property prefix that lets you configure the embedding model implementation for VertexAI Multimodal Embedding. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.vertex.ai.embedding.multimodal.enabled | Enable Vertex AI Embedding API model. | true +| spring.ai.vertex.ai.embedding.multimodal.options.model | You can get multimodal embeddings by using the following model: | multimodalembedding@001 +| spring.ai.vertex.ai.embedding.multimodal.options.dimensions | Specify lower-dimension embeddings. By default, an embedding request returns a 1408 float vector for a data type. You can also specify lower-dimension embeddings (128, 256, or 512 float vectors) for text and image data. | 1408 +| spring.ai.vertex.ai.embedding.multimodal.options.video-start-offset-sec | The start offset of the video segment in seconds. If not specified, it's calculated with max(0, endOffsetSec - 120). | - +| spring.ai.vertex.ai.embedding.multimodal.options.video-end-offset-sec | The end offset of the video segment in seconds. If not specified, it's calculated with min(video length, startOffSec + 120). If both startOffSec and endOffSec are specified, endOffsetSec is adjusted to min(startOffsetSec+120, endOffsetSec). | - +| spring.ai.vertex.ai.embedding.multimodal.options.video-interval-sec | The interval of the video the embedding will be generated. The minimum value for interval_sec is 4. +If the interval is less than 4, an InvalidArgumentError is returned. There are no limitations on the maximum value +of the interval. However, if the interval is larger than min(video length, 120s), it impacts the quality of the generated embeddings. Default value: 16. | - +|==== + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiMultimodalEmbeddingModel.java[VertexAiMultimodalEmbeddingModel] implements the `DocumentEmbeddingModel`. + +Add the `spring-ai-vertex-ai-embedding` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-vertex-ai-embedding + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-vertex-ai-embedding' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `VertexAiMultimodalEmbeddingModel` and use it for embeddings generations: + +[source,java] +---- +VertexAiEmbeddigConnectionDetails connectionDetails = + VertexAiEmbeddigConnectionDetails.builder() + .withProjectId(System.getenv()) + .withLocation(System.getenv()) + .build(); + +VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder() + .withModel(VertexAiMultimodalEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + +var embeddingModel = new VertexAiMultimodalEmbeddingModel(connectionDetails, options); + +Media imageMedial = new Media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.image.png")); +Media videoMedial = new Media(new MimeType("video", "mp4"), new ClassPathResource("/test.video.mp4")); + +var document = new Document("Explain what do you see on this video?", List.of(imageMedial, videoMedial), Map.of()); + +EmbeddingResponse embeddingResponse = embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + +DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document), + EmbeddingOptions.EMPTY); + +EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + +assertThat(embeddingResponse.getResults()).hasSize(3); +---- + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc similarity index 89% rename from spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings.adoc rename to spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc index 36b59bf3bbb..425baa1dca4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-palm2.adoc @@ -1,6 +1,12 @@ -= VertexAI Embeddings += Google VertexAI PaLM2 Embeddings -The link:https://developers.generativeai.google/api/rest/generativelanguage[Generative Language] PaLM API allows developers to build generative AI applications using the PaLM model. Large Language Models (LLMs) are a powerful, versatile type of machine learning model that enables computers to comprehend and generate natural language through a series of prompts. The PaLM API is based on Google's next generation LLM, PaLM. It excels at a variety of different tasks like code generation, reasoning, and writing. You can use the PaLM API to build generative AI applications for use cases like content generation, dialogue agents, summarization and classification systems, and more. +NOTE: For text-only embedding use cases, we recommend using the xref:api/embeddings/vertexai-embeddings-text.adoc[Vertex AI text-embeddings model] instead. + +The link:https://developers.generativeai.google/api/rest/generativelanguage[Generative Language] PaLM API allows developers to build generative AI applications using the PaLM model. +Large Language Models (LLMs) are a powerful, versatile type of machine learning model that enables computers to comprehend and generate natural language through a series of prompts. +The PaLM API is based on Google's next generation LLM, PaLM. +It excels at a variety of different tasks like code generation, reasoning, and writing. +You can use the PaLM API to build generative AI applications for use cases like content generation, dialogue agents, summarization and classification systems, and more. Based on the link:https://developers.generativeai.google/api/rest/generativelanguage/models[Models REST API]. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc new file mode 100644 index 00000000000..25394088479 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/vertexai-embeddings-text.adoc @@ -0,0 +1,162 @@ += Google VertexAI Text Embeddings + +Vertex AI supports two types of embeddings models, text and multimodal. +This document describes how to create a text embedding using the Vertex AI link:https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api[Text embeddings API]. + +Vertex AI text embeddings API uses dense vector representations. +Unlike sparse vectors, which tend to directly map words to numbers, dense vectors are designed to better represent the meaning of a piece of text. +The benefit of using dense vector embeddings in generative AI is that instead of searching for direct word or syntax matches, you can better search for passages that align to the meaning of the query, even if the passages don't use the same language. + +== Prerequisites + +- Install the link:https://cloud.google.com/sdk/docs/install[gcloud] CLI, appropriate for you OS. +- Authenticate by running the following command. +Replace `PROJECT_ID` with your Google Cloud project ID and `ACCOUNT` with your Google Cloud username. + +[source] +---- +gcloud config set project && +gcloud auth application-default login +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the VertexAI Embedding Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-vertex-ai-embedding-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-vertex-ai-embedding-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +The prefix `spring.ai.vertex.ai.embedding` is used as the property prefix that lets you connect to VertexAI Embedding API. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.vertex.ai.embedding.project-id | Google Cloud Platform project ID | - +| spring.ai.vertex.ai.embedding.location | Region | - +| spring.ai.vertex.ai.embedding.apiEndpoint | Vertex AI Embedding API endpoint. | - + +|==== + +The prefix `spring.ai.vertex.ai.embedding.text` is the property prefix that lets you configure the embedding model implementation for VertexAI Text Embedding. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.vertex.ai.embedding.text.enabled | Enable Vertex AI Embedding API model. | true +| spring.ai.vertex.ai.embedding.text.options.model | This is the link:https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models[Vertex Text Embedding model] to use | text-embedding-004 +| spring.ai.vertex.ai.embedding.text.options.task-type | The intended downstream application to help the model produce better quality embeddings. Available link:https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#request_body[task-types] | `RETRIEVAL_DOCUMENT` +| spring.ai.vertex.ai.embedding.text.options.title | Optional title, only valid with task_type=RETRIEVAL_DOCUMENT. | - +| spring.ai.vertex.ai.embedding.text.options.dimensions | The number of dimensions the resulting output embeddings should have. Supported for model version 004 and later. You can use this parameter to reduce the embedding size, for example, for storage optimization. | - +| spring.ai.vertex.ai.embedding.text.options.auto-truncate | When set to true, input text will be truncated. When set to false, an error is returned if the input text is longer than the maximum length supported by the model. | true +|==== + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-vertex-ai-embedding-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the VertexAi chat model: + +[source,application.properties] +---- +spring.ai.vertex.ai.embedding.project-id= +spring.ai.vertex.ai.embedding.location= +spring.ai.vertex.ai.embedding.text.options.model=text-embedding-004 +---- + + +This will create a `VertexAiTextEmbeddingModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the embedding model for embeddings generations. + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingModel embeddingModel; + + @Autowired + public EmbeddingController(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiTextEmbeddingModel.java[VertexAiTextEmbeddingModel] implements the `EmbeddingModel`. + +Add the `spring-ai-vertex-ai-embedding` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-vertex-ai-embedding + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-vertex-ai-embedding' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `VertexAiTextEmbeddingModel` and use it for text generations: + +[source,java] +---- +VertexAiEmbeddigConnectionDetails connectionDetails = + VertexAiEmbeddigConnectionDetails.builder() + .withProjectId(System.getenv()) + .withLocation(System.getenv()) + .build(); + +VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + +var embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, options); + +EmbeddingResponse embeddingResponse = embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc index 712bad741bd..25ee96900d1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc @@ -70,8 +70,13 @@ Initially starting as simple strings, prompts have evolved to include multiple m == Embeddings -Embeddings transform text into numerical arrays or vectors, enabling AI models to process and interpret language data. -This transformation from text to numbers is a key element in how AI interacts with and understands human language. +Embeddings are numerical representations of text, images, or videos that capture relationships between inputs. + +Embeddings work by converting text, image, and video into arrays of floating point numbers, called vectors. +These vectors are designed to capture the meaning of the text, images, and videos. +The length of the embedding array is called the vector's dimensionality. + +By calculating the numerical distance between the vector representations of two pieces of text, an application can determine the similarity between the objects used to generate the embedding vectors. image::spring-ai-embeddings.jpg[Embeddings, width=900, align="center"] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc index fe1bd9b065f..8c521ca2786 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc @@ -136,7 +136,8 @@ Each of the following sections in the documentation shows which dependencies you ** xref:api/embeddings/postgresml-embeddings.adoc[Spring AI PostgresML Embeddings] ** xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings] ** xref:api/embeddings/bedrock-titan-embedding.adoc[Spring AI Bedrock Titan Embeddings] -** xref:api/embeddings/vertexai-embeddings.adoc[Spring AI VertexAI Embeddings] +** xref:api/embeddings/vertexai-embeddings-text.adoc[Spring AI VertexAI Embeddings] +** xref:api/embeddings/vertexai-embeddings-palm2.adoc[Spring AI VertexAI PaLM2 Embeddings] ** xref:api/embeddings/mistralai-embeddings.adoc[Spring AI MistralAI Embeddings] === Chat Models diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 00794303de1..5fb52ecd501 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -1,6 +1,7 @@ + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> 4.0.0 org.springframework.ai @@ -116,7 +117,7 @@ ${oracle.version} true - + com.oracle.database.jdbc ucp ${oracle.version} @@ -214,6 +215,14 @@ true + + + org.springframework.ai + spring-ai-vertex-ai-embedding + ${project.parent.version} + true + + org.springframework.ai @@ -500,6 +509,6 @@ test - + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java new file mode 100644 index 00000000000..70aa7ec58f4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingAutoConfiguration.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.embedding; + +import java.io.IOException; + +import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails; +import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel; +import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import com.google.cloud.vertexai.VertexAI; + +/** + * Auto-configuration for Vertex AI Gemini Chat. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +@ConditionalOnClass({ VertexAI.class, VertexAiGeminiChatModel.class }) +@EnableConfigurationProperties({ VertexAiEmbeddingConnectionProperties.class, VertexAiTextEmbeddingProperties.class, + VertexAiMultimodalEmbeddingProperties.class, }) +public class VertexAiEmbeddingAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public VertexAiEmbeddigConnectionDetails connectionDetails( + VertexAiEmbeddingConnectionProperties connectionProperties) { + + Assert.hasText(connectionProperties.getProjectId(), "Vertex AI project-id must be set!"); + Assert.hasText(connectionProperties.getLocation(), "Vertex AI location must be set!"); + + var connectionBuilder = VertexAiEmbeddigConnectionDetails.builder() + .withProjectId(connectionProperties.getProjectId()) + .withLocation(connectionProperties.getLocation()); + + if (StringUtils.hasText(connectionProperties.getApiEndpoint())) { + connectionBuilder.withApiEndpoint(connectionProperties.getApiEndpoint()); + } + + return connectionBuilder.build(); + + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = VertexAiTextEmbeddingProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public VertexAiTextEmbeddingModel textEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails, + VertexAiTextEmbeddingProperties textEmbeddingProperties) throws IOException { + + return new VertexAiTextEmbeddingModel(connectionDetails, textEmbeddingProperties.getOptions()); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = VertexAiMultimodalEmbeddingProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public VertexAiMultimodalEmbeddingModel multimodalEmbedding(VertexAiEmbeddigConnectionDetails connectionDetails, + VertexAiMultimodalEmbeddingProperties multimodalEmbeddingProperties) throws IOException { + + return new VertexAiMultimodalEmbeddingModel(connectionDetails, multimodalEmbeddingProperties.getOptions()); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java new file mode 100644 index 00000000000..0073f569046 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiEmbeddingConnectionProperties.java @@ -0,0 +1,84 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.embedding; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.core.io.Resource; + +/** + * Configuration properties for Vertex AI Embedding. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +@ConfigurationProperties(VertexAiEmbeddingConnectionProperties.CONFIG_PREFIX) +public class VertexAiEmbeddingConnectionProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.embedding"; + + /** + * Vertex AI Gemini project ID. + */ + private String projectId; + + /** + * Vertex AI Gemini location. + */ + private String location; + + /** + * URI to Vertex AI Gemini credentials (optional) + */ + private Resource credentialsUri; + + /** + * Vertex AI Gemini API endpoint. + */ + private String apiEndpoint; + + public String getProjectId() { + return this.projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getLocation() { + return this.location; + } + + public void setLocation(String location) { + this.location = location; + } + + public Resource getCredentialsUri() { + return this.credentialsUri; + } + + public void setCredentialsUri(Resource credentialsUri) { + this.credentialsUri = credentialsUri; + } + + public String getApiEndpoint() { + return this.apiEndpoint; + } + + public void setApiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java new file mode 100644 index 00000000000..6d08403f56e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiMultimodalEmbeddingProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.embedding; + +import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for Vertex AI Gemini Chat. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +@ConfigurationProperties(VertexAiMultimodalEmbeddingProperties.CONFIG_PREFIX) +public class VertexAiMultimodalEmbeddingProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.embedding.multimodal"; + + private boolean enabled = true; + + /** + * Vertex AI Text Embedding API options. + */ + private VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder() + .withModel(VertexAiMultimodalEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + + public VertexAiMultimodalEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(VertexAiMultimodalEmbeddingOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java new file mode 100644 index 00000000000..102548521d0 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingProperties.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.embedding; + +import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for Vertex AI Gemini Chat. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +@ConfigurationProperties(VertexAiTextEmbeddingProperties.CONFIG_PREFIX) +public class VertexAiTextEmbeddingProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.embedding.text"; + + private boolean enabled = true; + + /** + * Vertex AI Text Embedding API options. + */ + private VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .withTaskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT) + .withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME) + .build(); + + public VertexAiTextEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(VertexAiTextEmbeddingOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 297bbfc7913..c6c8d7a2882 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -41,3 +41,4 @@ org.springframework.ai.autoconfigure.vectorstore.opensearch.OpenSearchVectorStor org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration +org.springframework.ai.autoconfigure.vertexai.embedding.VertexAiEmbeddingAutoConfiguration \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java new file mode 100644 index 00000000000..ca030ab9e9f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/embedding/VertexAiTextEmbeddingModelAutoConfigurationIT.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.embedding; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.io.TempDir; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.DocumentEmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResultMetadata; +import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModel; +import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +/** + * @author Christian Tzolov + */ +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +public class VertexAiTextEmbeddingModelAutoConfigurationIT { + + @TempDir + File tempDir; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.vertex.ai.embedding.project-id=" + System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"), + "spring.ai.vertex.ai.embedding.location=" + System.getenv("VERTEX_AI_GEMINI_LOCATION")) + .withConfiguration(AutoConfigurations.of(VertexAiEmbeddingAutoConfiguration.class)); + + @Test + public void textEmbedding() { + contextRunner.run(context -> { + var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); + var textEmbeddingProperties = context.getBean(VertexAiTextEmbeddingProperties.class); + + assertThat(conntectionProperties).isNotNull(); + assertThat(textEmbeddingProperties.isEnabled()).isTrue(); + + VertexAiTextEmbeddingModel embeddingModel = context.getBean(VertexAiTextEmbeddingModel.class); + assertThat(embeddingModel).isInstanceOf(VertexAiTextEmbeddingModel.class); + + List> embeddings = embeddingModel.embed(List.of("Spring Framework", "Spring AI")); + + assertThat(embeddings.size()).isEqualTo(2); // batch size + assertThat(embeddings.get(0).size()).isEqualTo(embeddingModel.dimensions()); + }); + } + + @Test + void textEmbeddingActivation() { + contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=false").run(context -> { + assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isEmpty(); + }); + + contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.text.enabled=true").run(context -> { + assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); + }); + + contextRunner.run(context -> { + assertThat(context.getBeansOfType(VertexAiTextEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiTextEmbeddingModel.class)).isNotEmpty(); + }); + + } + + @Test + public void multimodalEmbedding() { + contextRunner.run(context -> { + var conntectionProperties = context.getBean(VertexAiEmbeddingConnectionProperties.class); + var multimodalEmbeddingProperties = context.getBean(VertexAiMultimodalEmbeddingProperties.class); + + assertThat(conntectionProperties).isNotNull(); + assertThat(multimodalEmbeddingProperties.isEnabled()).isTrue(); + + VertexAiMultimodalEmbeddingModel multiModelEmbeddingModel = context + .getBean(VertexAiMultimodalEmbeddingModel.class); + + assertThat(multiModelEmbeddingModel).isNotNull(); + + var document = new Document("Hello World"); + + DocumentEmbeddingRequest embeddingRequest = new DocumentEmbeddingRequest(List.of(document), + EmbeddingOptions.EMPTY); + + EmbeddingResponse embeddingResponse = multiModelEmbeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); + assertThat(embeddingResponse.getResults().get(0).getMetadata().getModalityType()) + .isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408); + + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001"); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0); + + assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408); + + }); + } + + @Test + void multimodalEmbeddingActivation() { + contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=false").run(context -> { + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isEmpty(); + }); + + contextRunner.withPropertyValues("spring.ai.vertex.ai.embedding.multimodal.enabled=true").run(context -> { + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); + }); + + contextRunner.run(context -> { + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiMultimodalEmbeddingModel.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml new file mode 100644 index 00000000000..f59c3533000 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-vertex-ai-embedding-spring-boot-starter + jar + Spring AI Starter - VertexAI Embedding + Spring AI Vertex Embedding AI Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-vertex-ai-embedding + ${project.parent.version} + + + +