Skip to content

Commit

Permalink
Fix token budget + Clarify code (microsoft#839)
Browse files Browse the repository at this point in the history
### Motivation and Context
The user message is not added to the GPT completion request when the
token count of the prompt exceeds 50% of the configuration.

This fix is inspired by microsoft#836

### Description
Fix budget calculation and clarify / simplify code.

### Contribution Checklist
- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
glahaye authored Mar 9, 2024
1 parent af797ef commit 8ca758c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 171 deletions.
207 changes: 83 additions & 124 deletions webapi/Plugins/Chat/ChatPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,89 +107,6 @@ public ChatPlugin(
this._contentSafety = contentSafety;
}

/// <summary>
/// Extract user intent from the conversation history.
/// </summary>
/// <param name="kernelArguments">The KernelArguments.</param>
/// <param name="cancellationToken">The cancellation token.</param>
private async Task<string> ExtractUserIntentAsync(KernelArguments kernelArguments, CancellationToken cancellationToken = default)
{
var tokenLimit = this._promptOptions.CompletionTokenLimit;
var historyTokenBudget =
tokenLimit -
this._promptOptions.ResponseTokenLimit -
TokenUtils.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemDescription,
this._promptOptions.SystemIntent,
this._promptOptions.SystemIntentContinuation
})
);

// Clone the context to avoid modifying the original context variables.
KernelArguments intentExtractionContext = new(kernelArguments);
intentExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());
intentExtractionContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate;

var completionFunction = this._kernel.CreateFunctionFromPrompt(
this._promptOptions.SystemIntentExtraction,
this.CreateIntentCompletionSettings(),
functionName: nameof(ChatPlugin),
description: "Complete the prompt.");

var result = await completionFunction.InvokeAsync(
this._kernel,
intentExtractionContext,
cancellationToken
);

// Get token usage from ChatCompletion result and add to context
TokenUtils.GetFunctionTokenUsage(result, intentExtractionContext, this._logger, "SystemIntentExtraction");

return $"User intent: {result}";
}

/// <summary>
/// Extract the list of participants from the conversation history.
/// Note that only those who have spoken will be included.
/// </summary>
/// <param name="context">The SKContext.</param>
/// <param name="cancellationToken">The cancellation token.</param>
private async Task<string> ExtractAudienceAsync(KernelArguments context, CancellationToken cancellationToken = default)
{
var tokenLimit = this._promptOptions.CompletionTokenLimit;
var historyTokenBudget =
tokenLimit -
this._promptOptions.ResponseTokenLimit -
TokenUtils.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemAudience,
this._promptOptions.SystemAudienceContinuation,
})
);

// Clone the context to avoid modifying the original context variables.
KernelArguments audienceExtractionContext = new(context);
audienceExtractionContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());

var completionFunction = this._kernel.CreateFunctionFromPrompt(
this._promptOptions.SystemAudienceExtraction,
this.CreateIntentCompletionSettings(),
functionName: nameof(ChatPlugin),
description: "Complete the prompt.");

var result = await completionFunction.InvokeAsync(
this._kernel,
audienceExtractionContext,
cancellationToken
);

// Get token usage from ChatCompletion result and add to context
TokenUtils.GetFunctionTokenUsage(result, context, this._logger, "SystemAudienceExtraction");

return $"List of participants: {result}";
}

/// <summary>
/// Method that wraps GetAllowedChatHistoryAsync to get allotted history messages as one string.
/// GetAllowedChatHistoryAsync optionally updates a ChatHistory object with the allotted messages,
Expand Down Expand Up @@ -324,7 +241,7 @@ private async Task<CopilotChatMessage> GetChatResponseAsync(string chatId, strin
// Render system instruction components and create the meta-prompt template
var systemInstructions = await AsyncUtils.SafeInvokeAsync(
() => this.RenderSystemInstructions(chatId, chatContext, cancellationToken), nameof(RenderSystemInstructions));
ChatHistory chatHistory = new(systemInstructions);
ChatHistory metaPrompt = new(systemInstructions);

// Bypass audience extraction if Auth is disabled
var audience = string.Empty;
Expand All @@ -334,41 +251,42 @@ private async Task<CopilotChatMessage> GetChatResponseAsync(string chatId, strin
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting audience", cancellationToken);
audience = await AsyncUtils.SafeInvokeAsync(
() => this.GetAudienceAsync(chatContext, cancellationToken), nameof(GetAudienceAsync));
chatHistory.AddSystemMessage(audience);
metaPrompt.AddSystemMessage(audience);
}

// Extract user intent from the conversation history.
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting user intent", cancellationToken);
var userIntent = await AsyncUtils.SafeInvokeAsync(
() => this.GetUserIntentAsync(chatContext, cancellationToken), nameof(GetUserIntentAsync));
chatHistory.AddSystemMessage(userIntent);
metaPrompt.AddSystemMessage(userIntent);

// Calculate the remaining token budget.
var remainingTokenBudget = this.GetChatContextTokenLimit(chatHistory, userMessage.ToFormattedString());
// Calculate max amount of tokens to use for memories
int maxRequestTokenBudget = this.GetMaxRequestTokenBudget();
// Calculate tokens used so far: system instructions, audience extraction and user intent
int tokensUsed = TokenUtils.GetContextMessagesTokenCount(metaPrompt);
int chatMemoryTokenBudget = maxRequestTokenBudget
- tokensUsed
- TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userMessage.ToFormattedString());
chatMemoryTokenBudget = (int)(chatMemoryTokenBudget * this._promptOptions.MemoriesResponseContextWeight);

