diff --git a/webapi/Plugins/Chat/ChatPlugin.cs b/webapi/Plugins/Chat/ChatPlugin.cs
index 109cb3c08..874fe329f 100644
--- a/webapi/Plugins/Chat/ChatPlugin.cs
+++ b/webapi/Plugins/Chat/ChatPlugin.cs
@@ -107,89 +107,6 @@ public ChatPlugin(
this._contentSafety = contentSafety;
}
- ///
- /// Extract user intent from the conversation history.
- ///
- /// The KernelArguments.
- /// The cancellation token.
- private async Task ExtractUserIntentAsync(KernelArguments kernelArguments, CancellationToken cancellationToken = default)
- {
- var tokenLimit = this._promptOptions.CompletionTokenLimit;
- var historyTokenBudget =
- tokenLimit -
- this._promptOptions.ResponseTokenLimit -
- TokenUtils.TokenCount(string.Join("\n", new string[]
- {
- this._promptOptions.SystemDescription,
- this._promptOptions.SystemIntent,
- this._promptOptions.SystemIntentContinuation
- })
- );
-
- // Clone the context to avoid modifying the original context variables.
- KernelArguments intentExtractionContext = new(kernelArguments);
- intentExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());
- intentExtractionContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate;
-
- var completionFunction = this._kernel.CreateFunctionFromPrompt(
- this._promptOptions.SystemIntentExtraction,
- this.CreateIntentCompletionSettings(),
- functionName: nameof(ChatPlugin),
- description: "Complete the prompt.");
-
- var result = await completionFunction.InvokeAsync(
- this._kernel,
- intentExtractionContext,
- cancellationToken
- );
-
- // Get token usage from ChatCompletion result and add to context
- TokenUtils.GetFunctionTokenUsage(result, intentExtractionContext, this._logger, "SystemIntentExtraction");
-
- return $"User intent: {result}";
- }
-
- ///
- /// Extract the list of participants from the conversation history.
- /// Note that only those who have spoken will be included.
- ///
- /// The SKContext.
- /// The cancellation token.
- private async Task ExtractAudienceAsync(KernelArguments context, CancellationToken cancellationToken = default)
- {
- var tokenLimit = this._promptOptions.CompletionTokenLimit;
- var historyTokenBudget =
- tokenLimit -
- this._promptOptions.ResponseTokenLimit -
- TokenUtils.TokenCount(string.Join("\n", new string[]
- {
- this._promptOptions.SystemAudience,
- this._promptOptions.SystemAudienceContinuation,
- })
- );
-
- // Clone the context to avoid modifying the original context variables.
- KernelArguments audienceExtractionContext = new(context);
- audienceExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());
-
- var completionFunction = this._kernel.CreateFunctionFromPrompt(
- this._promptOptions.SystemAudienceExtraction,
- this.CreateIntentCompletionSettings(),
- functionName: nameof(ChatPlugin),
- description: "Complete the prompt.");
-
- var result = await completionFunction.InvokeAsync(
- this._kernel,
- audienceExtractionContext,
- cancellationToken
- );
-
- // Get token usage from ChatCompletion result and add to context
- TokenUtils.GetFunctionTokenUsage(result, context, this._logger, "SystemAudienceExtraction");
-
- return $"List of participants: {result}";
- }
-
///
/// Method that wraps GetAllowedChatHistoryAsync to get allotted history messages as one string.
/// GetAllowedChatHistoryAsync optionally updates a ChatHistory object with the allotted messages,
@@ -324,7 +241,7 @@ private async Task GetChatResponseAsync(string chatId, strin
// Render system instruction components and create the meta-prompt template
var systemInstructions = await AsyncUtils.SafeInvokeAsync(
() => this.RenderSystemInstructions(chatId, chatContext, cancellationToken), nameof(RenderSystemInstructions));
- ChatHistory chatHistory = new(systemInstructions);
+ ChatHistory metaPrompt = new(systemInstructions);
// Bypass audience extraction if Auth is disabled
var audience = string.Empty;
@@ -334,41 +251,42 @@ private async Task GetChatResponseAsync(string chatId, strin
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting audience", cancellationToken);
audience = await AsyncUtils.SafeInvokeAsync(
() => this.GetAudienceAsync(chatContext, cancellationToken), nameof(GetAudienceAsync));
- chatHistory.AddSystemMessage(audience);
+ metaPrompt.AddSystemMessage(audience);
}
// Extract user intent from the conversation history.
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting user intent", cancellationToken);
var userIntent = await AsyncUtils.SafeInvokeAsync(
() => this.GetUserIntentAsync(chatContext, cancellationToken), nameof(GetUserIntentAsync));
- chatHistory.AddSystemMessage(userIntent);
+ metaPrompt.AddSystemMessage(userIntent);
- // Calculate the remaining token budget.
- var remainingTokenBudget = this.GetChatContextTokenLimit(chatHistory, userMessage.ToFormattedString());
+ // Calculate max amount of tokens to use for memories
+ int maxRequestTokenBudget = this.GetMaxRequestTokenBudget();
+ // Calculate tokens used so far: system instructions, audience extraction and user intent
+ int tokensUsed = TokenUtils.GetContextMessagesTokenCount(metaPrompt);
+ int chatMemoryTokenBudget = maxRequestTokenBudget
+ - tokensUsed
+ - TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userMessage.ToFormattedString());
+ chatMemoryTokenBudget = (int)(chatMemoryTokenBudget * this._promptOptions.MemoriesResponseContextWeight);
// Query relevant semantic and document memories
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting semantic and document memories", cancellationToken);
- var chatMemoriesTokenLimit = (int)(remainingTokenBudget * this._promptOptions.MemoriesResponseContextWeight);
- (var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoriesTokenLimit);
-
+ (var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoryTokenBudget);
if (!string.IsNullOrWhiteSpace(memoryText))
{
- chatHistory.AddSystemMessage(memoryText);
+ metaPrompt.AddSystemMessage(memoryText);
+ tokensUsed += TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText);
}
- // Fill in the chat history with remaining token budget.
- string allowedChatHistory = string.Empty;
- var allowedChatHistoryTokenBudget = remainingTokenBudget - TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText);
-
- // Append previous messages
+ // Add as many chat history messages to meta-prompt as the token budget will allow
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting chat history", cancellationToken);
- allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, allowedChatHistoryTokenBudget, chatHistory, cancellationToken);
+ string allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, maxRequestTokenBudget - tokensUsed, metaPrompt, cancellationToken);
- // Calculate token usage of prompt template
- chatContext[TokenUtils.GetFunctionKey(this._logger, "SystemMetaPrompt")!] = TokenUtils.GetContextMessagesTokenCount(chatHistory).ToString(CultureInfo.CurrentCulture);
+ // Store token usage of prompt template
+ chatContext[TokenUtils.GetFunctionKey("SystemMetaPrompt")] = TokenUtils.GetContextMessagesTokenCount(metaPrompt).ToString(CultureInfo.CurrentCulture);
// Stream the response to the client
- var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, chatHistory);
+ var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, metaPrompt);
return await this.HandleBotResponseAsync(chatId, userId, chatContext, promptView, citationMap.Values.AsEnumerable(), cancellationToken);
}
@@ -429,7 +347,7 @@ await AsyncUtils.SafeInvokeAsync(
cancellationToken), nameof(SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync));
// Calculate total token usage for dependency functions and prompt template
- await this.UpdateBotResponseStatusOnClientAsync(chatId, "Calculating token usage", cancellationToken);
+ await this.UpdateBotResponseStatusOnClientAsync(chatId, "Saving token usage", cancellationToken);
chatMessage.TokenUsage = this.GetTokenUsages(chatContext, chatMessage.Content);
// Update the message on client and in chat history with final completion token usage
@@ -449,16 +367,38 @@ private async Task GetAudienceAsync(KernelArguments context, Cancellatio
{
// Clone the context to avoid modifying the original context variables
KernelArguments audienceContext = new(context);
- var audience = await this.ExtractAudienceAsync(audienceContext, cancellationToken);
+ int historyTokenBudget =
+ this._promptOptions.CompletionTokenLimit -
+ this._promptOptions.ResponseTokenLimit -
+ TokenUtils.TokenCount(string.Join("\n\n", new string[]
+ {
+ this._promptOptions.SystemAudience,
+ this._promptOptions.SystemAudienceContinuation,
+ })
+ );
+
+ audienceContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());
- // Copy token usage into original chat context
- var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemAudienceExtraction")!;
- if (audienceContext.TryGetValue(functionKey, out object? tokenUsage))
+ var completionFunction = this._kernel.CreateFunctionFromPrompt(
+ this._promptOptions.SystemAudienceExtraction,
+ this.CreateIntentCompletionSettings(),
+ functionName: "SystemAudienceExtraction",
+ description: "Extract audience");
+
+ var result = await completionFunction.InvokeAsync(this._kernel, audienceContext, cancellationToken);
+
+ // Get token usage from ChatCompletion result and add to original context
+ string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger);
+ if (tokenUsage is not null)
{
- context[functionKey] = tokenUsage;
+ context[TokenUtils.GetFunctionKey("SystemAudienceExtraction")] = tokenUsage;
+ }
+ else
+ {
+ this._logger.LogError("Unable to determine token usage for audienceExtraction");
}
- return audience;
+ return $"List of participants: {result}";
}
///
@@ -470,16 +410,41 @@ private async Task GetUserIntentAsync(KernelArguments context, Cancellat
{
// Clone the context to avoid modifying the original context variables
KernelArguments intentContext = new(context);
- string userIntent = await this.ExtractUserIntentAsync(intentContext, cancellationToken);
- // Copy token usage into original chat context
- var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemIntentExtraction")!;
- if (intentContext.TryGetValue(functionKey!, out object? tokenUsage))
+ int tokenBudget =
+ this._promptOptions.CompletionTokenLimit -
+ this._promptOptions.ResponseTokenLimit -
+ TokenUtils.TokenCount(string.Join("\n", new string[]
+ {
+ this._promptOptions.SystemPersona,
+ this._promptOptions.SystemIntent,
+ this._promptOptions.SystemIntentContinuation
+ })
+ );
+
+ intentContext["tokenLimit"] = tokenBudget.ToString(new NumberFormatInfo());
+ intentContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate;
+
+ var completionFunction = this._kernel.CreateFunctionFromPrompt(
+ this._promptOptions.SystemIntentExtraction,
+ this.CreateIntentCompletionSettings(),
+ functionName: "UserIntentExtraction",
+ description: "Extract user intent");
+
+ var result = await completionFunction.InvokeAsync(this._kernel, intentContext, cancellationToken);
+
+ // Get token usage from ChatCompletion result and add to original context
+ string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger);
+ if (tokenUsage is not null)
{
- context[functionKey!] = tokenUsage;
+ context[TokenUtils.GetFunctionKey("SystemIntentExtraction")] = tokenUsage;
+ }
+ else
+ {
+ this._logger.LogError("Unable to determine token usage for userIntentExtraction");
}
- return userIntent;
+ return $"User intent: {result}";
}
///
@@ -610,24 +575,18 @@ private OpenAIPromptExecutionSettings CreateIntentCompletionSettings()
}
///
- /// Calculate the remaining token budget for the chat response prompt.
- /// This is the token limit minus the token count of the user intent, audience, and the system commands.
+ /// Calculate the maximum number of tokens that can be sent in a request
///
- /// All current messages to use for chat completion
- /// The user message.
- /// The remaining token limit.
- private int GetChatContextTokenLimit(ChatHistory promptTemplate, string userInput = "")
+ private int GetMaxRequestTokenBudget()
{
// OpenAI inserts a message under the hood:
// "content": "Assistant is a large language model.","role": "system"
// This burns just under 20 tokens which need to be accounted for.
const int ExtraOpenAiMessageTokens = 20;
- return this._promptOptions.CompletionTokenLimit
+ return this._promptOptions.CompletionTokenLimit // Total token limit
- ExtraOpenAiMessageTokens
- - TokenUtils.GetContextMessagesTokenCount(promptTemplate)
- - TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userInput) // User message has to be included in chat history allowance
- - this._promptOptions.ResponseTokenLimit;
+ - this._promptOptions.ResponseTokenLimit; // Token count reserved for model to generate a response
}
///
diff --git a/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs b/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs
index 9b1c6d4eb..6c7f1d90f 100644
--- a/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs
+++ b/webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs
@@ -91,10 +91,19 @@ async Task ExtractCognitiveMemoryAsync(string memoryType, st
cancellationToken);
// Get token usage from ChatCompletion result and add to context
- // Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type.
- TokenUtils.GetFunctionTokenUsage(result, kernelArguments, logger, $"SystemCognitive_{memoryType}");
+ string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, logger);
+ if (tokenUsage is not null)
+ {
+ // Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type.
+ kernelArguments[TokenUtils.GetFunctionKey($"SystemCognitive_{memoryType}")] = tokenUsage;
+ }
+ else
+ {
+ logger.LogError("Unable to determine token usage for {0}", $"SystemCognitive_{memoryType}");
+ }
SemanticChatMemory memory = SemanticChatMemory.FromJson(result.ToString());
+
return memory;
}
diff --git a/webapi/Plugins/Utils/TokenUtils.cs b/webapi/Plugins/Utils/TokenUtils.cs
index 4d9f816e4..f3adad225 100644
--- a/webapi/Plugins/Utils/TokenUtils.cs
+++ b/webapi/Plugins/Utils/TokenUtils.cs
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
-using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
@@ -22,7 +21,7 @@ public static class TokenUtils
/// Semantic dependencies of ChatPlugin.
/// If you add a new semantic dependency, please add it here.
///
- public static readonly Dictionary semanticFunctions = new()
+ internal static readonly Dictionary semanticFunctions = new()
{
{ "SystemAudienceExtraction", "audienceExtraction" },
{ "SystemIntentExtraction", "userIntentExtraction" },
@@ -38,21 +37,19 @@ public static class TokenUtils
///
internal static Dictionary EmptyTokenUsages()
{
- return semanticFunctions.Values.ToDictionary(v => v, v => 0, StringComparer.OrdinalIgnoreCase);
+ return semanticFunctions.Values.ToDictionary(v => v, v => 0);
}
///
/// Gets key used to identify function token usage in context variables.
///
- /// The logger instance to use for logging errors.
/// Name of semantic function.
/// The key corresponding to the semantic function name, or null if the function name is unknown.
- internal static string? GetFunctionKey(ILogger logger, string? functionName)
+ internal static string GetFunctionKey(string? functionName)
{
if (functionName == null || !semanticFunctions.TryGetValue(functionName, out string? key))
{
- logger.LogError("Unknown token dependency {0}. Please define function as semanticFunctions entry in TokenUtils.cs", functionName);
- return null;
+ throw new KeyNotFoundException($"Unknown token dependency {functionName}. Please define function as semanticFunctions entry in TokenUtils.cs");
};
return $"{key}TokenUsage";
@@ -62,50 +59,32 @@ internal static Dictionary EmptyTokenUsages()
/// Gets the total token usage from a Chat or Text Completion result context and adds it as a variable to response context.
///
/// Result context from chat model
- /// Context maintained during response generation.
- /// The logger instance to use for logging errors.
- /// Name of the function that invoked the chat completion.
- /// true if token usage is found in result context; otherwise, false.
- internal static void GetFunctionTokenUsage(FunctionResult result, KernelArguments kernelArguments, ILogger logger, string? functionName = null)
+ /// The logger instance to use for logging errors.
+ /// String representation of number of tokens used by function (or null on error)
+ internal static string? GetFunctionTokenUsage(FunctionResult result, ILogger logger)
{
+ if (result.Metadata is null ||
+ !result.Metadata.TryGetValue("Usage", out object? usageObject) || usageObject is null)
+ {
+ logger.LogError("No usage metadata provided");
+
+ return null;
+ }
+
+ var tokenUsage = 0;
try
{
- var functionKey = GetFunctionKey(logger, functionName);
- if (functionKey == null)
- {
- return;
- }
-
- if (result.Metadata is null)
- {
- logger.LogError("No metadata provided to capture usage details.");
- return;
- }
-
- if (!result.Metadata.TryGetValue("Usage", out object? usageObject) || usageObject is null)
- {
- logger.LogError("Unable to determine token usage for {0}", functionKey);
- return;
- }
-
- var tokenUsage = 0;
- try
- {
- var jsonObject = JsonSerializer.Deserialize(JsonSerializer.Serialize(usageObject));
- tokenUsage = jsonObject.GetProperty("TotalTokens").GetInt32();
- }
- catch (KeyNotFoundException)
- {
- logger.LogError("Usage details not found in model result.");
- }
-
- kernelArguments[functionKey!] = tokenUsage.ToString(CultureInfo.InvariantCulture);
+ var jsonObject = JsonSerializer.Deserialize(JsonSerializer.Serialize(usageObject));
+ tokenUsage = jsonObject.GetProperty("TotalTokens").GetInt32();
}
- catch (Exception e)
+ catch (KeyNotFoundException)
{
- logger.LogError(e, "Unable to determine token usage for {0}", functionName);
- throw e;
+ logger.LogError("Usage details not found in model result.");
+
+ return null;
}
+
+ return tokenUsage.ToString(CultureInfo.InvariantCulture);
}
///