diff --git a/pom.xml b/pom.xml index 5a1b3f44cfb..edad7bcc44e 100644 --- a/pom.xml +++ b/pom.xml @@ -56,6 +56,7 @@ vector-stores/spring-ai-azure vector-stores/spring-ai-weaviate vector-stores/spring-ai-redis + vector-stores/spring-ai-gemfire spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2 spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini vector-stores/spring-ai-qdrant diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 8a807321a62..7457c27982b 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -186,6 +186,12 @@ ${project.version} + + org.springframework.ai + spring-ai-gemfire + ${project.version} + + org.springframework.ai diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 61bb5337157..b8af5752769 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -48,6 +48,7 @@ *** xref:api/vectordbs/redis.adoc[] *** xref:api/vectordbs/pinecone.adoc[] *** xref:api/vectordbs/qdrant.adoc[] +*** xref:api/vectordbs/gemfire.adoc[GemFire] ** xref:api/functions.adoc[Function Calling] ** xref:api/prompt.adoc[] ** xref:api/output-parser.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc new file mode 100644 index 00000000000..dff91d57515 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc @@ -0,0 +1,122 @@ += GemFire Vector Store + +This section walks you through setting up the GemFire VectorStore to store document embeddings and perform similarity searches. + +link:https://tanzu.vmware.com/gemfire[GemFire] is an ultra high speed in-memory data and compute grid, with vector extensions to store and search vectors efficiently. + +link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector searches through a distributed and resilient infrastructure: + +Capabilities: +- Create Indexes +- Store vectors and the associated metadata +- Perform vector searches based on similarity + +== Prerequisites + +Access to a GemFire cluster with the link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[GemFire Vector Database] extension installed. +You can download the GemFire VectorDB extension from the link:https://network.pivotal.io/products/gemfire-vectordb/[VMware Tanzu Network] after signing in. + +== Dependencies + +Add these dependencies to your project: + +- Embedding Client boot starter, required for calculating embeddings. +- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions. + +[source,xml] +---- + + org.springframework.ai + spring-ai-transformers + +---- + +- Add the GemFire VectorDB dependencies + +[source,xml] +---- + + org.springframework.ai + spring-ai-gemfire + +---- + + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + + +== Sample Code + +- To configure GemFire in your application, use the following setup: + +[source,java] +---- +@Bean +public GemFireVectorStoreConfig gemFireVectorStoreConfig() { + return GemFireVectorStoreConfig.builder() + .withUrl("http://localhost:8080") + .withIndexName("spring-ai-test-index") + .build(); +} +---- + +- Create a GemFireVectorStore instance connected to your GemFire VectorDB: + +[source,java] +---- +@Bean +public VectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingClient embeddingClient) { + return new GemFireVectorStore(config, embeddingClient); +} +---- +- Create a Vector Index which will configure GemFire region. + +[source,java] +---- + public void createIndex() { + try { + CreateRequest createRequest = new CreateRequest(); + createRequest.setName(INDEX_NAME); + createRequest.setBeamWidth(20); + createRequest.setMaxConnections(16); + ObjectMapper objectMapper = new ObjectMapper(); + String index = objectMapper.writeValueAsString(createRequest); + client.post() + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(index) + .retrieve() + .bodyToMono(Void.class) + .block(); + } + catch (Exception e) { + logger.warn("An unexpected error occurred while creating the index"); + } + } +---- + +- Create some documents: + +[source,java] +---- + List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); +---- + +- Add the documents to GemFire VectorDB: + +[source,java] +---- +vectorStore.add(List.of(document)); +---- + +- And finally, retrieve documents similar to a query: + +[source,java] +---- + List results = vectorStore.similaritySearch("Spring", 5); +---- + +If all goes well, you should retrieve the document containing the text "Spring AI rocks!!". + diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 807c9ca7c15..de659532398 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -323,7 +323,6 @@ test - - + diff --git a/vector-stores/spring-ai-gemfire/README.md b/vector-stores/spring-ai-gemfire/README.md new file mode 100644 index 00000000000..9adf1686a0e --- /dev/null +++ b/vector-stores/spring-ai-gemfire/README.md @@ -0,0 +1 @@ +[GemFire Vector Store Documentation](https://docs.spring.io/spring-ai/reference/api/vectordbs/gemfire.html) \ No newline at end of file diff --git a/vector-stores/spring-ai-gemfire/pom.xml b/vector-stores/spring-ai-gemfire/pom.xml new file mode 100644 index 00000000000..8b4e6407b70 --- /dev/null +++ b/vector-stores/spring-ai-gemfire/pom.xml @@ -0,0 +1,81 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-gemfire + jar + Spring AI Vector Store - GemFire + Spring AI GemFire Vector Store + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 17 + 17 + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework + spring-webflux + + + + + org.springframework.ai + spring-ai-openai + ${parent.version} + test + + + + org.springframework.ai + spring-ai-test + ${parent.version} + test + + + + org.springframework.ai + spring-ai-transformers + ${parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.awaitility + awaitility + 3.0.0 + test + + + org.apache.logging.log4j + log4j-core + + + + + diff --git a/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java new file mode 100644 index 00000000000..df7e08ec59d --- /dev/null +++ b/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -0,0 +1,530 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import static org.springframework.http.HttpStatus.BAD_REQUEST; +import static org.springframework.http.HttpStatus.NOT_FOUND; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientException; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import org.springframework.web.util.UriComponentsBuilder; +import reactor.util.annotation.NonNull; + +/** + * A VectorStore implementation backed by GemFire. This store supports creating, updating, + * deleting, and similarity searching of documents in a GemFire index. + * + * @author Geet Rawat + */ +public class GemFireVectorStore implements VectorStore { + + public static final String QUERY = "/query"; + + private static final Logger logger = LoggerFactory.getLogger(GemFireVectorStore.class); + + private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; + + private static final String EMBEDDINGS = "/embeddings"; + + private final WebClient client; + + private final EmbeddingClient embeddingClient; + + private final int topKPerBucket; + + private final int topK; + + private final String documentField; + + public static final class GemFireVectorStoreConfig { + + private final WebClient client; + + private final String index; + + private final int topKPerBucket; + + public final int topK; + + private final String documentField; + + public static Builder builder() { + return new Builder(); + } + + private GemFireVectorStoreConfig(Builder builder) { + String base = UriComponentsBuilder.fromUriString(DEFAULT_URI) + .build(builder.sslEnabled ? "s" : "", builder.host, builder.port) + .toString(); + this.index = builder.index; + this.client = WebClient.create(base); + this.topKPerBucket = builder.topKPerBucket; + this.topK = builder.topK; + this.documentField = builder.documentField; + } + + public static class Builder { + + private String host; + + private int port = DEFAULT_PORT; + + private boolean sslEnabled; + + private long connectionTimeout; + + private long requestTimeout; + + private String index; + + private int topKPerBucket = DEFAULT_TOP_K_PER_BUCKET; + + private int topK = DEFAULT_TOP_K; + + private String documentField = DEFAULT_DOCUMENT_FIELD; + + public Builder withHost(String host) { + Assert.hasText(host, "host must have a value"); + this.host = host; + return this; + } + + public Builder withPort(int port) { + Assert.isTrue(port > 0, "port must be postive"); + this.port = port; + return this; + } + + public Builder withSslEnabled(boolean sslEnabled) { + this.sslEnabled = sslEnabled; + return this; + } + + public Builder withConnectionTimeout(long timeout) { + Assert.isTrue(timeout >= 0, "timeout must be >= 0"); + this.connectionTimeout = timeout; + return this; + } + + public Builder withRequestTimeout(long timeout) { + Assert.isTrue(timeout >= 0, "timeout must be >= 0"); + this.requestTimeout = timeout; + return this; + } + + public Builder withIndex(String index) { + Assert.hasText(index, "index must have a value"); + this.index = index; + return this; + } + + public Builder withTopKPerBucket(int topKPerBucket) { + Assert.isTrue(topKPerBucket > 0, "topKPerBucket must be positive"); + this.topKPerBucket = topKPerBucket; + return this; + } + + public Builder withTopK(int topK) { + Assert.isTrue(topK > 0, "topK must be positive"); + this.topK = topK; + return this; + } + + public Builder withDocumentField(String documentField) { + Assert.hasText(documentField, "documentField must have a value"); + this.documentField = documentField; + return this; + } + + public GemFireVectorStoreConfig build() { + return new GemFireVectorStoreConfig(this); + } + + } + + } + + private static final int DEFAULT_PORT = 9090; + + public static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes"; + + private static final int DEFAULT_TOP_K_PER_BUCKET = 10; + + private static final int DEFAULT_TOP_K = 10; + + private static final String DEFAULT_DOCUMENT_FIELD = "document"; + + public String indexName; + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingClient embedding) { + Assert.notNull(config, "GemFireVectorStoreConfig must not be null"); + Assert.notNull(embedding, "EmbeddingClient must not be null"); + this.client = config.client; + this.embeddingClient = embedding; + this.topKPerBucket = config.topKPerBucket; + this.topK = config.topK; + this.documentField = config.documentField; + } + + private static final class CreateRequest { + + @JsonProperty("name") + private String name; + + @JsonProperty("beam-width") + private int beamWidth = 100; + + @JsonProperty("max-connections") + private int maxConnections = 16; + + @JsonProperty("vector-similarity-function") + private String vectorSimilarityFunction = "COSINE"; + + @JsonProperty("fields") + private String[] fields = new String[] { "vector" }; + + @JsonProperty("buckets") + private int buckets = 0; + + public CreateRequest() { + } + + public CreateRequest(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getBeamWidth() { + return beamWidth; + } + + public void setBeamWidth(int beamWidth) { + this.beamWidth = beamWidth; + } + + public int getMaxConnections() { + return maxConnections; + } + + public void setMaxConnections(int maxConnections) { + this.maxConnections = maxConnections; + } + + public String getVectorSimilarityFunction() { + return vectorSimilarityFunction; + } + + public void setVectorSimilarityFunction(String vectorSimilarityFunction) { + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + public String[] getFields() { + return fields; + } + + public void setFields(String[] fields) { + this.fields = fields; + } + + public int getBuckets() { + return buckets; + } + + public void setBuckets(int buckets) { + this.buckets = buckets; + } + + } + + private static final class UploadRequest { + + private final List embeddings; + + public List getEmbeddings() { + return embeddings; + } + + @JsonCreator + public UploadRequest(@JsonProperty("embeddings") List embeddings) { + this.embeddings = embeddings; + } + + private static final class Embedding { + + private final String key; + + private List vector; + + @JsonInclude(JsonInclude.Include.NON_NULL) + private Map metadata; + + public Embedding(@JsonProperty("key") String key, @JsonProperty("vector") List vector, + String contentName, String content, @JsonProperty("metadata") Map metadata) { + this.key = key; + this.vector = vector; + this.metadata = new HashMap<>(metadata); + this.metadata.put(contentName, content); + } + + public String getKey() { + return key; + } + + public List getVector() { + return vector; + } + + public Map getMetadata() { + return metadata; + } + + } + + } + + private static final class QueryRequest { + + @JsonProperty("vector") + @NonNull + private final List vector; + + @JsonProperty("top-k") + private final int k; + + @JsonProperty("k-per-bucket") + private final int kPerBucket; + + @JsonProperty("include-metadata") + private final boolean includeMetadata; + + public QueryRequest(List vector, int k, int kPerBucket, boolean includeMetadata) { + this.vector = vector; + this.k = k; + this.kPerBucket = kPerBucket; + this.includeMetadata = includeMetadata; + } + + public List getVector() { + return vector; + } + + public int getK() { + return k; + } + + public int getkPerBucket() { + return kPerBucket; + } + + public boolean isIncludeMetadata() { + return includeMetadata; + } + + } + + private static final class QueryResponse { + + private String key; + + private float score; + + private Map metadata; + + private String getContent(String field) { + return (String) metadata.get(field); + } + + public void setKey(String key) { + this.key = key; + } + + public void setScore(float score) { + this.score = score; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata; + } + + } + + private static class DeleteRequest { + + @JsonProperty("delete-data") + private boolean deleteData = true; + + public DeleteRequest() { + } + + public DeleteRequest(boolean deleteData) { + this.deleteData = deleteData; + } + + public boolean isDeleteData() { + return deleteData; + } + + public void setDeleteData(boolean deleteData) { + this.deleteData = deleteData; + } + + } + + @Override + public void add(List documents) { + UploadRequest upload = new UploadRequest(documents.stream().map(document -> { + // Compute and assign an embedding to the document. + document.setEmbedding(this.embeddingClient.embed(document)); + List floatVector = document.getEmbedding().stream().map(Double::floatValue).toList(); + return new UploadRequest.Embedding(document.getId(), floatVector, documentField, document.getContent(), + document.getMetadata()); + }).toList()); + + ObjectMapper objectMapper = new ObjectMapper(); + String embeddingsJson = null; + try { + String embeddingString = objectMapper.writeValueAsString(upload); + embeddingsJson = embeddingString.substring("{\"embeddings\":".length()); + } + catch (JsonProcessingException e) { + throw new RuntimeException(String.format("Embedding JSON parsing error: %s", e.getMessage())); + } + + client.post() + .uri("/" + indexName + EMBEDDINGS) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(embeddingsJson) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + @Override + public Optional delete(List idList) { + try { + client.method(HttpMethod.DELETE) + .uri("/" + indexName + EMBEDDINGS) + .body(BodyInserters.fromValue(idList)) + .retrieve() + .bodyToMono(Void.class) + .block(); + } + catch (Exception e) { + logger.warn("Error removing embedding: " + e); + return Optional.of(false); + } + return Optional.of(true); + } + + @Override + public List similaritySearch(SearchRequest request) { + if (request.hasFilterExpression()) { + throw new UnsupportedOperationException("Gemfire does not support metadata filter expressions yet."); + } + List vector = this.embeddingClient.embed(request.getQuery()); + List floatVector = vector.stream().map(Double::floatValue).toList(); + + return client.post() + .uri("/" + indexName + QUERY) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(new QueryRequest(floatVector, request.getTopK(), topKPerBucket, true)) + .retrieve() + .bodyToFlux(QueryResponse.class) + .filter(r -> r.score >= request.getSimilarityThreshold()) + .map(r -> { + Map metadata = r.metadata; + metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); + String content = (String) metadata.remove(documentField); + return new Document(r.key, content, metadata); + }) + .collectList() + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + public void createIndex(String indexName) throws JsonProcessingException { + CreateRequest createRequest = new CreateRequest(indexName); + ObjectMapper objectMapper = new ObjectMapper(); + String index = objectMapper.writeValueAsString(createRequest); + client.post() + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(index) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + public void deleteIndex(String indexName) { + DeleteRequest deleteRequest = new DeleteRequest(); + deleteRequest.setDeleteData(true); + client.method(HttpMethod.DELETE) + .uri("/" + indexName) + .body(BodyInserters.fromValue(deleteRequest)) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + private Throwable handleHttpClientException(Throwable ex) { + if (!(ex instanceof WebClientResponseException clientException)) { + throw new RuntimeException(String.format("Got an unexpected error: %s", ex)); + } + + if (clientException.getStatusCode().equals(NOT_FOUND)) { + throw new RuntimeException(String.format("Index %s not found: %s", indexName, ex)); + } + else if (clientException.getStatusCode().equals(BAD_REQUEST)) { + throw new RuntimeException(String.format("Bad Request: %s", ex)); + } + else { + throw new RuntimeException(String.format("Got an unexpected HTTP error: %s", ex)); + } + } + +} diff --git a/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java new file mode 100644 index 00000000000..49d03254448 --- /dev/null +++ b/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -0,0 +1,205 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.transformers.TransformersEmbeddingClient; +import org.springframework.ai.vectorstore.GemFireVectorStore.GemFireVectorStoreConfig; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.DefaultResourceLoader; + +/** + * @author Geet Rawat + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "GEMFIRE_HOST", matches = ".+") +public class GemFireVectorStoreIT { + + public static final String INDEX_NAME = "spring-ai-index1"; + + List documents = List.of( + new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + @BeforeEach + public void createIndex() { + contextRunner.run(c -> c.getBean(GemFireVectorStore.class).createIndex(INDEX_NAME)); + } + + @AfterEach + public void deleteIndex() { + contextRunner.run(c -> c.getBean(GemFireVectorStore.class).deleteIndex(INDEX_NAME)); + } + + @Test + public void addAndDeleteEmbeddingTest() { + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(documents); + vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + Awaitility.await().atMost(1, MINUTES).until(() -> { + return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)); + }, hasSize(0)); + }); + } + + @Test + public void addAndSearchTest() { + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(documents); + + Awaitility.await().atMost(1, MINUTES).until(() -> { + return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + }, hasSize(1)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5)); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + }); + } + + @Test + public void documentUpdateTest() { + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Collections.singletonMap("meta1", "meta1")); + vectorStore.add(List.of(document)); + SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5); + Awaitility.await().atMost(1, MINUTES).until(() -> { + return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + }, hasSize(1)); + List results = vectorStore.similaritySearch(springSearchRequest); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", + Collections.singletonMap("meta2", "meta2")); + + vectorStore.add(List.of(sameIdDocument)); + SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); + results = vectorStore.similaritySearch(fooBarSearchRequest); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + }); + } + + @Test + public void searchThresholdTest() { + + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(documents); + + Awaitility.await().atMost(1, MINUTES).until(() -> { + return vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()); + }, hasSize(3)); + + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); + + List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + assertThat(distances).hasSize(3); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + List results = vectorStore + .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration + public static class TestApplication { + + @Bean + public GemFireVectorStoreConfig gemfireVectorStoreConfig() { + return GemFireVectorStoreConfig.builder().withHost("localhost").build(); + } + + @Bean + public GemFireVectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingClient embeddingClient) { + GemFireVectorStore gemFireVectorStore = new GemFireVectorStore(config, embeddingClient); + gemFireVectorStore.setIndexName(INDEX_NAME); + return gemFireVectorStore; + } + + @Bean + public EmbeddingClient embeddingClient() { + return new TransformersEmbeddingClient(); + } + + } + +} \ No newline at end of file