// Query relevant semantic and document memories
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting semantic and document memories", cancellationToken);
var chatMemoriesTokenLimit = (int)(remainingTokenBudget * this._promptOptions.MemoriesResponseContextWeight);
(var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoriesTokenLimit);

(var memoryText, var citationMap) = await this._semanticMemoryRetriever.QueryMemoriesAsync(userIntent, chatId, chatMemoryTokenBudget);
if (!string.IsNullOrWhiteSpace(memoryText))
{
chatHistory.AddSystemMessage(memoryText);
metaPrompt.AddSystemMessage(memoryText);
tokensUsed += TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText);
}

// Fill in the chat history with remaining token budget.
string allowedChatHistory = string.Empty;
var allowedChatHistoryTokenBudget = remainingTokenBudget - TokenUtils.GetContextMessageTokenCount(AuthorRole.System, memoryText);

// Append previous messages
// Add as many chat history messages to meta-prompt as the token budget will allow
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Extracting chat history", cancellationToken);
allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, allowedChatHistoryTokenBudget, chatHistory, cancellationToken);
string allowedChatHistory = await this.GetAllowedChatHistoryAsync(chatId, maxRequestTokenBudget - tokensUsed, metaPrompt, cancellationToken);

// Calculate token usage of prompt template
chatContext[TokenUtils.GetFunctionKey(this._logger, "SystemMetaPrompt")!] = TokenUtils.GetContextMessagesTokenCount(chatHistory).ToString(CultureInfo.CurrentCulture);
// Store token usage of prompt template
chatContext[TokenUtils.GetFunctionKey("SystemMetaPrompt")] = TokenUtils.GetContextMessagesTokenCount(metaPrompt).ToString(CultureInfo.CurrentCulture);

// Stream the response to the client
var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, chatHistory);
var promptView = new BotResponsePrompt(systemInstructions, audience, userIntent, memoryText, allowedChatHistory, metaPrompt);

return await this.HandleBotResponseAsync(chatId, userId, chatContext, promptView, citationMap.Values.AsEnumerable(), cancellationToken);
}
Expand Down Expand Up @@ -429,7 +347,7 @@ await AsyncUtils.SafeInvokeAsync(
cancellationToken), nameof(SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync));

// Calculate total token usage for dependency functions and prompt template
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Calculating token usage", cancellationToken);
await this.UpdateBotResponseStatusOnClientAsync(chatId, "Saving token usage", cancellationToken);
chatMessage.TokenUsage = this.GetTokenUsages(chatContext, chatMessage.Content);

// Update the message on client and in chat history with final completion token usage
Expand All @@ -449,16 +367,38 @@ private async Task<string> GetAudienceAsync(KernelArguments context, Cancellatio
{
// Clone the context to avoid modifying the original context variables
KernelArguments audienceContext = new(context);
var audience = await this.ExtractAudienceAsync(audienceContext, cancellationToken);
int historyTokenBudget =
this._promptOptions.CompletionTokenLimit -
this._promptOptions.ResponseTokenLimit -
TokenUtils.TokenCount(string.Join("\n\n", new string[]
{
this._promptOptions.SystemAudience,
this._promptOptions.SystemAudienceContinuation,
})
);

audienceContext["tokenLimit"] = historyTokenBudget.ToString(new NumberFormatInfo());

// Copy token usage into original chat context
var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemAudienceExtraction")!;
if (audienceContext.TryGetValue(functionKey, out object? tokenUsage))
var completionFunction = this._kernel.CreateFunctionFromPrompt(
this._promptOptions.SystemAudienceExtraction,
this.CreateIntentCompletionSettings(),
functionName: "SystemAudienceExtraction",
description: "Extract audience");

var result = await completionFunction.InvokeAsync(this._kernel, audienceContext, cancellationToken);

// Get token usage from ChatCompletion result and add to original context
string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger);
if (tokenUsage is not null)
{
context[functionKey] = tokenUsage;
context[TokenUtils.GetFunctionKey("SystemAudienceExtraction")] = tokenUsage;
}
else
{
this._logger.LogError("Unable to determine token usage for audienceExtraction");
}

return audience;
return $"List of participants: {result}";
}

