Skip to content

Commit

Permalink
Merge pull request #75 from tryAGI/Metadata-search-for-opensearch
Browse files Browse the repository at this point in the history
Change to readonly collection and cleanup
  • Loading branch information
robalexclark authored Oct 25, 2024
2 parents e9116ac + 0158148 commit e153878
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/Abstractions/src/IVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Task<VectorSearchResponse> SearchAsync(
/// <param name="filters">The filters to apply to the search request.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task representing the asynchronous operation. The task result contains the search response.</returns>
Task<List<Vector>> SearchByMetadata(
Task<IReadOnlyList<Vector>> SearchByMetadata(
Dictionary<string, object> filters,
CancellationToken cancellationToken = default);

Expand Down
3 changes: 2 additions & 1 deletion src/Abstractions/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ LangChain.Databases.IVectorCollection.Id.get -> string!
LangChain.Databases.IVectorCollection.IsEmptyAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<bool>!
LangChain.Databases.IVectorCollection.Name.get -> string!
LangChain.Databases.IVectorCollection.SearchAsync(LangChain.Databases.VectorSearchRequest! request, LangChain.Databases.VectorSearchSettings? settings = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<LangChain.Databases.VectorSearchResponse!>!
LangChain.Databases.IVectorCollection.SearchByMetadata(System.Collections.Generic.Dictionary<string!, object!>! filters, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<System.Collections.Generic.List<LangChain.Databases.Vector!>!>!
LangChain.Databases.IVectorCollection.SearchByMetadata(System.Collections.Generic.Dictionary<string!, object!>! filters, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<System.Collections.Generic.IReadOnlyList<LangChain.Databases.Vector!>!>!
LangChain.Databases.IVectorDatabase
LangChain.Databases.IVectorDatabase.CreateCollectionAsync(string! collectionName, int dimensions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
LangChain.Databases.IVectorDatabase.DeleteCollectionAsync(string! collectionName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Expand Down Expand Up @@ -90,6 +90,7 @@ static LangChain.Databases.RelevanceScoreFunctions.Cosine(float distance) -> flo
static LangChain.Databases.RelevanceScoreFunctions.Euclidean(float distance) -> float
static LangChain.Databases.RelevanceScoreFunctions.Get(LangChain.Databases.DistanceStrategy distanceStrategy) -> System.Func<float, float>!
static LangChain.Databases.RelevanceScoreFunctions.MaxInnerProduct(float distance) -> float
static LangChain.Databases.VectorCollection.IsValidJsonKey(string! input) -> bool
static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]! embedding) -> LangChain.Databases.VectorSearchRequest!
static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]![]! embeddings) -> LangChain.Databases.VectorSearchRequest!
static LangChain.Databases.VectorSearchRequest.ToVectorSearchRequest(float[]! embedding) -> LangChain.Databases.VectorSearchRequest!
Expand Down
9 changes: 9 additions & 0 deletions src/Abstractions/src/VectorCollection.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Text.RegularExpressions;

namespace LangChain.Databases;

/// <summary>
Expand All @@ -20,4 +22,11 @@ public class VectorCollection(
/// Collection name provided by client.
/// </summary>
public string Name { get; set; } = name;


protected static bool IsValidJsonKey(string input)
{
// Only allow letters, numbers, and underscores
return Regex.IsMatch(input, @"^\w+$");
}
}
2 changes: 1 addition & 1 deletion src/Chroma/src/ChromaVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ private static IDictionary<string, object> DeserializeMetadata(MemoryRecordMetad
?? new Dictionary<string, object>();
}

public Task<List<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
public Task<IReadOnlyList<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
{
throw new NotSupportedException("Chroma doesn't support collection metadata");
}
Expand Down
2 changes: 1 addition & 1 deletion src/Elasticsearch/src/ElasticsearchVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
throw new NotImplementedException();
}

Task<List<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
Task<IReadOnlyList<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand Down
2 changes: 1 addition & 1 deletion src/InMemory/src/InMemoryVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
return Task.FromResult(_vectors.GetValueOrDefault(id));
}

public async Task<List<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken = default)
public async Task<IReadOnlyList<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken = default)
{
filters = filters ?? throw new ArgumentNullException(nameof(filters));
var filteredVectors = await Task.Run(() => _vectors.Values.Where(vector =>
Expand Down
8 changes: 7 additions & 1 deletion src/Mongo/src/MongoVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,36 @@ public class MongoVectorCollection(
{
private readonly IMongoCollection<Vector> _mongoCollection = mongoContext.GetCollection<Vector>(name);

/// <inheritdoc />
public async Task<IReadOnlyCollection<string>> AddAsync(IReadOnlyCollection<Vector> items, CancellationToken cancellationToken = default)
{
await _mongoCollection.InsertManyAsync(items, cancellationToken: cancellationToken).ConfigureAwait(false);
return items.Select(i => i.Id).ToList();
}

/// <inheritdoc />
public async Task<bool> DeleteAsync(IEnumerable<string> ids, CancellationToken cancellationToken = default)
{
var filter = Builders<Vector>.Filter.In(i => i.Id, ids);
var result = await _mongoCollection.DeleteManyAsync(filter, cancellationToken).ConfigureAwait(false);
return result.IsAcknowledged;
}

/// <inheritdoc />
public async Task<Vector?> GetAsync(string id, CancellationToken cancellationToken = default)
{
var filter = Builders<Vector>.Filter.Eq(i => i.Id, id);
var result = await _mongoCollection.FindAsync(filter, cancellationToken: cancellationToken).ConfigureAwait(false);
return result.FirstOrDefault(cancellationToken: cancellationToken);
}

/// <inheritdoc />
public async Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
{
return await _mongoCollection.EstimatedDocumentCountAsync(cancellationToken: cancellationToken).ConfigureAwait(false) == 0;
}

/// <inheritdoc />
public async Task<VectorSearchResponse> SearchAsync(VectorSearchRequest request, VectorSearchSettings? settings = null, CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
Expand Down Expand Up @@ -71,7 +76,8 @@ public async Task<VectorSearchResponse> SearchAsync(VectorSearchRequest request,
};
}

public async Task<List<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken = default)
/// <inheritdoc />
public async Task<IReadOnlyList<Vector>> SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken = default)
{
filters = filters ?? throw new ArgumentNullException(nameof(filters));

Expand Down
2 changes: 1 addition & 1 deletion src/OpenSearch/src/OpenSearchVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
throw new NotImplementedException();
}

Task<List<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
Task<IReadOnlyList<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand Down
28 changes: 21 additions & 7 deletions src/Postgres/src/PostgresVectorCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,33 @@ public Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
throw new NotImplementedException();
}

public async Task<List<Vector>> SearchByMetadata(
/// <inheritdoc />
public async Task<IReadOnlyList<Vector>> SearchByMetadata(
Dictionary<string, object> filters,
CancellationToken cancellationToken = default)
{
filters = filters ?? throw new ArgumentNullException(nameof(filters));

foreach (var kvp in filters)
{
if (string.IsNullOrWhiteSpace(kvp.Key))
{
throw new ArgumentException("Filter key cannot be null or whitespace.", nameof(filters));
}
// Add more validation for allowed characters
if (!IsValidJsonKey(kvp.Key))
{
throw new ArgumentException($"Invalid character in filter key: {kvp.Key}", nameof(filters));
}
}

var records = await client
.GetRecordsByMetadataAsync(
Name,
filters,
withEmbeddings: false,
cancellationToken: cancellationToken)
.ConfigureAwait(false);
.GetRecordsByMetadataAsync(
Name,
filters,
withEmbeddings: false,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

var vectors = records.Select(record => new Vector
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public async Task<VectorSearchResponse> SearchAsync(VectorSearchRequest request,
return new VectorSearchResponse { Items = results.Select(x => new Vector { Text = x.Item1.Metadata.ExternalSourceName }).ToList() };
}

Task<List<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
Task<IReadOnlyList<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand Down
9 changes: 1 addition & 8 deletions src/Sqlite/src/SqLiteVectorCollection.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Microsoft.Data.Sqlite;
using System.Globalization;
using System.Text.Json;
using System.Text.RegularExpressions;

namespace LangChain.Databases.Sqlite;

Expand Down Expand Up @@ -190,7 +189,7 @@ public async Task<VectorSearchResponse> SearchAsync(
}

/// <inheritdoc />
public async Task<List<Vector>> SearchByMetadata(
public async Task<IReadOnlyList<Vector>> SearchByMetadata(
Dictionary<string, object> filters,
CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -243,10 +242,4 @@ public async Task<List<Vector>> SearchByMetadata(

return res;
}

private static bool IsValidJsonKey(string input)
{
// Only allow letters, numbers, and underscores
return Regex.IsMatch(input, @"^\w+$");
}
}

0 comments on commit e153878

Please sign in to comment.