Skip to content

Commit

Permalink
Merge pull request #71 from tryAGI/Add-metadata-search-to-mongodb
Browse files Browse the repository at this point in the history
Add metadata search to mongodb
  • Loading branch information
robalexclark authored Oct 23, 2024
2 parents 9d2890e + f73459c commit 66e3ef6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
29 changes: 27 additions & 2 deletions src/IntegrationTests/DatabaseTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using LangChain.DocumentLoaders;
using LangChain.Extensions;
using StackExchange.Redis;
using System;

namespace LangChain.Databases.IntegrationTests;

Expand Down Expand Up @@ -298,10 +300,10 @@ public async Task SimilaritySearchWithScores_Ok(SupportedDatabase database)
}

[TestCase(SupportedDatabase.InMemory)]
//[TestCase(SupportedDatabase.Chroma)]
//[TestCase(SupportedDatabase.OpenSearch)]
//[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
//[TestCase(SupportedDatabase.DuckDb)]
//[TestCase(SupportedDatabase.Weaviate)]
//[TestCase(SupportedDatabase.Elasticsearch)]
Expand Down Expand Up @@ -336,7 +338,30 @@ public async Task MetadataSearch_Ok(SupportedDatabase database)
var items = await vectorCollection.SearchByMetadata(filters);

totalItems.Should().HaveCount(2);

items.Should().HaveCount(1);
var vector = items.SingleOrDefault()?.Metadata?["color"].Should().Be("orange");
var result = items.Single();
result.Text.Should().Be("orange");
result.Metadata.Should().ContainKey("color");
result.Metadata?["color"].Should().Be("orange");
}

[TestCase(SupportedDatabase.InMemory)]
//[TestCase(SupportedDatabase.OpenSearch)]
//[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
//[TestCase(SupportedDatabase.DuckDb)]
//[TestCase(SupportedDatabase.Weaviate)]
//[TestCase(SupportedDatabase.Elasticsearch)]
//[TestCase(SupportedDatabase.Milvus)]
public async Task SearchByMetadata_WithNullFilters_ThrowsArgumentException(SupportedDatabase database)
{
await using var environment = await StartEnvironmentForAsync(database);
var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(environment.CollectionName, dimensions: environment.Dimensions);

// Act & Assert
await vectorCollection.Invoking(v => v.SearchByMetadata(null!))
.Should().ThrowAsync<ArgumentNullException>();
}
}
32 changes: 28 additions & 4 deletions src/Mongo/src/MongoVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class MongoVectorCollection(
string? id = null)
: VectorCollection(name, id), IVectorCollection
{
private IMongoCollection<Vector> _mongoCollection = mongoContext.GetCollection<Vector>(name);
private readonly IMongoCollection<Vector> _mongoCollection = mongoContext.GetCollection<Vector>(name);

public async Task<IReadOnlyCollection<string>> AddAsync(IReadOnlyCollection<Vector> items, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -71,8 +71,32 @@ public async Task<VectorSearchResponse> SearchAsync(VectorSearchRequest request,
};
}

Task<List<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
public async Task<List<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
filters = filters ?? throw new ArgumentNullException(nameof(filters));

var builder = Builders<Vector>.Filter;
var filterDefinitions = new List<FilterDefinition<Vector>>();

foreach (var kvp in filters)
{
if (kvp.Value == null)
{
throw new ArgumentException($"Metadata value for key '{kvp.Key}' cannot be null", nameof(filters));
}

// Assuming your Vector class has a Metadata field of type Dictionary<string, object>
var filter = builder.Eq($"Metadata.{kvp.Key}", kvp.Value);
filterDefinitions.Add(filter);
}

var combinedFilter = builder.And(filterDefinitions);

var results = await _mongoCollection
.Find(combinedFilter)
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

return results;
}
}
}

0 comments on commit 66e3ef6

Please sign in to comment.