From 6499ddac30b3cfb6d656d5c5e4e61a2c420bcc4f Mon Sep 17 00:00:00 2001 From: Robin Date: Sun, 20 Oct 2024 18:05:32 +0100 Subject: [PATCH] Add inmemory searching for metadata and also add tests --- src/InMemory/src/InMemoryVectorCollection.cs | 25 ++++++++++- src/IntegrationTests/DatabaseTests.cs | 45 +++++++++++++++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/InMemory/src/InMemoryVectorCollection.cs b/src/InMemory/src/InMemoryVectorCollection.cs index e30cfef..3040339 100644 --- a/src/InMemory/src/InMemoryVectorCollection.cs +++ b/src/InMemory/src/InMemoryVectorCollection.cs @@ -92,8 +92,29 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) return Task.FromResult(_vectors.GetValueOrDefault(id)); } - Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + public Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + filters = filters ?? throw new ArgumentNullException(nameof(filters)); + + var filteredVectors = _vectors.Values.Where(vector => + { + // Check if all filters match + foreach (var filter in filters) + { + object? metadataValue = null; + if (vector.Metadata != null && !vector.Metadata.TryGetValue(filter.Key, out metadataValue) || metadataValue == null) + { + return false; + } + else if (!metadataValue.Equals(filter.Value)) // Convert metadata value to string and compare + { + return false; + } + } + + return true; + }).ToList(); + + return Task.FromResult(filteredVectors); } } \ No newline at end of file diff --git a/src/IntegrationTests/DatabaseTests.cs b/src/IntegrationTests/DatabaseTests.cs index ba3b19b..6b9e4e0 100644 --- a/src/IntegrationTests/DatabaseTests.cs +++ b/src/IntegrationTests/DatabaseTests.cs @@ -1,5 +1,5 @@ -using LangChain.Extensions; using LangChain.DocumentLoaders; +using LangChain.Extensions; namespace LangChain.Databases.IntegrationTests; @@ -296,4 +296,47 @@ public async Task SimilaritySearchWithScores_Ok(SupportedDatabase database) first.Distance.Should().BeGreaterOrEqualTo(1f); } } + + [TestCase(SupportedDatabase.InMemory)] + //[TestCase(SupportedDatabase.Chroma)] + //[TestCase(SupportedDatabase.OpenSearch)] + //[TestCase(SupportedDatabase.Postgres)] + [TestCase(SupportedDatabase.SqLite)] + //[TestCase(SupportedDatabase.DuckDb)] + //[TestCase(SupportedDatabase.Weaviate)] + //[TestCase(SupportedDatabase.Elasticsearch)] + //[TestCase(SupportedDatabase.Milvus)] + public async Task MetadataSearch_Ok(SupportedDatabase database) + { + await using var environment = await StartEnvironmentForAsync(database); + var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(environment.CollectionName, dimensions: environment.Dimensions); + + var texts = new[] { "apple", "orange" }; + + var metadatas = new Dictionary[2]; + metadatas[0] = new Dictionary + { + ["color"] = "red" + }; + + metadatas[1] = new Dictionary + { + ["color"] = "orange" + }; + + var totalItems = await vectorCollection.AddTextsAsync(environment.EmbeddingModel, texts, metadatas); + + // Define the filters to get the orange entry + var filters = new Dictionary + { + { "color", "orange" } + }; + + + var items = await vectorCollection.SearchByMetadata(filters); + + totalItems.Should().HaveCount(2); + items.Should().HaveCount(1); + var vector = items.SingleOrDefault()?.Metadata?["color"].Should().Be("orange"); + } } \ No newline at end of file