diff --git a/pom.xml b/pom.xml
index 5a1b3f44cfb..edad7bcc44e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -56,6 +56,7 @@
vector-stores/spring-ai-azure
vector-stores/spring-ai-weaviate
vector-stores/spring-ai-redis
+ vector-stores/spring-ai-gemfire
spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2
spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini
vector-stores/spring-ai-qdrant
diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml
index 8a807321a62..7457c27982b 100644
--- a/spring-ai-bom/pom.xml
+++ b/spring-ai-bom/pom.xml
@@ -186,6 +186,12 @@
${project.version}
+
+ org.springframework.ai
+ spring-ai-gemfire
+ ${project.version}
+
+
org.springframework.ai
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
index 61bb5337157..b8af5752769 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
@@ -48,6 +48,7 @@
*** xref:api/vectordbs/redis.adoc[]
*** xref:api/vectordbs/pinecone.adoc[]
*** xref:api/vectordbs/qdrant.adoc[]
+*** xref:api/vectordbs/gemfire.adoc[GemFire]
** xref:api/functions.adoc[Function Calling]
** xref:api/prompt.adoc[]
** xref:api/output-parser.adoc[]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc
new file mode 100644
index 00000000000..dff91d57515
--- /dev/null
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc
@@ -0,0 +1,122 @@
+= GemFire Vector Store
+
+This section walks you through setting up the GemFire VectorStore to store document embeddings and perform similarity searches.
+
+link:https://tanzu.vmware.com/gemfire[GemFire] is an ultra high speed in-memory data and compute grid, with vector extensions to store and search vectors efficiently.
+
+link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector searches through a distributed and resilient infrastructure:
+
+Capabilities:
+- Create Indexes
+- Store vectors and the associated metadata
+- Perform vector searches based on similarity
+
+== Prerequisites
+
+Access to a GemFire cluster with the link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[GemFire Vector Database] extension installed.
+You can download the GemFire VectorDB extension from the link:https://network.pivotal.io/products/gemfire-vectordb/[VMware Tanzu Network] after signing in.
+
+== Dependencies
+
+Add these dependencies to your project:
+
+- Embedding Client boot starter, required for calculating embeddings.
+- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions.
+
+[source,xml]
+----
+
+ org.springframework.ai
+ spring-ai-transformers
+
+----
+
+- Add the GemFire VectorDB dependencies
+
+[source,xml]
+----
+
+ org.springframework.ai
+ spring-ai-gemfire
+
+----
+
+
+TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
+
+
+== Sample Code
+
+- To configure GemFire in your application, use the following setup:
+
+[source,java]
+----
+@Bean
+public GemFireVectorStoreConfig gemFireVectorStoreConfig() {
+ return GemFireVectorStoreConfig.builder()
+ .withUrl("http://localhost:8080")
+ .withIndexName("spring-ai-test-index")
+ .build();
+}
+----
+
+- Create a GemFireVectorStore instance connected to your GemFire VectorDB:
+
+[source,java]
+----
+@Bean
+public VectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingClient embeddingClient) {
+ return new GemFireVectorStore(config, embeddingClient);
+}
+----
+- Create a Vector Index which will configure GemFire region.
+
+[source,java]
+----
+ public void createIndex() {
+ try {
+ CreateRequest createRequest = new CreateRequest();
+ createRequest.setName(INDEX_NAME);
+ createRequest.setBeamWidth(20);
+ createRequest.setMaxConnections(16);
+ ObjectMapper objectMapper = new ObjectMapper();
+ String index = objectMapper.writeValueAsString(createRequest);
+ client.post()
+ .contentType(MediaType.APPLICATION_JSON)
+ .bodyValue(index)
+ .retrieve()
+ .bodyToMono(Void.class)
+ .block();
+ }
+ catch (Exception e) {
+ logger.warn("An unexpected error occurred while creating the index");
+ }
+ }
+----
+
+- Create some documents:
+
+[source,java]
+----
+ List documents = List.of(
+ new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
+ new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
+ new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
+----
+
+- Add the documents to GemFire VectorDB:
+
+[source,java]
+----
+vectorStore.add(List.of(document));
+----
+
+- And finally, retrieve documents similar to a query:
+
+[source,java]
+----
+ List results = vectorStore.similaritySearch("Spring", 5);
+----
+
+If all goes well, you should retrieve the document containing the text "Spring AI rocks!!".
+
diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml
index 807c9ca7c15..de659532398 100644
--- a/spring-ai-spring-boot-autoconfigure/pom.xml
+++ b/spring-ai-spring-boot-autoconfigure/pom.xml
@@ -323,7 +323,6 @@
test
-
-
+
diff --git a/vector-stores/spring-ai-gemfire/README.md b/vector-stores/spring-ai-gemfire/README.md
new file mode 100644
index 00000000000..9adf1686a0e
--- /dev/null
+++ b/vector-stores/spring-ai-gemfire/README.md
@@ -0,0 +1 @@
+[GemFire Vector Store Documentation](https://docs.spring.io/spring-ai/reference/api/vectordbs/gemfire.html)
\ No newline at end of file
diff --git a/vector-stores/spring-ai-gemfire/pom.xml b/vector-stores/spring-ai-gemfire/pom.xml
new file mode 100644
index 00000000000..8b4e6407b70
--- /dev/null
+++ b/vector-stores/spring-ai-gemfire/pom.xml
@@ -0,0 +1,81 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 1.0.0-SNAPSHOT
+ ../../pom.xml
+
+ spring-ai-gemfire
+ jar
+ Spring AI Vector Store - GemFire
+ Spring AI GemFire Vector Store
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+ 17
+ 17
+
+
+
+
+ org.springframework.ai
+ spring-ai-core
+ ${project.parent.version}
+
+
+
+ org.springframework
+ spring-webflux
+
+
+
+
+ org.springframework.ai
+ spring-ai-openai
+ ${parent.version}
+ test
+
+
+
+ org.springframework.ai
+ spring-ai-test
+ ${parent.version}
+ test
+
+
+
+ org.springframework.ai
+ spring-ai-transformers
+ ${parent.version}
+ test
+
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+ org.awaitility
+ awaitility
+ 3.0.0
+ test
+
+
+ org.apache.logging.log4j
+ log4j-core
+
+
+
+
+
diff --git a/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java
new file mode 100644
index 00000000000..df7e08ec59d
--- /dev/null
+++ b/vector-stores/spring-ai-gemfire/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java
@@ -0,0 +1,530 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.vectorstore;
+
+import static org.springframework.http.HttpStatus.BAD_REQUEST;
+import static org.springframework.http.HttpStatus.NOT_FOUND;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.util.Assert;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.client.WebClientException;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+import org.springframework.web.util.UriComponentsBuilder;
+import reactor.util.annotation.NonNull;
+
+/**
+ * A VectorStore implementation backed by GemFire. This store supports creating, updating,
+ * deleting, and similarity searching of documents in a GemFire index.
+ *
+ * @author Geet Rawat
+ */
+public class GemFireVectorStore implements VectorStore {
+
+ public static final String QUERY = "/query";
+
+ private static final Logger logger = LoggerFactory.getLogger(GemFireVectorStore.class);
+
+ private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
+
+ private static final String EMBEDDINGS = "/embeddings";
+
+ private final WebClient client;
+
+ private final EmbeddingClient embeddingClient;
+
+ private final int topKPerBucket;
+
+ private final int topK;
+
+ private final String documentField;
+
+ public static final class GemFireVectorStoreConfig {
+
+ private final WebClient client;
+
+ private final String index;
+
+ private final int topKPerBucket;
+
+ public final int topK;
+
+ private final String documentField;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ private GemFireVectorStoreConfig(Builder builder) {
+ String base = UriComponentsBuilder.fromUriString(DEFAULT_URI)
+ .build(builder.sslEnabled ? "s" : "", builder.host, builder.port)
+ .toString();
+ this.index = builder.index;
+ this.client = WebClient.create(base);
+ this.topKPerBucket = builder.topKPerBucket;
+ this.topK = builder.topK;
+ this.documentField = builder.documentField;
+ }
+
+ public static class Builder {
+
+ private String host;
+
+ private int port = DEFAULT_PORT;
+
+ private boolean sslEnabled;
+
+ private long connectionTimeout;
+
+ private long requestTimeout;
+
+ private String index;
+
+ private int topKPerBucket = DEFAULT_TOP_K_PER_BUCKET;
+
+ private int topK = DEFAULT_TOP_K;
+
+ private String documentField = DEFAULT_DOCUMENT_FIELD;
+
+ public Builder withHost(String host) {
+ Assert.hasText(host, "host must have a value");
+ this.host = host;
+ return this;
+ }
+
+ public Builder withPort(int port) {
+ Assert.isTrue(port > 0, "port must be postive");
+ this.port = port;
+ return this;
+ }
+
+ public Builder withSslEnabled(boolean sslEnabled) {
+ this.sslEnabled = sslEnabled;
+ return this;
+ }
+
+ public Builder withConnectionTimeout(long timeout) {
+ Assert.isTrue(timeout >= 0, "timeout must be >= 0");
+ this.connectionTimeout = timeout;
+ return this;
+ }
+
+ public Builder withRequestTimeout(long timeout) {
+ Assert.isTrue(timeout >= 0, "timeout must be >= 0");
+ this.requestTimeout = timeout;
+ return this;
+ }
+
+ public Builder withIndex(String index) {
+ Assert.hasText(index, "index must have a value");
+ this.index = index;
+ return this;
+ }
+
+ public Builder withTopKPerBucket(int topKPerBucket) {
+ Assert.isTrue(topKPerBucket > 0, "topKPerBucket must be positive");
+ this.topKPerBucket = topKPerBucket;
+ return this;
+ }
+
+ public Builder withTopK(int topK) {
+ Assert.isTrue(topK > 0, "topK must be positive");
+ this.topK = topK;
+ return this;
+ }
+
+ public Builder withDocumentField(String documentField) {
+ Assert.hasText(documentField, "documentField must have a value");
+ this.documentField = documentField;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig build() {
+ return new GemFireVectorStoreConfig(this);
+ }
+
+ }
+
+ }
+
+ private static final int DEFAULT_PORT = 9090;
+
+ public static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes";
+
+ private static final int DEFAULT_TOP_K_PER_BUCKET = 10;
+
+ private static final int DEFAULT_TOP_K = 10;
+
+ private static final String DEFAULT_DOCUMENT_FIELD = "document";
+
+ public String indexName;
+
+ public void setIndexName(String indexName) {
+ this.indexName = indexName;
+ }
+
+ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingClient embedding) {
+ Assert.notNull(config, "GemFireVectorStoreConfig must not be null");
+ Assert.notNull(embedding, "EmbeddingClient must not be null");
+ this.client = config.client;
+ this.embeddingClient = embedding;
+ this.topKPerBucket = config.topKPerBucket;
+ this.topK = config.topK;
+ this.documentField = config.documentField;
+ }
+
+ private static final class CreateRequest {
+
+ @JsonProperty("name")
+ private String name;
+
+ @JsonProperty("beam-width")
+ private int beamWidth = 100;
+
+ @JsonProperty("max-connections")
+ private int maxConnections = 16;
+
+ @JsonProperty("vector-similarity-function")
+ private String vectorSimilarityFunction = "COSINE";
+
+ @JsonProperty("fields")
+ private String[] fields = new String[] { "vector" };
+
+ @JsonProperty("buckets")
+ private int buckets = 0;
+
+ public CreateRequest() {
+ }
+
+ public CreateRequest(String name) {
+ this.name = name;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public int getBeamWidth() {
+ return beamWidth;
+ }
+
+ public void setBeamWidth(int beamWidth) {
+ this.beamWidth = beamWidth;
+ }
+
+ public int getMaxConnections() {
+ return maxConnections;
+ }
+
+ public void setMaxConnections(int maxConnections) {
+ this.maxConnections = maxConnections;
+ }
+
+ public String getVectorSimilarityFunction() {
+ return vectorSimilarityFunction;
+ }
+
+ public void setVectorSimilarityFunction(String vectorSimilarityFunction) {
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ }
+
+ public String[] getFields() {
+ return fields;
+ }
+
+ public void setFields(String[] fields) {
+ this.fields = fields;
+ }
+
+ public int getBuckets() {
+ return buckets;
+ }
+
+ public void setBuckets(int buckets) {
+ this.buckets = buckets;
+ }
+
+ }
+
+ private static final class UploadRequest {
+
+ private final List embeddings;
+
+ public List getEmbeddings() {
+ return embeddings;
+ }
+
+ @JsonCreator
+ public UploadRequest(@JsonProperty("embeddings") List embeddings) {
+ this.embeddings = embeddings;
+ }
+
+ private static final class Embedding {
+
+ private final String key;
+
+ private List vector;
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ private Map metadata;
+
+ public Embedding(@JsonProperty("key") String key, @JsonProperty("vector") List vector,
+ String contentName, String content, @JsonProperty("metadata") Map metadata) {
+ this.key = key;
+ this.vector = vector;
+ this.metadata = new HashMap<>(metadata);
+ this.metadata.put(contentName, content);
+ }
+
+ public String getKey() {
+ return key;
+ }
+
+ public List getVector() {
+ return vector;
+ }
+
+ public Map getMetadata() {
+ return metadata;
+ }
+
+ }
+
+ }
+
+ private static final class QueryRequest {
+
+ @JsonProperty("vector")
+ @NonNull
+ private final List vector;
+
+ @JsonProperty("top-k")
+ private final int k;
+
+ @JsonProperty("k-per-bucket")
+ private final int kPerBucket;
+
+ @JsonProperty("include-metadata")
+ private final boolean includeMetadata;
+
+ public QueryRequest(List vector, int k, int kPerBucket, boolean includeMetadata) {
+ this.vector = vector;
+ this.k = k;
+ this.kPerBucket = kPerBucket;
+ this.includeMetadata = includeMetadata;
+ }
+
+ public List getVector() {
+ return vector;
+ }
+
+ public int getK() {
+ return k;
+ }
+
+ public int getkPerBucket() {
+ return kPerBucket;
+ }
+
+ public boolean isIncludeMetadata() {
+ return includeMetadata;
+ }
+
+ }
+
+ private static final class QueryResponse {
+
+ private String key;
+
+ private float score;
+
+ private Map metadata;
+
+ private String getContent(String field) {
+ return (String) metadata.get(field);
+ }
+
+ public void setKey(String key) {
+ this.key = key;
+ }
+
+ public void setScore(float score) {
+ this.score = score;
+ }
+
+ public void setMetadata(Map metadata) {
+ this.metadata = metadata;
+ }
+
+ }
+
+ private static class DeleteRequest {
+
+ @JsonProperty("delete-data")
+ private boolean deleteData = true;
+
+ public DeleteRequest() {
+ }
+
+ public DeleteRequest(boolean deleteData) {
+ this.deleteData = deleteData;
+ }
+
+ public boolean isDeleteData() {
+ return deleteData;
+ }
+
+ public void setDeleteData(boolean deleteData) {
+ this.deleteData = deleteData;
+ }
+
+ }
+
+ @Override
+ public void add(List documents) {
+ UploadRequest upload = new UploadRequest(documents.stream().map(document -> {
+ // Compute and assign an embedding to the document.
+ document.setEmbedding(this.embeddingClient.embed(document));
+ List floatVector = document.getEmbedding().stream().map(Double::floatValue).toList();
+ return new UploadRequest.Embedding(document.getId(), floatVector, documentField, document.getContent(),
+ document.getMetadata());
+ }).toList());
+
+ ObjectMapper objectMapper = new ObjectMapper();
+ String embeddingsJson = null;
+ try {
+ String embeddingString = objectMapper.writeValueAsString(upload);
+ embeddingsJson = embeddingString.substring("{\"embeddings\":".length());
+ }
+ catch (JsonProcessingException e) {
+ throw new RuntimeException(String.format("Embedding JSON parsing error: %s", e.getMessage()));
+ }
+
+ client.post()
+ .uri("/" + indexName + EMBEDDINGS)
+ .contentType(MediaType.APPLICATION_JSON)
+ .bodyValue(embeddingsJson)
+ .retrieve()
+ .bodyToMono(Void.class)
+ .onErrorMap(WebClientException.class, this::handleHttpClientException)
+ .block();
+ }
+
+ @Override
+ public Optional delete(List idList) {
+ try {
+ client.method(HttpMethod.DELETE)
+ .uri("/" + indexName + EMBEDDINGS)
+ .body(BodyInserters.fromValue(idList))
+ .retrieve()
+ .bodyToMono(Void.class)
+ .block();
+ }
+ catch (Exception e) {
+ logger.warn("Error removing embedding: " + e);
+ return Optional.of(false);
+ }
+ return Optional.of(true);
+ }
+
+ @Override
+ public List similaritySearch(SearchRequest request) {
+ if (request.hasFilterExpression()) {
+ throw new UnsupportedOperationException("Gemfire does not support metadata filter expressions yet.");
+ }
+ List vector = this.embeddingClient.embed(request.getQuery());
+ List floatVector = vector.stream().map(Double::floatValue).toList();
+
+ return client.post()
+ .uri("/" + indexName + QUERY)
+ .contentType(MediaType.APPLICATION_JSON)
+ .bodyValue(new QueryRequest(floatVector, request.getTopK(), topKPerBucket, true))
+ .retrieve()
+ .bodyToFlux(QueryResponse.class)
+ .filter(r -> r.score >= request.getSimilarityThreshold())
+ .map(r -> {
+ Map metadata = r.metadata;
+ metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score);
+ String content = (String) metadata.remove(documentField);
+ return new Document(r.key, content, metadata);
+ })
+ .collectList()
+ .onErrorMap(WebClientException.class, this::handleHttpClientException)
+ .block();
+ }
+
+ public void createIndex(String indexName) throws JsonProcessingException {
+ CreateRequest createRequest = new CreateRequest(indexName);
+ ObjectMapper objectMapper = new ObjectMapper();
+ String index = objectMapper.writeValueAsString(createRequest);
+ client.post()
+ .contentType(MediaType.APPLICATION_JSON)
+ .bodyValue(index)
+ .retrieve()
+ .bodyToMono(Void.class)
+ .onErrorMap(WebClientException.class, this::handleHttpClientException)
+ .block();
+ }
+
+ public void deleteIndex(String indexName) {
+ DeleteRequest deleteRequest = new DeleteRequest();
+ deleteRequest.setDeleteData(true);
+ client.method(HttpMethod.DELETE)
+ .uri("/" + indexName)
+ .body(BodyInserters.fromValue(deleteRequest))
+ .retrieve()
+ .bodyToMono(Void.class)
+ .onErrorMap(WebClientException.class, this::handleHttpClientException)
+ .block();
+ }
+
+ private Throwable handleHttpClientException(Throwable ex) {
+ if (!(ex instanceof WebClientResponseException clientException)) {
+ throw new RuntimeException(String.format("Got an unexpected error: %s", ex));
+ }
+
+ if (clientException.getStatusCode().equals(NOT_FOUND)) {
+ throw new RuntimeException(String.format("Index %s not found: %s", indexName, ex));
+ }
+ else if (clientException.getStatusCode().equals(BAD_REQUEST)) {
+ throw new RuntimeException(String.format("Bad Request: %s", ex));
+ }
+ else {
+ throw new RuntimeException(String.format("Got an unexpected HTTP error: %s", ex));
+ }
+ }
+
+}
diff --git a/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java
new file mode 100644
index 00000000000..49d03254448
--- /dev/null
+++ b/vector-stores/spring-ai-gemfire/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java
@@ -0,0 +1,205 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.vectorstore;
+
+import static java.util.concurrent.TimeUnit.MINUTES;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.Matchers.hasSize;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.transformers.TransformersEmbeddingClient;
+import org.springframework.ai.vectorstore.GemFireVectorStore.GemFireVectorStoreConfig;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+import org.springframework.core.io.DefaultResourceLoader;
+
+/**
+ * @author Geet Rawat
+ * @since 1.0.0
+ */
+@EnabledIfEnvironmentVariable(named = "GEMFIRE_HOST", matches = ".+")
+public class GemFireVectorStoreIT {
+
+ public static final String INDEX_NAME = "spring-ai-index1";
+
+ List documents = List.of(
+ new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
+ new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
+ new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
+
+ public static String getText(String uri) {
+ var resource = new DefaultResourceLoader().getResource(uri);
+ try {
+ return resource.getContentAsString(StandardCharsets.UTF_8);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withUserConfiguration(TestApplication.class);
+
+ @BeforeEach
+ public void createIndex() {
+ contextRunner.run(c -> c.getBean(GemFireVectorStore.class).createIndex(INDEX_NAME));
+ }
+
+ @AfterEach
+ public void deleteIndex() {
+ contextRunner.run(c -> c.getBean(GemFireVectorStore.class).deleteIndex(INDEX_NAME));
+ }
+
+ @Test
+ public void addAndDeleteEmbeddingTest() {
+ contextRunner.run(context -> {
+ VectorStore vectorStore = context.getBean(VectorStore.class);
+ vectorStore.add(documents);
+ vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());
+ Awaitility.await().atMost(1, MINUTES).until(() -> {
+ return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3));
+ }, hasSize(0));
+ });
+ }
+
+ @Test
+ public void addAndSearchTest() {
+ contextRunner.run(context -> {
+ VectorStore vectorStore = context.getBean(VectorStore.class);
+ vectorStore.add(documents);
+
+ Awaitility.await().atMost(1, MINUTES).until(() -> {
+ return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1));
+ }, hasSize(1));
+
+ List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5));
+ Document resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
+ assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
+ assertThat(resultDoc.getMetadata()).hasSize(2);
+ assertThat(resultDoc.getMetadata()).containsKey("meta2");
+ assertThat(resultDoc.getMetadata()).containsKey("distance");
+ });
+ }
+
+ @Test
+ public void documentUpdateTest() {
+ contextRunner.run(context -> {
+ VectorStore vectorStore = context.getBean(VectorStore.class);
+
+ Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
+ Collections.singletonMap("meta1", "meta1"));
+ vectorStore.add(List.of(document));
+ SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5);
+ Awaitility.await().atMost(1, MINUTES).until(() -> {
+ return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1));
+ }, hasSize(1));
+ List results = vectorStore.similaritySearch(springSearchRequest);
+ Document resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(document.getId());
+ assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!");
+ assertThat(resultDoc.getMetadata()).containsKey("meta1");
+ assertThat(resultDoc.getMetadata()).containsKey("distance");
+
+ Document sameIdDocument = new Document(document.getId(),
+ "The World is Big and Salvation Lurks Around the Corner",
+ Collections.singletonMap("meta2", "meta2"));
+
+ vectorStore.add(List.of(sameIdDocument));
+ SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5);
+ results = vectorStore.similaritySearch(fooBarSearchRequest);
+
+ assertThat(results).hasSize(1);
+ resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(document.getId());
+ assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
+ assertThat(resultDoc.getMetadata()).containsKey("meta2");
+ assertThat(resultDoc.getMetadata()).containsKey("distance");
+ });
+ }
+
+ @Test
+ public void searchThresholdTest() {
+
+ contextRunner.run(context -> {
+ VectorStore vectorStore = context.getBean(VectorStore.class);
+ vectorStore.add(documents);
+
+ Awaitility.await().atMost(1, MINUTES).until(() -> {
+ return vectorStore
+ .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll());
+ }, hasSize(3));
+
+ List fullResult = vectorStore
+ .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
+
+ List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();
+ assertThat(distances).hasSize(3);
+
+ float threshold = (distances.get(0) + distances.get(1)) / 2;
+ List results = vectorStore
+ .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold));
+
+ assertThat(results).hasSize(1);
+
+ Document resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
+ assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
+ assertThat(resultDoc.getMetadata()).containsKey("meta2");
+ assertThat(resultDoc.getMetadata()).containsKey("distance");
+ });
+ }
+
+ @SpringBootConfiguration
+ @EnableAutoConfiguration
+ public static class TestApplication {
+
+ @Bean
+ public GemFireVectorStoreConfig gemfireVectorStoreConfig() {
+ return GemFireVectorStoreConfig.builder().withHost("localhost").build();
+ }
+
+ @Bean
+ public GemFireVectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingClient embeddingClient) {
+ GemFireVectorStore gemFireVectorStore = new GemFireVectorStore(config, embeddingClient);
+ gemFireVectorStore.setIndexName(INDEX_NAME);
+ return gemFireVectorStore;
+ }
+
+ @Bean
+ public EmbeddingClient embeddingClient() {
+ return new TransformersEmbeddingClient();
+ }
+
+ }
+
+}
\ No newline at end of file