Skip to content

Commit

Permalink
Merge pull request #73 from tryAGI/Metadata-search-for-postgres
Browse files Browse the repository at this point in the history
Added metadata search for Postgres
  • Loading branch information
robalexclark authored Oct 24, 2024
2 parents 503d753 + f1e1250 commit 09b1567
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/IntegrationTests/DatabaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public async Task SimilaritySearchWithScores_Ok(SupportedDatabase database)

[TestCase(SupportedDatabase.InMemory)]
//[TestCase(SupportedDatabase.OpenSearch)]
//[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
//[TestCase(SupportedDatabase.DuckDb)]
Expand Down Expand Up @@ -348,7 +348,7 @@ public async Task MetadataSearch_Ok(SupportedDatabase database)

[TestCase(SupportedDatabase.InMemory)]
//[TestCase(SupportedDatabase.OpenSearch)]
//[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.Postgres)]
[TestCase(SupportedDatabase.SqLite)]
[TestCase(SupportedDatabase.Mongo)]
//[TestCase(SupportedDatabase.DuckDb)]
Expand Down
62 changes: 57 additions & 5 deletions src/Postgres/src/PostgresDbClient.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
using LangChain.Databases.JsonConverters;
using Npgsql;
using NpgsqlTypes;
using System.Data;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Text.Json;
using LangChain.Databases.JsonConverters;
using Npgsql;
using NpgsqlTypes;
using Pgvector;

namespace LangChain.Databases.Postgres;

Expand Down Expand Up @@ -81,7 +80,6 @@ public async Task<bool> IsTableExistsAsync(string tableName, CancellationToken c
}
}


public async Task<IReadOnlyList<string>> ListTablesAsync(CancellationToken cancellationToken = default)
{
var connection = await _dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -489,6 +487,60 @@ private async Task<EmbeddingTableRecord> ReadEntryAsync(
/// <param name="tableName"></param>
/// <returns></returns>
private string GetFullTableName(string tableName) => $"{_schema}.\"{tableName}\"";

[CLSCompliant(false)]
public async Task<List<EmbeddingTableRecord>> GetRecordsByMetadataAsync(
string tableName,
Dictionary<string, object> filters,
bool withEmbeddings = false,
CancellationToken cancellationToken = default)
{
filters = filters ?? throw new ArgumentNullException(nameof(filters));

var whereClauses = new List<string>();
var parameters = new List<NpgsqlParameter>();

int paramIndex = 0;
foreach (var kvp in filters)
{
string paramName = $"@p{paramIndex}";

// Use the JSONB containment operator @> for metadata filtering
whereClauses.Add($"metadata @> {paramName}::jsonb");

// Serialize the key-value pair to JSON
var jsonValue = JsonSerializer.Serialize(new Dictionary<string, object> { { kvp.Key, kvp.Value } });

parameters.Add(new NpgsqlParameter(paramName, NpgsqlDbType.Jsonb) { Value = jsonValue });
paramIndex++;
}

string whereClause = string.Join(" AND ", whereClauses);

var queryColumns = withEmbeddings
? "id, content, metadata, timestamp, embedding"
: "id, content, metadata, timestamp";

string query = $"SELECT {queryColumns} FROM {GetFullTableName(tableName)} WHERE {whereClause}";

var records = new List<EmbeddingTableRecord>();

using var connection = await _dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);

using var cmd = connection.CreateCommand();
cmd.CommandText = query;
cmd.Parameters.AddRange(parameters.ToArray());

using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
var record = await ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false);
records.Add(record);
}

return records;
}
}

/// <summary>
Expand Down
25 changes: 22 additions & 3 deletions src/Postgres/src/PostgresVectorCollection.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

namespace LangChain.Databases.Postgres;

/// <summary>
Expand Down Expand Up @@ -125,8 +124,28 @@ public Task<bool> IsEmptyAsync(CancellationToken cancellationToken = default)
throw new NotImplementedException();
}

Task<List<Vector>> IVectorCollection.SearchByMetadata(Dictionary<string, object> filters, CancellationToken cancellationToken)
public async Task<List<Vector>> SearchByMetadata(
Dictionary<string, object> filters,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
filters = filters ?? throw new ArgumentNullException(nameof(filters));

var records = await client
.GetRecordsByMetadataAsync(
Name,
filters,
withEmbeddings: false,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

var vectors = records.Select(record => new Vector
{
Id = record.Id,
Text = record.Content,
Metadata = record.Metadata,
// Embedding is null since withEmbeddings is false
}).ToList();

return vectors;
}
}
2 changes: 1 addition & 1 deletion src/Postgres/src/PostgresVectorDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ public Task<IReadOnlyList<string>> ListCollectionsAsync(CancellationToken cancel
{
return _client.ListTablesAsync(cancellationToken);
}
}
}

0 comments on commit 09b1567

Please sign in to comment.