diff --git a/src/Abstractions/src/IVectorCollection.cs b/src/Abstractions/src/IVectorCollection.cs index 4877ea6..d3448ef 100644 --- a/src/Abstractions/src/IVectorCollection.cs +++ b/src/Abstractions/src/IVectorCollection.cs @@ -63,7 +63,7 @@ Task SearchAsync( /// The filters to apply to the search request. /// The cancellation token. /// A task representing the asynchronous operation. The task result contains the search response. - Task> SearchByMetadata( + Task> SearchByMetadata( Dictionary filters, CancellationToken cancellationToken = default); diff --git a/src/Abstractions/src/PublicAPI.Unshipped.txt b/src/Abstractions/src/PublicAPI.Unshipped.txt index 724ee6e..3f78cd2 100644 --- a/src/Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Abstractions/src/PublicAPI.Unshipped.txt @@ -14,7 +14,7 @@ LangChain.Databases.IVectorCollection.Id.get -> string! LangChain.Databases.IVectorCollection.IsEmptyAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! LangChain.Databases.IVectorCollection.Name.get -> string! LangChain.Databases.IVectorCollection.SearchAsync(LangChain.Databases.VectorSearchRequest! request, LangChain.Databases.VectorSearchSettings? settings = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! -LangChain.Databases.IVectorCollection.SearchByMetadata(System.Collections.Generic.Dictionary! filters, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!>! +LangChain.Databases.IVectorCollection.SearchByMetadata(System.Collections.Generic.Dictionary! filters, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!>! LangChain.Databases.IVectorDatabase LangChain.Databases.IVectorDatabase.CreateCollectionAsync(string! collectionName, int dimensions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! LangChain.Databases.IVectorDatabase.DeleteCollectionAsync(string! collectionName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! @@ -90,6 +90,7 @@ static LangChain.Databases.RelevanceScoreFunctions.Cosine(float distance) -> flo static LangChain.Databases.RelevanceScoreFunctions.Euclidean(float distance) -> float static LangChain.Databases.RelevanceScoreFunctions.Get(LangChain.Databases.DistanceStrategy distanceStrategy) -> System.Func! static LangChain.Databases.RelevanceScoreFunctions.MaxInnerProduct(float distance) -> float +static LangChain.Databases.VectorCollection.IsValidJsonKey(string! input) -> bool static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]! embedding) -> LangChain.Databases.VectorSearchRequest! static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]![]! embeddings) -> LangChain.Databases.VectorSearchRequest! static LangChain.Databases.VectorSearchRequest.ToVectorSearchRequest(float[]! embedding) -> LangChain.Databases.VectorSearchRequest! diff --git a/src/Abstractions/src/VectorCollection.cs b/src/Abstractions/src/VectorCollection.cs index 5cd8112..45f0b92 100644 --- a/src/Abstractions/src/VectorCollection.cs +++ b/src/Abstractions/src/VectorCollection.cs @@ -1,3 +1,5 @@ +using System.Text.RegularExpressions; + namespace LangChain.Databases; /// @@ -20,4 +22,11 @@ public class VectorCollection( /// Collection name provided by client. /// public string Name { get; set; } = name; + + + protected static bool IsValidJsonKey(string input) + { + // Only allow letters, numbers, and underscores + return Regex.IsMatch(input, @"^\w+$"); + } } \ No newline at end of file diff --git a/src/Chroma/src/ChromaVectorCollection.cs b/src/Chroma/src/ChromaVectorCollection.cs index 69dfcb3..102b569 100644 --- a/src/Chroma/src/ChromaVectorCollection.cs +++ b/src/Chroma/src/ChromaVectorCollection.cs @@ -167,7 +167,7 @@ private static IDictionary DeserializeMetadata(MemoryRecordMetad ?? new Dictionary(); } - public Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + public Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) { throw new NotSupportedException("Chroma doesn't support collection metadata"); } diff --git a/src/Elasticsearch/src/ElasticsearchVectorCollection.cs b/src/Elasticsearch/src/ElasticsearchVectorCollection.cs index 4091baf..2294058 100644 --- a/src/Elasticsearch/src/ElasticsearchVectorCollection.cs +++ b/src/Elasticsearch/src/ElasticsearchVectorCollection.cs @@ -104,7 +104,7 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) throw new NotImplementedException(); } - Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/InMemory/src/InMemoryVectorCollection.cs b/src/InMemory/src/InMemoryVectorCollection.cs index 2420d0e..93d19f4 100644 --- a/src/InMemory/src/InMemoryVectorCollection.cs +++ b/src/InMemory/src/InMemoryVectorCollection.cs @@ -92,7 +92,7 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) return Task.FromResult(_vectors.GetValueOrDefault(id)); } - public async Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken = default) + public async Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken = default) { filters = filters ?? throw new ArgumentNullException(nameof(filters)); var filteredVectors = await Task.Run(() => _vectors.Values.Where(vector => diff --git a/src/Mongo/src/MongoVectorCollection.cs b/src/Mongo/src/MongoVectorCollection.cs index 0da1a83..0549048 100644 --- a/src/Mongo/src/MongoVectorCollection.cs +++ b/src/Mongo/src/MongoVectorCollection.cs @@ -13,12 +13,14 @@ public class MongoVectorCollection( { private readonly IMongoCollection _mongoCollection = mongoContext.GetCollection(name); + /// public async Task> AddAsync(IReadOnlyCollection items, CancellationToken cancellationToken = default) { await _mongoCollection.InsertManyAsync(items, cancellationToken: cancellationToken).ConfigureAwait(false); return items.Select(i => i.Id).ToList(); } + /// public async Task DeleteAsync(IEnumerable ids, CancellationToken cancellationToken = default) { var filter = Builders.Filter.In(i => i.Id, ids); @@ -26,6 +28,7 @@ public async Task DeleteAsync(IEnumerable ids, CancellationToken c return result.IsAcknowledged; } + /// public async Task GetAsync(string id, CancellationToken cancellationToken = default) { var filter = Builders.Filter.Eq(i => i.Id, id); @@ -33,11 +36,13 @@ public async Task DeleteAsync(IEnumerable ids, CancellationToken c return result.FirstOrDefault(cancellationToken: cancellationToken); } + /// public async Task IsEmptyAsync(CancellationToken cancellationToken = default) { return await _mongoCollection.EstimatedDocumentCountAsync(cancellationToken: cancellationToken).ConfigureAwait(false) == 0; } + /// public async Task SearchAsync(VectorSearchRequest request, VectorSearchSettings? settings = null, CancellationToken cancellationToken = default) { request = request ?? throw new ArgumentNullException(nameof(request)); @@ -71,7 +76,8 @@ public async Task SearchAsync(VectorSearchRequest request, }; } - public async Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken = default) + /// + public async Task> SearchByMetadata(Dictionary filters, CancellationToken cancellationToken = default) { filters = filters ?? throw new ArgumentNullException(nameof(filters)); diff --git a/src/OpenSearch/src/OpenSearchVectorCollection.cs b/src/OpenSearch/src/OpenSearchVectorCollection.cs index 2ea134c..eb9ab13 100644 --- a/src/OpenSearch/src/OpenSearchVectorCollection.cs +++ b/src/OpenSearch/src/OpenSearchVectorCollection.cs @@ -134,7 +134,7 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) throw new NotImplementedException(); } - Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/Postgres/src/PostgresVectorCollection.cs b/src/Postgres/src/PostgresVectorCollection.cs index 8a69a1c..e0b1ba2 100644 --- a/src/Postgres/src/PostgresVectorCollection.cs +++ b/src/Postgres/src/PostgresVectorCollection.cs @@ -124,19 +124,33 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) throw new NotImplementedException(); } - public async Task> SearchByMetadata( + /// + public async Task> SearchByMetadata( Dictionary filters, CancellationToken cancellationToken = default) { filters = filters ?? throw new ArgumentNullException(nameof(filters)); + foreach (var kvp in filters) + { + if (string.IsNullOrWhiteSpace(kvp.Key)) + { + throw new ArgumentException("Filter key cannot be null or whitespace.", nameof(filters)); + } + // Add more validation for allowed characters + if (!IsValidJsonKey(kvp.Key)) + { + throw new ArgumentException($"Invalid character in filter key: {kvp.Key}", nameof(filters)); + } + } + var records = await client - .GetRecordsByMetadataAsync( - Name, - filters, - withEmbeddings: false, - cancellationToken: cancellationToken) - .ConfigureAwait(false); + .GetRecordsByMetadataAsync( + Name, + filters, + withEmbeddings: false, + cancellationToken: cancellationToken) + .ConfigureAwait(false); var vectors = records.Select(record => new Vector { diff --git a/src/SemanticKernel/src/SemanticKernelMemoryStoreCollection.cs b/src/SemanticKernel/src/SemanticKernelMemoryStoreCollection.cs index b3d190f..a89ac69 100644 --- a/src/SemanticKernel/src/SemanticKernelMemoryStoreCollection.cs +++ b/src/SemanticKernel/src/SemanticKernelMemoryStoreCollection.cs @@ -76,7 +76,7 @@ public async Task SearchAsync(VectorSearchRequest request, return new VectorSearchResponse { Items = results.Select(x => new Vector { Text = x.Item1.Metadata.ExternalSourceName }).ToList() }; } - Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/Sqlite/src/SqLiteVectorCollection.cs b/src/Sqlite/src/SqLiteVectorCollection.cs index b31ce34..f2bfc06 100644 --- a/src/Sqlite/src/SqLiteVectorCollection.cs +++ b/src/Sqlite/src/SqLiteVectorCollection.cs @@ -1,7 +1,6 @@ using Microsoft.Data.Sqlite; using System.Globalization; using System.Text.Json; -using System.Text.RegularExpressions; namespace LangChain.Databases.Sqlite; @@ -190,7 +189,7 @@ public async Task SearchAsync( } /// - public async Task> SearchByMetadata( + public async Task> SearchByMetadata( Dictionary filters, CancellationToken cancellationToken = default) { @@ -243,10 +242,4 @@ public async Task> SearchByMetadata( return res; } - - private static bool IsValidJsonKey(string input) - { - // Only allow letters, numbers, and underscores - return Regex.IsMatch(input, @"^\w+$"); - } } \ No newline at end of file