diff --git a/webapi/Plugins/Chat/ChatPlugin.cs b/webapi/Plugins/Chat/ChatPlugin.cs index 109cb3c08..874fe329f 100644 --- a/webapi/Plugins/Chat/ChatPlugin.cs +++ b/webapi/Plugins/Chat/ChatPlugin.cs @@ -107,89 +107,6 @@ public ChatPlugin( this._contentSafety = contentSafety; } - /// - /// Extract user intent from the conversation history. - /// - /// The KernelArguments. - /// The cancellation token. - private async Task ExtractUserIntentAsync(KernelArguments kernelArguments, CancellationToken cancellationToken = default) - { - var tokenLimit = this._promptOptions.CompletionTokenLimit; - var historyTokenBudget = - tokenLimit - - this._promptOptions.ResponseTokenLimit - - TokenUtils.TokenCount(string.Join("\n", new string[] - { - this._promptOptions.SystemDescription, - this._promptOptions.SystemIntent, - this._promptOptions.SystemIntentContinuation - }) - ); - - // Clone the context to avoid modifying the original context variables. - KernelArguments intentExtractionContext = new(kernelArguments); - intentExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo()); - intentExtractionContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate; - - var completionFunction = this._kernel.CreateFunctionFromPrompt( - this._promptOptions.SystemIntentExtraction, - this.CreateIntentCompletionSettings(), - functionName: nameof(ChatPlugin), - description: "Complete the prompt."); - - var result = await completionFunction.InvokeAsync( - this._kernel, - intentExtractionContext, - cancellationToken - ); - - // Get token usage from ChatCompletion result and add to context - TokenUtils.GetFunctionTokenUsage(result, intentExtractionContext, this._logger, "SystemIntentExtraction"); - - return $"User intent: {result}"; - } - - /// - /// Extract the list of participants from the conversation history. - /// Note that only those who have spoken will be included. - /// - /// The SKContext. - /// The cancellation token. - private async Task ExtractAudienceAsync(KernelArguments context, CancellationToken cancellationToken = default) - { - var tokenLimit = this._promptOptions.CompletionTokenLimit; - var historyTokenBudget = - tokenLimit - - this._promptOptions.ResponseTokenLimit - - TokenUtils.TokenCount(string.Join("\n", new string[] - { - this._promptOptions.SystemAudience, - this._promptOptions.SystemAudienceContinuation, - }) - ); - - // Clone the context to avoid modifying the original context variables. - KernelArguments audienceExtractionContext = new(context); - audienceExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo()); - - var completionFunction = this._kernel.CreateFunctionFromPrompt( - this._promptOptions.SystemAudienceExtraction, - this.CreateIntentCompletionSettings(), - functionName: nameof(ChatPlugin), - description: "Complete the prompt."); - - var result = await completionFunction.InvokeAsync( - this._kernel, - audienceExtractionContext, - cancellationToken - ); - - // Get token usage from ChatCompletion result and add to context - TokenUtils.GetFunctionTokenUsage(result, context, this._logger, "SystemAudienceExtraction"); - - return $"List of participants: {result}"; - } - /// /// Method that wraps GetAllowedChatHistoryAsync to get allotted history messages as one string. /// GetAllowedChatHistoryAsync optionally updates a ChatHistory object with the allotted messages, @@ -324,7 +241,7 @@ private async Task GetChatResponseAsync(string chatId, strin // Render system instruction components and create the meta-prompt template var systemInstructions = await AsyncUtils.SafeInvokeAsync( () => this.RenderSystemInstructions(chatId, chatContext, cancellationToken), nameof(RenderSystemInstructions)); - ChatHistory chatHistory = new(systemInstructions); + ChatHistory metaPrompt = new(systemInstructions); // Bypass audience extraction if Auth is disabled var audience = string.Empty; @@ -334,41 +251,42 @@ private async Task GetChatResponseAsync(string chatId, strin await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting audience", cancellationToken); audience = await AsyncUtils.SafeInvokeAsync( () => this.GetAudienceAsync(chatContext, cancellationToken), nameof(GetAudienceAsync)); - chatHistory.AddSystemMessage(audience); + metaPrompt.AddSystemMessage(audience); } // Extract user intent from the conversation history. await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting user intent", cancellationToken); var userIntent = await AsyncUtils.SafeInvokeAsync( () => this.GetUserIntentAsync(chatContext, cancellationToken), nameof(GetUserIntentAsync)); - chatHistory.AddSystemMessage(userIntent); + metaPrompt.AddSystemMessage(userIntent); - // Calculate the remaining token budget. - var remainingTokenBudget = this.GetChatContextTokenLimit(chatHistory, userMessage.ToFormattedString()); + // Calculate max amount of tokens to use for memories + int maxRequestTokenBudget = this.GetMaxRequestTokenBudget(); + // Calculate tokens used so far: system instructions, audience extraction and user intent + int tokensUsed = TokenUtils.GetContextMessagesTokenCount(metaPrompt); + int chatMemoryTokenBudget = maxRequestTokenBudget + - tokensUsed + - TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userMessage.ToFormattedString()); + chatMemoryTokenBudget = (int)(chatMemoryTokenBudget * this._promptOptions.MemoriesResponseContextWeight); // Query relevant semantic and document memories await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting semantic and document memories", cancellationToken); - var chatMemoriesTokenLimit = (int)(remainingTokenBudget * this._promptOptions.MemoriesResponseContextWeight); - (var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoriesTokenLimit); - + (var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoryTokenBudget); if (!string.IsNullOrWhiteSpace(memoryText)) { - chatHistory.AddSystemMessage(memoryText); + metaPrompt.AddSystemMessage(memoryText); + tokensUsed += TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText); } - // Fill in the chat history with remaining token budget. - string allowedChatHistory = string.Empty; - var allowedChatHistoryTokenBudget = remainingTokenBudget - TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText); - - // Append previous messages + // Add as many chat history messages to meta-prompt as the token budget will allow await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting chat history", cancellationToken); - allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, allowedChatHistoryTokenBudget, chatHistory, cancellationToken); + string allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, maxRequestTokenBudget - tokensUsed, metaPrompt, cancellationToken); - // Calculate token usage of prompt template - chatContext[TokenUtils.GetFunctionKey(this._logger, "SystemMetaPrompt")!] = TokenUtils.GetContextMessagesTokenCount(chatHistory).ToString(CultureInfo.CurrentCulture); + // Store token usage of prompt template + chatContext[TokenUtils.GetFunctionKey("SystemMetaPrompt")] = TokenUtils.GetContextMessagesTokenCount(metaPrompt).ToString(CultureInfo.CurrentCulture); // Stream the response to the client - var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, chatHistory); + var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, metaPrompt); return await this.HandleBotResponseAsync(chatId, userId, chatContext, promptView, citationMap.Values.AsEnumerable(), cancellationToken); } @@ -429,7 +347,7 @@ await AsyncUtils.SafeInvokeAsync( cancellationToken), nameof(SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync)); // Calculate total token usage for dependency functions and prompt template - await this.UpdateBotResponseStatusOnClientAsync(chatId, "Calculating token usage", cancellationToken); + await this.UpdateBotResponseStatusOnClientAsync(chatId, "Saving token usage", cancellationToken); chatMessage.TokenUsage = this.GetTokenUsages(chatContext, chatMessage.Content); // Update the message on client and in chat history with final completion token usage @@ -449,16 +367,38 @@ private async Task GetAudienceAsync(KernelArguments context, Cancellatio { // Clone the context to avoid modifying the original context variables KernelArguments audienceContext = new(context); - var audience = await this.ExtractAudienceAsync(audienceContext, cancellationToken); + int historyTokenBudget = + this._promptOptions.CompletionTokenLimit - + this._promptOptions.ResponseTokenLimit - + TokenUtils.TokenCount(string.Join("\n\n", new string[] + { + this._promptOptions.SystemAudience, + this._promptOptions.SystemAudienceContinuation, + }) + ); + + audienceContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo()); - // Copy token usage into original chat context - var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemAudienceExtraction")!; - if (audienceContext.TryGetValue(functionKey, out object? tokenUsage)) + var completionFunction = this._kernel.CreateFunctionFromPrompt( + this._promptOptions.SystemAudienceExtraction, + this.CreateIntentCompletionSettings(), + functionName: "SystemAudienceExtraction", + description: "Extract audience"); + + var result = await completionFunction.InvokeAsync(this._kernel, audienceContext, cancellationToken); + + // Get token usage from ChatCompletion result and add to original context + string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger); + if (tokenUsage is not null) { - context[functionKey] = tokenUsage; + context[TokenUtils.GetFunctionKey("SystemAudienceExtraction")] = tokenUsage; + } + else + { + this._logger.LogError("Unable to determine token usage for audienceExtraction"); } - return audience; + return $"List of participants: {result}"; } /// @@ -470,16 +410,41 @@ private async Task GetUserIntentAsync(KernelArguments context, Cancellat { // Clone the context to avoid modifying the original context variables KernelArguments intentContext = new(context); - string userIntent = await this.ExtractUserIntentAsync(intentContext, cancellationToken); - // Copy token usage into original chat context - var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemIntentExtraction")!; - if (intentContext.TryGetValue(functionKey!, out object? tokenUsage)) + int tokenBudget = + this._promptOptions.CompletionTokenLimit - + this._promptOptions.ResponseTokenLimit - + TokenUtils.TokenCount(string.Join("\n", new string[] + { + this._promptOptions.SystemPersona, + this._promptOptions.SystemIntent, + this._promptOptions.SystemIntentContinuation + }) + ); + + intentContext["tokenLimit"] = tokenBudget.ToString(new NumberFormatInfo()); + intentContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate; + + var completionFunction = this._kernel.CreateFunctionFromPrompt( + this._promptOptions.SystemIntentExtraction, + this.CreateIntentCompletionSettings(), + functionName: "UserIntentExtraction", + description: "Extract user intent"); + + var result = await completionFunction.InvokeAsync(this._kernel, intentContext, cancellationToken); + + // Get token usage from ChatCompletion result and add to original context + string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger); + if (tokenUsage is not null) { - context[functionKey!] = tokenUsage; + context[TokenUtils.GetFunctionKey("SystemIntentExtraction")] = tokenUsage; + } + else + { + this._logger.LogError("Unable to determine token usage for userIntentExtraction"); } - return userIntent; + return $"User intent: {result}"; } /// @@ -610,24 +575,18 @@ private OpenAIPromptExecutionSettings CreateIntentCompletionSettings() } /// - /// Calculate the remaining token budget for the chat response prompt. - /// This is the token limit minus the token count of the user intent, audience, and the system commands. + /// Calculate the maximum number of tokens that can be sent in a request /// - /// All current messages to use for chat completion - /// The user message. - /// The remaining token limit. - private int GetChatContextTokenLimit(ChatHistory promptTemplate, string userInput = "") + private int GetMaxRequestTokenBudget() { // OpenAI inserts a message under the hood: // "content": "Assistant is a large language model.","role": "system" // This burns just under 20 tokens which need to be accounted for. const int ExtraOpenAiMessageTokens = 20; - return this._promptOptions.CompletionTokenLimit + return this._promptOptions.CompletionTokenLimit // Total token limit - ExtraOpenAiMessageTokens - - TokenUtils.GetContextMessagesTokenCount(promptTemplate) - - TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userInput) // User message has to be included in chat history allowance - - this._promptOptions.ResponseTokenLimit; + - this._promptOptions.ResponseTokenLimit; // Token count reserved for model to generate a response } /// diff --git a/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs b/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs index 9b1c6d4eb..6c7f1d90f 100644 --- a/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs +++ b/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs @@ -91,10 +91,19 @@ async Task ExtractCognitiveMemoryAsync(string memoryType, st cancellationToken); // Get token usage from ChatCompletion result and add to context - // Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type. - TokenUtils.GetFunctionTokenUsage(result, kernelArguments, logger, $"SystemCognitive_{memoryType}"); + string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, logger); + if (tokenUsage is not null) + { + // Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type. + kernelArguments[TokenUtils.GetFunctionKey($"SystemCognitive_{memoryType}")] = tokenUsage; + } + else + { + logger.LogError("Unable to determine token usage for {0}", $"SystemCognitive_{memoryType}"); + } SemanticChatMemory memory = SemanticChatMemory.FromJson(result.ToString()); + return memory; } diff --git a/webapi/Plugins/Utils/TokenUtils.cs b/webapi/Plugins/Utils/TokenUtils.cs index 4d9f816e4..f3adad225 100644 --- a/webapi/Plugins/Utils/TokenUtils.cs +++ b/webapi/Plugins/Utils/TokenUtils.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Globalization; using System.Linq; @@ -22,7 +21,7 @@ public static class TokenUtils /// Semantic dependencies of ChatPlugin. /// If you add a new semantic dependency, please add it here. /// - public static readonly Dictionary semanticFunctions = new() + internal static readonly Dictionary semanticFunctions = new() { { "SystemAudienceExtraction", "audienceExtraction" }, { "SystemIntentExtraction", "userIntentExtraction" }, @@ -38,21 +37,19 @@ public static class TokenUtils /// internal static Dictionary EmptyTokenUsages() { - return semanticFunctions.Values.ToDictionary(v => v, v => 0, StringComparer.OrdinalIgnoreCase); + return semanticFunctions.Values.ToDictionary(v => v, v => 0); } /// /// Gets key used to identify function token usage in context variables. /// - /// The logger instance to use for logging errors. /// Name of semantic function. /// The key corresponding to the semantic function name, or null if the function name is unknown. - internal static string? GetFunctionKey(ILogger logger, string? functionName) + internal static string GetFunctionKey(string? functionName) { if (functionName == null || !semanticFunctions.TryGetValue(functionName, out string? key)) { - logger.LogError("Unknown token dependency {0}. Please define function as semanticFunctions entry in TokenUtils.cs", functionName); - return null; + throw new KeyNotFoundException($"Unknown token dependency {functionName}. Please define function as semanticFunctions entry in TokenUtils.cs"); }; return $"{key}TokenUsage"; @@ -62,50 +59,32 @@ internal static Dictionary EmptyTokenUsages() /// Gets the total token usage from a Chat or Text Completion result context and adds it as a variable to response context. /// /// Result context from chat model - /// Context maintained during response generation. - /// The logger instance to use for logging errors. - /// Name of the function that invoked the chat completion. - /// true if token usage is found in result context; otherwise, false. - internal static void GetFunctionTokenUsage(FunctionResult result, KernelArguments kernelArguments, ILogger logger, string? functionName = null) + /// The logger instance to use for logging errors. + /// String representation of number of tokens used by function (or null on error) + internal static string? GetFunctionTokenUsage(FunctionResult result, ILogger logger) { + if (result.Metadata is null || + !result.Metadata.TryGetValue("Usage", out object? usageObject) || usageObject is null) + { + logger.LogError("No usage metadata provided"); + + return null; + } + + var tokenUsage = 0; try { - var functionKey = GetFunctionKey(logger, functionName); - if (functionKey == null) - { - return; - } - - if (result.Metadata is null) - { - logger.LogError("No metadata provided to capture usage details."); - return; - } - - if (!result.Metadata.TryGetValue("Usage", out object? usageObject) || usageObject is null) - { - logger.LogError("Unable to determine token usage for {0}", functionKey); - return; - } - - var tokenUsage = 0; - try - { - var jsonObject = JsonSerializer.Deserialize(JsonSerializer.Serialize(usageObject)); - tokenUsage = jsonObject.GetProperty("TotalTokens").GetInt32(); - } - catch (KeyNotFoundException) - { - logger.LogError("Usage details not found in model result."); - } - - kernelArguments[functionKey!] = tokenUsage.ToString(CultureInfo.InvariantCulture); + var jsonObject = JsonSerializer.Deserialize(JsonSerializer.Serialize(usageObject)); + tokenUsage = jsonObject.GetProperty("TotalTokens").GetInt32(); } - catch (Exception e) + catch (KeyNotFoundException) { - logger.LogError(e, "Unable to determine token usage for {0}", functionName); - throw e; + logger.LogError("Usage details not found in model result."); + + return null; } + + return tokenUsage.ToString(CultureInfo.InvariantCulture); } ///