/// <summary>
Expand All @@ -470,16 +410,41 @@ private async Task<string> GetUserIntentAsync(KernelArguments context, Cancellat
{
// Clone the context to avoid modifying the original context variables
KernelArguments intentContext = new(context);
string userIntent = await this.ExtractUserIntentAsync(intentContext, cancellationToken);

// Copy token usage into original chat context
var functionKey = TokenUtils.GetFunctionKey(this._logger, "SystemIntentExtraction")!;
if (intentContext.TryGetValue(functionKey!, out object? tokenUsage))
int tokenBudget =
this._promptOptions.CompletionTokenLimit -
this._promptOptions.ResponseTokenLimit -
TokenUtils.TokenCount(string.Join("\n", new string[]
{
this._promptOptions.SystemPersona,
this._promptOptions.SystemIntent,
this._promptOptions.SystemIntentContinuation
})
);

intentContext["tokenLimit"] = tokenBudget.ToString(new NumberFormatInfo());
intentContext["knowledgeCutoff"] = this._promptOptions.KnowledgeCutoffDate;

var completionFunction = this._kernel.CreateFunctionFromPrompt(
this._promptOptions.SystemIntentExtraction,
this.CreateIntentCompletionSettings(),
functionName: "UserIntentExtraction",
description: "Extract user intent");

var result = await completionFunction.InvokeAsync(this._kernel, intentContext, cancellationToken);

// Get token usage from ChatCompletion result and add to original context
string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, this._logger);
if (tokenUsage is not null)
{
context[functionKey!] = tokenUsage;
context[TokenUtils.GetFunctionKey("SystemIntentExtraction")] = tokenUsage;
}
else
{
this._logger.LogError("Unable to determine token usage for userIntentExtraction");
}

return userIntent;
return $"User intent: {result}";
}

/// <summary>
Expand Down Expand Up @@ -610,24 +575,18 @@ private OpenAIPromptExecutionSettings CreateIntentCompletionSettings()
}

/// <summary>
/// Calculate the remaining token budget for the chat response prompt.
/// This is the token limit minus the token count of the user intent, audience, and the system commands.
/// Calculate the maximum number of tokens that can be sent in a request
/// </summary>
/// <param name="promptTemplate">All current messages to use for chat completion</param>
/// <param name="userIntent">The user message.</param>
/// <returns>The remaining token limit.</returns>
private int GetChatContextTokenLimit(ChatHistory promptTemplate, string userInput = "")
private int GetMaxRequestTokenBudget()
{
// OpenAI inserts a message under the hood:
// "content": "Assistant is a large language model.","role": "system"
// This burns just under 20 tokens which need to be accounted for.
const int ExtraOpenAiMessageTokens = 20;

return this._promptOptions.CompletionTokenLimit
return this._promptOptions.CompletionTokenLimit // Total token limit
- ExtraOpenAiMessageTokens
- TokenUtils.GetContextMessagesTokenCount(promptTemplate)
- TokenUtils.GetContextMessageTokenCount(AuthorRole.User, userInput) // User message has to be included in chat history allowance
- this._promptOptions.ResponseTokenLimit;
- this._promptOptions.ResponseTokenLimit; // Token count reserved for model to generate a response
}

/// <summary>
Expand Down
13 changes: 11 additions & 2 deletions webapi/Plugins/Chat/SemanticChatMemoryExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,19 @@ async Task<SemanticChatMemory> ExtractCognitiveMemoryAsync(string memoryType, st
cancellationToken);

// Get token usage from ChatCompletion result and add to context
// Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type.
TokenUtils.GetFunctionTokenUsage(result, kernelArguments, logger, $"SystemCognitive_{memoryType}");
string? tokenUsage = TokenUtils.GetFunctionTokenUsage(result, logger);
if (tokenUsage is not null)
{
// Since there are multiple memory types, total token usage is calculated by cumulating the token usage of each memory type.
kernelArguments[TokenUtils.GetFunctionKey($"SystemCognitive_{memoryType}")] = tokenUsage;
}
else
{
logger.LogError("Unable to determine token usage for {0}", $"SystemCognitive_{memoryType}");
}

SemanticChatMemory memory = SemanticChatMemory.FromJson(result.ToString());

return memory;
}

Expand Down
Loading

0 comments on commit 8ca758c

Please sign in to comment.