From 664de7121defb4b37dd80d341aced92af4d2ca7f Mon Sep 17 00:00:00 2001 From: HavenDV Date: Thu, 19 Sep 2024 03:25:12 +0400 Subject: [PATCH] fix: Removed LangChain.Core usage over repo. --- .../LangChain.Databases.Abstractions.csproj | 2 + .../MessageHistory/BaseChatMessageHistory.cs | 97 +++++++ .../src/MessageHistory/ChatMessageHistory.cs | 40 +++ .../MessageHistory/FileChatMessageHistory.cs | 81 ++++++ src/Abstractions/src/PublicAPI.Unshipped.txt | 23 +- src/Directory.Packages.props | 6 +- .../DatabaseTests.OpenSearch.cs | 242 ------------------ .../Extensions/VectorCollectionExtensions.cs | 118 +++++++++ ...angChain.Databases.IntegrationTests.csproj | 5 +- .../src/LangChain.Databases.Mongo.csproj | 1 - .../src/LangChain.Databases.Redis.csproj | 1 - 11 files changed, 366 insertions(+), 250 deletions(-) create mode 100644 src/Abstractions/src/MessageHistory/BaseChatMessageHistory.cs create mode 100644 src/Abstractions/src/MessageHistory/ChatMessageHistory.cs create mode 100644 src/Abstractions/src/MessageHistory/FileChatMessageHistory.cs delete mode 100644 src/IntegrationTests/DatabaseTests.OpenSearch.cs create mode 100644 src/IntegrationTests/Extensions/VectorCollectionExtensions.cs diff --git a/src/Abstractions/src/LangChain.Databases.Abstractions.csproj b/src/Abstractions/src/LangChain.Databases.Abstractions.csproj index 4bdea2a..1833ff6 100644 --- a/src/Abstractions/src/LangChain.Databases.Abstractions.csproj +++ b/src/Abstractions/src/LangChain.Databases.Abstractions.csproj @@ -12,6 +12,8 @@ + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/Abstractions/src/MessageHistory/BaseChatMessageHistory.cs b/src/Abstractions/src/MessageHistory/BaseChatMessageHistory.cs new file mode 100644 index 0000000..c8f85d8 --- /dev/null +++ b/src/Abstractions/src/MessageHistory/BaseChatMessageHistory.cs @@ -0,0 +1,97 @@ +using LangChain.Providers; + +namespace LangChain.Memory; + +/// +/// Abstract base class for storing chat message history. +/// +/// Implementations should over-ride the AddMessages method to handle bulk addition +/// of messages. +/// +/// The default implementation of AddMessages will correctly call AddMessage, so +/// it is not necessary to implement both methods. +/// +/// When used for updating history, users should favor usage of `AddMessages` +/// over `AddMessage` or other variants like `AddUserMessage` and `AddAiMessage` +/// to avoid unnecessary round-trips to the underlying persistence layer. +/// +public abstract class BaseChatMessageHistory +{ + /// + /// A list of messages stored in-memory. + /// + public abstract IReadOnlyList Messages { get; } + + /// + /// Convenience method for adding a human message string to the store. + /// + /// Please note that this is a convenience method. Code should favor the + /// bulk AddMessages interface instead to save on round-trips to the underlying + /// persistence layer. + /// + /// This method may be deprecated in a future release. + /// + /// The human message to add + public async Task AddUserMessage(string message) + { + await AddMessage(message.AsHumanMessage()).ConfigureAwait(false); + } + + /// + /// Convenience method for adding an AI message string to the store. + /// + /// Please note that this is a convenience method. Code should favor the bulk + /// AddMessages interface instead to save on round-trips to the underlying + /// persistence layer. + /// + /// This method may be deprecated in a future release. + /// + /// + public async Task AddAiMessage(string message) + { + await AddMessage(message.AsAiMessage()).ConfigureAwait(false); + } + + /// + /// Add a message object to the store. + /// + /// A message object to store + public abstract Task AddMessage(Message message); + + /// + /// Add a list of messages. + /// + /// Implementations should override this method to handle bulk addition of messages + /// in an efficient manner to avoid unnecessary round-trips to the underlying store. + /// + /// A list of message objects to store. + public virtual async Task AddMessages(IEnumerable messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + foreach (var message in messages) + { + await AddMessage(message).ConfigureAwait(false); + } + } + + /// + /// Replace the list of messages. + /// + /// Implementations should override this method to handle bulk addition of messages + /// in an efficient manner to avoid unnecessary round-trips to the underlying store. + /// + /// A list of message objects to store. + public virtual async Task SetMessages(IEnumerable messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + await Clear().ConfigureAwait(false); + await AddMessages(messages).ConfigureAwait(false); + } + + /// + /// Remove all messages from the store + /// + public abstract Task Clear(); +} \ No newline at end of file diff --git a/src/Abstractions/src/MessageHistory/ChatMessageHistory.cs b/src/Abstractions/src/MessageHistory/ChatMessageHistory.cs new file mode 100644 index 0000000..39525ed --- /dev/null +++ b/src/Abstractions/src/MessageHistory/ChatMessageHistory.cs @@ -0,0 +1,40 @@ +using LangChain.Providers; + +namespace LangChain.Memory; + +/// +/// In memory implementation of chat message history. +/// +/// Stores messages in an in memory list. +/// +public class ChatMessageHistory : BaseChatMessageHistory +{ + private readonly List _messages = new List(); + + /// + /// Used to inspect and filter messages on their way to the history store + /// NOTE: This is not a feature of python langchain + /// + public Predicate IsMessageAccepted { get; set; } = (x => true); + + /// + public override IReadOnlyList Messages => _messages; + + /// + public override Task AddMessage(Message message) + { + if (IsMessageAccepted(message)) + { + _messages.Add(message); + } + + return Task.CompletedTask; + } + + /// + public override Task Clear() + { + _messages.Clear(); + return Task.CompletedTask; + } +} \ No newline at end of file diff --git a/src/Abstractions/src/MessageHistory/FileChatMessageHistory.cs b/src/Abstractions/src/MessageHistory/FileChatMessageHistory.cs new file mode 100644 index 0000000..4972f98 --- /dev/null +++ b/src/Abstractions/src/MessageHistory/FileChatMessageHistory.cs @@ -0,0 +1,81 @@ +using LangChain.Providers; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LangChain.Memory; + +/// +/// Chat message history that stores history in a local file. +/// +public class FileChatMessageHistory : BaseChatMessageHistory +{ + private string MessagesFilePath { get; } + + private List _messages = new List(); + + /// + public override IReadOnlyList Messages => _messages; + + /// + /// Initializes new history instance with provided file path + /// + /// path of the local file to store the messages + /// + private FileChatMessageHistory(string messagesFilePath) + { + MessagesFilePath = messagesFilePath ?? throw new ArgumentNullException(nameof(messagesFilePath)); + } + + /// + /// Create new history instance with provided file path + /// + /// path of the local file to store the messages + /// + public static async Task CreateAsync(string path, CancellationToken cancellationToken = default) + { + FileChatMessageHistory chatHistory = new FileChatMessageHistory(path); + await chatHistory.LoadMessages().ConfigureAwait(false); + + return chatHistory; + } + + /// + public override Task AddMessage(Message message) + { + _messages.Add(message); + SaveMessages(); + + return Task.CompletedTask; + } + + /// + public override Task Clear() + { + _messages.Clear(); + SaveMessages(); + + return Task.CompletedTask; + } + + private void SaveMessages() + { + var json = JsonSerializer.Serialize(_messages, SourceGenerationContext.Default.ListMessage); + + File.WriteAllText(MessagesFilePath, json); + } + + private async Task LoadMessages() + { + if (File.Exists(MessagesFilePath)) + { + var json = await File2.ReadAllTextAsync(MessagesFilePath).ConfigureAwait(false); + if (!string.IsNullOrWhiteSpace(json)) + { + _messages = JsonSerializer.Deserialize(json, SourceGenerationContext.Default.ListMessage) ?? new List(); + } + } + } +} + +[JsonSerializable(typeof(List))] +internal sealed partial class SourceGenerationContext : JsonSerializerContext; \ No newline at end of file diff --git a/src/Abstractions/src/PublicAPI.Unshipped.txt b/src/Abstractions/src/PublicAPI.Unshipped.txt index 91fae08..f35c19b 100644 --- a/src/Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Abstractions/src/PublicAPI.Unshipped.txt @@ -1,3 +1,6 @@ +abstract LangChain.Memory.BaseChatMessageHistory.AddMessage(LangChain.Providers.Message message) -> System.Threading.Tasks.Task! +abstract LangChain.Memory.BaseChatMessageHistory.Clear() -> System.Threading.Tasks.Task! +abstract LangChain.Memory.BaseChatMessageHistory.Messages.get -> System.Collections.Generic.IReadOnlyList! const LangChain.Databases.VectorCollection.DefaultName = "langchain" -> string! LangChain.Databases.DistanceStrategy LangChain.Databases.DistanceStrategy.Cosine = 1 -> LangChain.Databases.DistanceStrategy @@ -67,6 +70,21 @@ LangChain.Databases.VectorSearchType LangChain.Databases.VectorSearchType.MaximumMarginalRelevance = 2 -> LangChain.Databases.VectorSearchType LangChain.Databases.VectorSearchType.Similarity = 0 -> LangChain.Databases.VectorSearchType LangChain.Databases.VectorSearchType.SimilarityScoreThreshold = 1 -> LangChain.Databases.VectorSearchType +LangChain.Memory.BaseChatMessageHistory +LangChain.Memory.BaseChatMessageHistory.AddAiMessage(string! message) -> System.Threading.Tasks.Task! +LangChain.Memory.BaseChatMessageHistory.AddUserMessage(string! message) -> System.Threading.Tasks.Task! +LangChain.Memory.BaseChatMessageHistory.BaseChatMessageHistory() -> void +LangChain.Memory.ChatMessageHistory +LangChain.Memory.ChatMessageHistory.ChatMessageHistory() -> void +LangChain.Memory.ChatMessageHistory.IsMessageAccepted.get -> System.Predicate! +LangChain.Memory.ChatMessageHistory.IsMessageAccepted.set -> void +LangChain.Memory.FileChatMessageHistory +override LangChain.Memory.ChatMessageHistory.AddMessage(LangChain.Providers.Message message) -> System.Threading.Tasks.Task! +override LangChain.Memory.ChatMessageHistory.Clear() -> System.Threading.Tasks.Task! +override LangChain.Memory.ChatMessageHistory.Messages.get -> System.Collections.Generic.IReadOnlyList! +override LangChain.Memory.FileChatMessageHistory.AddMessage(LangChain.Providers.Message message) -> System.Threading.Tasks.Task! +override LangChain.Memory.FileChatMessageHistory.Clear() -> System.Threading.Tasks.Task! +override LangChain.Memory.FileChatMessageHistory.Messages.get -> System.Collections.Generic.IReadOnlyList! static LangChain.Databases.RelevanceScoreFunctions.Cosine(float distance) -> float static LangChain.Databases.RelevanceScoreFunctions.Euclidean(float distance) -> float static LangChain.Databases.RelevanceScoreFunctions.Get(LangChain.Databases.DistanceStrategy distanceStrategy) -> System.Func! @@ -74,4 +92,7 @@ static LangChain.Databases.RelevanceScoreFunctions.MaxInnerProduct(float distanc static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]! embedding) -> LangChain.Databases.VectorSearchRequest! static LangChain.Databases.VectorSearchRequest.implicit operator LangChain.Databases.VectorSearchRequest!(float[]![]! embeddings) -> LangChain.Databases.VectorSearchRequest! static LangChain.Databases.VectorSearchRequest.ToVectorSearchRequest(float[]! embedding) -> LangChain.Databases.VectorSearchRequest! -static LangChain.Databases.VectorSearchRequest.ToVectorSearchRequest(float[]![]! embeddings) -> LangChain.Databases.VectorSearchRequest! \ No newline at end of file +static LangChain.Databases.VectorSearchRequest.ToVectorSearchRequest(float[]![]! embeddings) -> LangChain.Databases.VectorSearchRequest! +static LangChain.Memory.FileChatMessageHistory.CreateAsync(string! path, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +virtual LangChain.Memory.BaseChatMessageHistory.AddMessages(System.Collections.Generic.IEnumerable! messages) -> System.Threading.Tasks.Task! +virtual LangChain.Memory.BaseChatMessageHistory.SetMessages(System.Collections.Generic.IEnumerable! messages) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 4eba02a..8237db2 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -12,11 +12,11 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - - + - + + diff --git a/src/IntegrationTests/DatabaseTests.OpenSearch.cs b/src/IntegrationTests/DatabaseTests.OpenSearch.cs deleted file mode 100644 index 201c67a..0000000 --- a/src/IntegrationTests/DatabaseTests.OpenSearch.cs +++ /dev/null @@ -1,242 +0,0 @@ -using LangChain.Extensions; -using LangChain.Providers; -using LangChain.Providers.Amazon.Bedrock; -using LangChain.Providers.Amazon.Bedrock.Predefined.Amazon; -using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic; -using LangChain.DocumentLoaders; -using static LangChain.Chains.Chain; - -namespace LangChain.Databases.IntegrationTests; - -public partial class OpenSearchTests -{ - #region Query Images - - private static async Task SetupImageTestsAsync() - { - var environment = await DatabaseTests.StartEnvironmentForAsync(SupportedDatabase.OpenSearch); - environment.Dimensions = 1024; - environment.EmbeddingModel = new TitanEmbedImageV1Model(new BedrockProvider()) - { - Settings = new BedrockEmbeddingSettings - { - Dimensions = environment.Dimensions, - } - }; - - return environment; - } - - [Test] - [Explicit] - public async Task index_test_images() - { - await using var environment = await SetupImageTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - string[] extensions = { ".bmp", ".gif", ".jpg", ".jpeg", ".png", ".tiff" }; - var files = Directory.EnumerateFiles(@"[images directory]", "*.*", SearchOption.AllDirectories) - .Where(s => extensions.Any(ext => ext == Path.GetExtension(s))); - - var images = files.ToBinaryData(); - - var documents = new List(); - - foreach (BinaryData image in images) - { - var model = new Claude3HaikuModel(new BedrockProvider()); - var message = new Message(" \"what's this an image of and describe the details?\"", MessageRole.Human); - - var chatRequest = ChatRequest.ToChatRequest(message); - chatRequest.Image = image; - - var response = await model.GenerateAsync(chatRequest); - - var document = new Document - { - PageContent = response, - Metadata = new Dictionary - { - {response, image} - } - }; - - documents.Add(document); - } - - var pages = await vectorCollection.AddDocumentsAsync(environment.EmbeddingModel, documents); - } - - [Test] - [Explicit] - public async Task can_query_image_against_images() - { - await using var environment = await SetupImageTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var path = Path.Combine(Path.GetTempPath(), "test_image.jpg"); - var imageData = await File.ReadAllBytesAsync(path); - var binaryData = new BinaryData(imageData, "image/jpg"); - - var embeddingRequest = new EmbeddingRequest - { - Strings = new List(), - Images = new List { Data.FromBytes(binaryData.ToArray()) } - }; - var embedding = await environment.EmbeddingModel.CreateEmbeddingsAsync(embeddingRequest) - .ConfigureAwait(false); - - var floats = embedding.ToSingleArray(); - var similaritySearchByVectorAsync = await vectorCollection.SearchAsync(floats).ConfigureAwait(false); - - Console.WriteLine("Count: " + similaritySearchByVectorAsync.Items.Count); - } - - [Test] - [Explicit] - public async Task can_query_text_against_images() - { - await using var environment = await SetupImageTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var llm = new Claude3SonnetModel(new BedrockProvider()); - - var promptText = - @"Use the following pieces of context to answer the question at the end. If the answer is not in context then just say that you don't know, don't try to make up an answer. Keep the answer as short as possible. - -{context} - -Question: {question} -Helpful Answer:"; - - var chain = - Set("tell me about the orange shirt", outputKey: "question") // set the question - | RetrieveDocuments(vectorCollection, environment.EmbeddingModel, inputKey: "question", outputKey: "documents", amount: 10) // take 5 most similar documents - | StuffDocuments(inputKey: "documents", outputKey: "context") // combine documents together and put them into context - | Template(promptText) // replace context and question in the prompt with their values - | LLM(llm); // send the result to the language model - - var res = await chain.RunAsync("text"); - Console.WriteLine(res); - } - - #endregion - - #region Query Simple Documents - - private static async Task SetupDocumentTestsAsync() - { - var environment = await DatabaseTests.StartEnvironmentForAsync(SupportedDatabase.OpenSearch); - environment.Dimensions = 1536; - environment.EmbeddingModel = new TitanEmbedTextV1Model(new BedrockProvider()) - { - Settings = new BedrockEmbeddingSettings - { - Dimensions = environment.Dimensions, - } - }; - - return environment; - } - - [Test] - [Explicit] - public async Task index_test_documents() - { - await using var environment = await SetupDocumentTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var documents = new[] - { - "I spent entire day watching TV", - "My dog's name is Bob", - "The car is orange", - "This icecream is delicious", - "It is cold in space", - }.ToDocuments(); - - var pages = await vectorCollection.AddDocumentsAsync(environment.EmbeddingModel, documents); - Console.WriteLine("pages: " + pages.Count); - } - - [Test] - [Explicit] - public async Task can_query_test_documents() - { - await using var environment = await SetupDocumentTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var llm = new Claude3SonnetModel(new BedrockProvider()); - - const string question = "what color is the car?"; - - var promptText = - @"Use the following pieces of context to answer the question at the end. If the answer is not in context then just say that you don't know, don't try to make up an answer. Keep the answer as short as possible. - -{context} - -Question: {question} -Helpful Answer:"; - var chain = - Set(question, outputKey: "question") - | RetrieveDocuments(vectorCollection, environment.EmbeddingModel, inputKey: "question", outputKey: "documents", amount: 2) - | StuffDocuments(inputKey: "documents", outputKey: "context") - | Template(promptText) - | LLM(llm); - - - var res = await chain.RunAsync("text"); - Console.WriteLine(res); - } - - #endregion - - #region Query Pdf Book - - [Test] - [Explicit] - public async Task index_harry_potter_book() - { - await using var environment = await SetupDocumentTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var pdfSource = new PdfPigPdfLoader(); - var documents = await pdfSource.LoadAsync(DataSource.FromPath("x:\\Harry-Potter-Book-1.pdf")); - - var pages = await vectorCollection.AddDocumentsAsync(environment.EmbeddingModel, documents); - Console.WriteLine("pages: " + pages.Count()); - } - - [Test] - [Explicit] - public async Task can_query_harry_potter_book() - { - await using var environment = await SetupDocumentTestsAsync(); - var vectorCollection = await environment.VectorDatabase.GetOrCreateCollectionAsync(VectorCollection.DefaultName, environment.Dimensions); - - var llm = new Claude3SonnetModel(new BedrockProvider()); - - var promptText = - @"Use the following pieces of context to answer the question at the end. If the answer is not in context then just say that you don't know, don't try to make up an answer. Keep the answer as short as possible. - -{context} - -Question: {question} -Helpful Answer:"; - - var chain = - //Set("what color is the car?", outputKey: "question") // set the question - //Set("Hagrid was looking for the golden key. Where was it?", outputKey: "question") // set the question - // Set("Who was on the Dursleys front step?", outputKey: "question") // set the question - Set("Who was drinking a unicorn blood?", outputKey: "question") // set the question - | RetrieveDocuments(vectorCollection, environment.EmbeddingModel, inputKey: "question", outputKey: "documents", amount: 10) // take 5 most similar documents - | StuffDocuments(inputKey: "documents", outputKey: "context") // combine documents together and put them into context - | Template(promptText) // replace context and question in the prompt with their values - | LLM(llm); // send the result to the language model - - var res = await chain.RunAsync("text"); - Console.WriteLine(res); - } - - #endregion -} \ No newline at end of file diff --git a/src/IntegrationTests/Extensions/VectorCollectionExtensions.cs b/src/IntegrationTests/Extensions/VectorCollectionExtensions.cs new file mode 100644 index 0000000..5523da3 --- /dev/null +++ b/src/IntegrationTests/Extensions/VectorCollectionExtensions.cs @@ -0,0 +1,118 @@ +using LangChain.Databases; +using LangChain.Providers; +using LangChain.DocumentLoaders; + +namespace LangChain.Extensions; + +/// +/// +/// +public static class VectorCollectionExtensions +{ + /// + /// Return documents most similar to query. + /// + /// + /// + /// + /// + /// + /// + /// + public static async Task SearchAsync( + this IVectorCollection vectorCollection, + IEmbeddingModel embeddingModel, + EmbeddingRequest embeddingRequest, + EmbeddingSettings? embeddingSettings = default, + VectorSearchSettings? searchSettings = default, + CancellationToken cancellationToken = default) + { + vectorCollection = vectorCollection ?? throw new ArgumentNullException(nameof(vectorCollection)); + embeddingModel = embeddingModel ?? throw new ArgumentNullException(nameof(embeddingModel)); + searchSettings ??= new VectorSearchSettings(); + + if (searchSettings is { Type: VectorSearchType.SimilarityScoreThreshold, ScoreThreshold: null }) + { + throw new ArgumentException($"ScoreThreshold required for {searchSettings.Type}"); + } + + var response = await embeddingModel.CreateEmbeddingsAsync( + request: embeddingRequest, + settings: embeddingSettings, + cancellationToken: cancellationToken).ConfigureAwait(false); + + return await vectorCollection.SearchAsync(new VectorSearchRequest + { + Embeddings = [response.ToSingleArray()], + }, searchSettings, cancellationToken).ConfigureAwait(false); + } + + public static async Task> AddDocumentsAsync( + this IVectorCollection vectorCollection, + IEmbeddingModel embeddingModel, + IReadOnlyCollection documents, + EmbeddingSettings? embeddingSettings = default, + CancellationToken cancellationToken = default) + { + vectorCollection = vectorCollection ?? throw new ArgumentNullException(nameof(vectorCollection)); + embeddingModel = embeddingModel ?? throw new ArgumentNullException(nameof(embeddingModel)); + + return await vectorCollection.AddTextsAsync( + embeddingModel: embeddingModel, + texts: documents.Select(x => x.PageContent).ToArray(), + metadatas: documents.Select(x => x.Metadata).ToArray(), + embeddingSettings: embeddingSettings, + cancellationToken).ConfigureAwait(false); + } + + public static async Task GetDocumentByIdAsync( + this IVectorCollection vectorCollection, + string id, + CancellationToken cancellationToken = default) + { + vectorCollection = vectorCollection ?? throw new ArgumentNullException(nameof(vectorCollection)); + + var item = await vectorCollection.GetAsync(id, cancellationToken).ConfigureAwait(false); + + return item == null + ? null + : new Document(item.Text, item.Metadata?.ToDictionary(x => x.Key, x => x.Value)); + } + + public static async Task> AddTextsAsync( + this IVectorCollection vectorCollection, + IEmbeddingModel embeddingModel, + IReadOnlyCollection texts, + IReadOnlyCollection>? metadatas = null, + EmbeddingSettings? embeddingSettings = default, + CancellationToken cancellationToken = default) + { + vectorCollection = vectorCollection ?? throw new ArgumentNullException(nameof(vectorCollection)); + embeddingModel = embeddingModel ?? throw new ArgumentNullException(nameof(embeddingModel)); + + var embeddingRequest = new EmbeddingRequest + { + Strings = texts.ToArray(), + Images = metadatas? + .Select((metadata, i) => metadata.TryGetValue(texts.ElementAt(i), out object? result) + ? result as BinaryData + : null) + .Where(x => x != null) + .Select(x => Data.FromBytes(x!.ToArray())) + .ToArray() ?? [], + }; + + float[][] embeddings = await embeddingModel + .CreateEmbeddingsAsync(embeddingRequest, embeddingSettings, cancellationToken) + .ConfigureAwait(false); + + return await vectorCollection.AddAsync( + items: texts.Select((text, i) => new Vector + { + Text = text, + Metadata = metadatas?.ElementAt(i).ToDictionary(x => x.Key, x => x.Value), + Embedding = embeddings[i], + }).ToArray(), + cancellationToken).ConfigureAwait(false); + } +} \ No newline at end of file diff --git a/src/IntegrationTests/LangChain.Databases.IntegrationTests.csproj b/src/IntegrationTests/LangChain.Databases.IntegrationTests.csproj index b233b51..659ad4f 100644 --- a/src/IntegrationTests/LangChain.Databases.IntegrationTests.csproj +++ b/src/IntegrationTests/LangChain.Databases.IntegrationTests.csproj @@ -26,7 +26,9 @@ - + + + @@ -34,7 +36,6 @@ - diff --git a/src/Mongo/src/LangChain.Databases.Mongo.csproj b/src/Mongo/src/LangChain.Databases.Mongo.csproj index 0d710a9..1fc2e9f 100644 --- a/src/Mongo/src/LangChain.Databases.Mongo.csproj +++ b/src/Mongo/src/LangChain.Databases.Mongo.csproj @@ -12,7 +12,6 @@ - diff --git a/src/Redis/src/LangChain.Databases.Redis.csproj b/src/Redis/src/LangChain.Databases.Redis.csproj index a346938..5e3729c 100644 --- a/src/Redis/src/LangChain.Databases.Redis.csproj +++ b/src/Redis/src/LangChain.Databases.Redis.csproj @@ -11,7 +11,6 @@ -