diff --git a/src/IntegrationTests/DatabaseTests.cs b/src/IntegrationTests/DatabaseTests.cs index 37a19c5..09a3d90 100644 --- a/src/IntegrationTests/DatabaseTests.cs +++ b/src/IntegrationTests/DatabaseTests.cs @@ -301,7 +301,7 @@ public async Task SimilaritySearchWithScores_Ok(SupportedDatabase database) [TestCase(SupportedDatabase.InMemory)] //[TestCase(SupportedDatabase.OpenSearch)] - //[TestCase(SupportedDatabase.Postgres)] + [TestCase(SupportedDatabase.Postgres)] [TestCase(SupportedDatabase.SqLite)] [TestCase(SupportedDatabase.Mongo)] //[TestCase(SupportedDatabase.DuckDb)] @@ -348,7 +348,7 @@ public async Task MetadataSearch_Ok(SupportedDatabase database) [TestCase(SupportedDatabase.InMemory)] //[TestCase(SupportedDatabase.OpenSearch)] - //[TestCase(SupportedDatabase.Postgres)] + [TestCase(SupportedDatabase.Postgres)] [TestCase(SupportedDatabase.SqLite)] [TestCase(SupportedDatabase.Mongo)] //[TestCase(SupportedDatabase.DuckDb)] diff --git a/src/Postgres/src/PostgresDbClient.cs b/src/Postgres/src/PostgresDbClient.cs index 6712456..aec4f2c 100644 --- a/src/Postgres/src/PostgresDbClient.cs +++ b/src/Postgres/src/PostgresDbClient.cs @@ -1,11 +1,10 @@ +using LangChain.Databases.JsonConverters; +using Npgsql; +using NpgsqlTypes; using System.Data; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text.Json; -using LangChain.Databases.JsonConverters; -using Npgsql; -using NpgsqlTypes; -using Pgvector; namespace LangChain.Databases.Postgres; @@ -81,7 +80,6 @@ public async Task IsTableExistsAsync(string tableName, CancellationToken c } } - public async Task> ListTablesAsync(CancellationToken cancellationToken = default) { var connection = await _dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); @@ -489,6 +487,60 @@ private async Task ReadEntryAsync( /// /// private string GetFullTableName(string tableName) => $"{_schema}.\"{tableName}\""; + + [CLSCompliant(false)] + public async Task> GetRecordsByMetadataAsync( + string tableName, + Dictionary filters, + bool withEmbeddings = false, + CancellationToken cancellationToken = default) + { + filters = filters ?? throw new ArgumentNullException(nameof(filters)); + + var whereClauses = new List(); + var parameters = new List(); + + int paramIndex = 0; + foreach (var kvp in filters) + { + string paramName = $"@p{paramIndex}"; + + // Use the JSONB containment operator @> for metadata filtering + whereClauses.Add($"metadata @> {paramName}::jsonb"); + + // Serialize the key-value pair to JSON + var jsonValue = JsonSerializer.Serialize(new Dictionary { { kvp.Key, kvp.Value } }); + + parameters.Add(new NpgsqlParameter(paramName, NpgsqlDbType.Jsonb) { Value = jsonValue }); + paramIndex++; + } + + string whereClause = string.Join(" AND ", whereClauses); + + var queryColumns = withEmbeddings + ? "id, content, metadata, timestamp, embedding" + : "id, content, metadata, timestamp"; + + string query = $"SELECT {queryColumns} FROM {GetFullTableName(tableName)} WHERE {whereClause}"; + + var records = new List(); + + using var connection = await _dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + using var cmd = connection.CreateCommand(); + cmd.CommandText = query; + cmd.Parameters.AddRange(parameters.ToArray()); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + var record = await ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + records.Add(record); + } + + return records; + } } /// diff --git a/src/Postgres/src/PostgresVectorCollection.cs b/src/Postgres/src/PostgresVectorCollection.cs index bca1851..8a69a1c 100644 --- a/src/Postgres/src/PostgresVectorCollection.cs +++ b/src/Postgres/src/PostgresVectorCollection.cs @@ -1,4 +1,3 @@ - namespace LangChain.Databases.Postgres; /// @@ -125,8 +124,28 @@ public Task IsEmptyAsync(CancellationToken cancellationToken = default) throw new NotImplementedException(); } - Task> IVectorCollection.SearchByMetadata(Dictionary filters, CancellationToken cancellationToken) + public async Task> SearchByMetadata( + Dictionary filters, + CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + filters = filters ?? throw new ArgumentNullException(nameof(filters)); + + var records = await client + .GetRecordsByMetadataAsync( + Name, + filters, + withEmbeddings: false, + cancellationToken: cancellationToken) + .ConfigureAwait(false); + + var vectors = records.Select(record => new Vector + { + Id = record.Id, + Text = record.Content, + Metadata = record.Metadata, + // Embedding is null since withEmbeddings is false + }).ToList(); + + return vectors; } } \ No newline at end of file diff --git a/src/Postgres/src/PostgresVectorDatabase.cs b/src/Postgres/src/PostgresVectorDatabase.cs index 0d7af88..814b485 100644 --- a/src/Postgres/src/PostgresVectorDatabase.cs +++ b/src/Postgres/src/PostgresVectorDatabase.cs @@ -68,4 +68,4 @@ public Task> ListCollectionsAsync(CancellationToken cancel { return _client.ListTablesAsync(cancellationToken); } -} +} \ No newline at end of file