From ad0f00c7b8387fd6cead9cad73850ffe14f77df6 Mon Sep 17 00:00:00 2001 From: Ronny Gunawan <3048897+ronnygunawan@users.noreply.github.com> Date: Wed, 31 Jan 2024 00:07:37 +0700 Subject: [PATCH] Vision --- .../AI/OpenAI/OpenAIImagePromptHandler.cs | 243 ++++++++++++++++++ .../AI/OpenAI/OpenAITextPromptHandler.cs | 2 +- .../BotUpdate/Message/AICallCommandHandler.cs | 35 ++- .../AI/OpenAI/OpenAIImagePrompt.cs | 61 +++++ 4 files changed, 327 insertions(+), 14 deletions(-) create mode 100644 BotNet.CommandHandlers/AI/OpenAI/OpenAIImagePromptHandler.cs create mode 100644 BotNet.Commands/AI/OpenAI/OpenAIImagePrompt.cs diff --git a/BotNet.CommandHandlers/AI/OpenAI/OpenAIImagePromptHandler.cs b/BotNet.CommandHandlers/AI/OpenAI/OpenAIImagePromptHandler.cs new file mode 100644 index 0000000..9c9dad1 --- /dev/null +++ b/BotNet.CommandHandlers/AI/OpenAI/OpenAIImagePromptHandler.cs @@ -0,0 +1,243 @@ +using BotNet.CommandHandlers.Art; +using BotNet.Commands; +using BotNet.Commands.AI.OpenAI; +using BotNet.Commands.AI.Stability; +using BotNet.Commands.BotUpdate.Message; +using BotNet.Commands.ChatAggregate; +using BotNet.Commands.SenderAggregate; +using BotNet.Services.MarkdownV2; +using BotNet.Services.OpenAI; +using BotNet.Services.OpenAI.Models; +using BotNet.Services.RateLimit; +using Microsoft.Extensions.Logging; +using SkiaSharp; +using Telegram.Bot; +using Telegram.Bot.Types; +using Telegram.Bot.Types.Enums; + +namespace BotNet.CommandHandlers.AI.OpenAI { + public sealed class OpenAIImagePromptHandler( + ITelegramBotClient telegramBotClient, + ICommandQueue commandQueue, + ITelegramMessageCache telegramMessageCache, + OpenAIClient openAIClient, + ILogger logger + ) : ICommandHandler { + internal static readonly RateLimiter VISION_RATE_LIMITER = RateLimiter.PerUserPerChat(1, TimeSpan.FromMinutes(15)); + + private readonly ITelegramBotClient _telegramBotClient = telegramBotClient; + private readonly ICommandQueue _commandQueue = commandQueue; + private readonly ITelegramMessageCache _telegramMessageCache = telegramMessageCache; + private readonly OpenAIClient _openAIClient = openAIClient; + private readonly ILogger _logger = logger; + + public Task Handle(OpenAIImagePrompt imagePrompt, CancellationToken cancellationToken) { + if (imagePrompt.Command.Sender is not VIPSender + && imagePrompt.Command.Chat is not HomeGroupChat) { + return _telegramBotClient.SendTextMessageAsync( + chatId: imagePrompt.Command.Chat.Id, + text: MarkdownV2Sanitizer.Sanitize("Vision tidak bisa dipakai di sini."), + parseMode: ParseMode.MarkdownV2, + replyToMessageId: imagePrompt.Command.MessageId, + cancellationToken: cancellationToken + ); + } + + try { + VISION_RATE_LIMITER.ValidateActionRate( + chatId: imagePrompt.Command.Chat.Id, + userId: imagePrompt.Command.Sender.Id + ); + } catch (RateLimitExceededException exc) { + return _telegramBotClient.SendTextMessageAsync( + chatId: imagePrompt.Command.Chat.Id, + text: $"Anda terlalu banyak menggunakan vision. Coba lagi {exc.Cooldown}.", + parseMode: ParseMode.Html, + replyToMessageId: imagePrompt.Command.MessageId, + cancellationToken: cancellationToken + ); + } + + // Fire and forget + Task.Run(async () => { + (string? imageBase64, string? error) = await GetImageBase64Async( + botClient: _telegramBotClient, + fileId: imagePrompt.ImageFileId, + cancellationToken: cancellationToken + ); + + if (error is not null) { + await _telegramBotClient.SendTextMessageAsync( + chatId: imagePrompt.Command.Chat.Id, + text: $"{error}", + parseMode: ParseMode.Html, + replyToMessageId: imagePrompt.Command.MessageId, + cancellationToken: cancellationToken + ); + return; + } + + List messages = [ + ChatMessage.FromText("system", "The following is a conversation with an AI assistant. The assistant is helpful, creative, direct, concise, and always get to the point. When user asks for an image to be generated, the AI assistant should respond with \"ImageGeneration:\" followed by comma separated list of features to be expected from the generated image.") + ]; + + messages.AddRange( + from message in imagePrompt.Thread.Take(10).Reverse() + select ChatMessage.FromText( + role: message.Sender.ChatGPTRole, + text: message.Text + ) + ); + + messages.Add( + ChatMessage.FromTextWithImageBase64("user", imagePrompt.Prompt, imageBase64!) + ); + + Message responseMessage = await _telegramBotClient.SendTextMessageAsync( + chatId: imagePrompt.Command.Chat.Id, + text: MarkdownV2Sanitizer.Sanitize("… ⏳"), + parseMode: ParseMode.MarkdownV2, + replyToMessageId: imagePrompt.Command.MessageId + ); + + string response = await _openAIClient.ChatAsync( + model: imagePrompt switch { + ({ Command: { Sender: VIPSender } or { Chat: HomeGroupChat } }) => "gpt-4-1106-preview", + _ => "gpt-3.5-turbo" + }, + messages: messages, + maxTokens: 512, + cancellationToken: cancellationToken + ); + + // Handle image generation intent + if (response.StartsWith("ImageGeneration:")) { + if (imagePrompt.Command.Sender is not VIPSender) { + try { + ArtCommandHandler.IMAGE_GENERATION_RATE_LIMITER.ValidateActionRate(imagePrompt.Command.Chat.Id, imagePrompt.Command.Sender.Id); + } catch (RateLimitExceededException exc) { + await _telegramBotClient.SendTextMessageAsync( + chatId: imagePrompt.Command.Chat.Id, + text: $"Anda belum mendapat giliran. Coba lagi {exc.Cooldown}.", + parseMode: ParseMode.Html, + replyToMessageId: imagePrompt.Command.MessageId, + cancellationToken: cancellationToken + ); + return; + } + } + + string imageGenerationPrompt = response.Substring(response.IndexOf(':') + 1).Trim(); + switch (imagePrompt.Command) { + case { Sender: VIPSender }: + await _commandQueue.DispatchAsync( + command: new OpenAIImageGenerationPrompt( + callSign: imagePrompt.CallSign, + prompt: imageGenerationPrompt, + promptMessageId: imagePrompt.Command.MessageId, + responseMessageId: new(responseMessage.MessageId), + chat: imagePrompt.Command.Chat, + sender: imagePrompt.Command.Sender + ) + ); + break; + case { Chat: HomeGroupChat }: + await _commandQueue.DispatchAsync( + command: new StabilityTextToImagePrompt( + callSign: imagePrompt.CallSign, + prompt: imageGenerationPrompt, + promptMessageId: imagePrompt.Command.MessageId, + responseMessageId: new(responseMessage.MessageId), + chat: imagePrompt.Command.Chat, + sender: imagePrompt.Command.Sender + ) + ); + break; + default: + await _telegramBotClient.EditMessageTextAsync( + chatId: imagePrompt.Command.Chat.Id, + messageId: responseMessage.MessageId, + text: MarkdownV2Sanitizer.Sanitize("Image generation tidak bisa dipakai di sini."), + parseMode: ParseMode.MarkdownV2, + cancellationToken: cancellationToken + ); + break; + } + return; + } + + // Finalize message + try { + responseMessage = await telegramBotClient.EditMessageTextAsync( + chatId: imagePrompt.Command.Chat.Id, + messageId: responseMessage.MessageId, + text: MarkdownV2Sanitizer.Sanitize(response), + parseMode: ParseMode.MarkdownV2, + cancellationToken: cancellationToken + ); + } catch (Exception exc) { + _logger.LogError(exc, null); + throw; + } + + // Track thread + _telegramMessageCache.Add( + message: AIResponseMessage.FromMessage( + message: responseMessage, + replyToMessage: imagePrompt.Command, + callSign: imagePrompt.CallSign + ) + ); + }); + + return Task.CompletedTask; + } + + private static async Task<(string? ImageBase64, string? Error)> GetImageBase64Async(ITelegramBotClient botClient, string fileId, CancellationToken cancellationToken) { + // Download photo + using MemoryStream originalImageStream = new(); + await botClient.GetInfoAndDownloadFileAsync( + fileId: fileId, + destination: originalImageStream, + cancellationToken: cancellationToken); + byte[] originalImage = originalImageStream.ToArray(); + + // Limit input image to 300KB + if (originalImage.Length > 300 * 1024) { + return (null, "Image larger than 300KB"); + } + + // Decode image + originalImageStream.Position = 0; + using SKCodec codec = SKCodec.Create(originalImageStream, out SKCodecResult codecResult); + if (codecResult != SKCodecResult.Success) { + return (null, "Invalid image"); + } + + if (codec.EncodedFormat != SKEncodedImageFormat.Jpeg + && codec.EncodedFormat != SKEncodedImageFormat.Webp) { + return (null, "Image must be compressed image"); + } + SKBitmap bitmap = SKBitmap.Decode(codec); + + // Limit input image to 1280x1280 + if (bitmap.Width > 1280 || bitmap.Width > 1280) { + return (null, "Image larger than 1280x1280"); + } + + // Handle stickers + if (codec.EncodedFormat == SKEncodedImageFormat.Webp) { + SKImage image = SKImage.FromBitmap(bitmap); + SKData data = image.Encode(SKEncodedImageFormat.Jpeg, 20); + using MemoryStream jpegStream = new(); + data.SaveTo(jpegStream); + + // Encode image as base64 + return (Convert.ToBase64String(jpegStream.ToArray()), null); + } + + // Encode image as base64 + return (Convert.ToBase64String(originalImage), null); + } + } +} diff --git a/BotNet.CommandHandlers/AI/OpenAI/OpenAITextPromptHandler.cs b/BotNet.CommandHandlers/AI/OpenAI/OpenAITextPromptHandler.cs index 59430e5..a95e55b 100644 --- a/BotNet.CommandHandlers/AI/OpenAI/OpenAITextPromptHandler.cs +++ b/BotNet.CommandHandlers/AI/OpenAI/OpenAITextPromptHandler.cs @@ -49,7 +49,7 @@ public Task Handle(OpenAITextPrompt textPrompt, CancellationToken cancellationTo // Fire and forget Task.Run(async () => { List messages = [ - ChatMessage.FromText("system", "The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly. When user asks for an image to be generated, the AI assistant should respond with \"ImageGeneration:\" followed by comma separated list of features to be expected from the generated image.") + ChatMessage.FromText("system", "The following is a conversation with an AI assistant. The assistant is helpful, creative, direct, concise, and always get to the point. When user asks for an image to be generated, the AI assistant should respond with \"ImageGeneration:\" followed by comma separated list of features to be expected from the generated image.") ]; messages.AddRange( diff --git a/BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs b/BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs index 56926fb..b98fb4b 100644 --- a/BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs +++ b/BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs @@ -15,19 +15,28 @@ IntentDetector intentDetector public async Task Handle(AICallCommand command, CancellationToken cancellationToken) { switch (command.CallSign) { - case "AI" or "Bot" or "GPT" when command.ImageFileId is null: - await _commandQueue.DispatchAsync( - command: OpenAITextPrompt.FromAICallCommand( - aiCallCommand: command, - thread: command.ReplyToMessage is { } replyToMessage - ? _telegramMessageCache.GetThread(replyToMessage) - : Enumerable.Empty() - ) - ); - break; - case "AI" or "Bot" or "GPT" when command.ImageFileId is { } imageFileId: - // TODO: Implement GPT-4 Vision - break; + case "AI" or "Bot" or "GPT" when command.ImageFileId is null: { + await _commandQueue.DispatchAsync( + command: OpenAITextPrompt.FromAICallCommand( + aiCallCommand: command, + thread: command.ReplyToMessage is { } replyToMessage + ? _telegramMessageCache.GetThread(replyToMessage) + : Enumerable.Empty() + ) + ); + break; + } + case "AI" or "Bot" or "GPT" when command.ImageFileId is { } imageFileId: { + await _commandQueue.DispatchAsync( + command: OpenAIImagePrompt.FromAICallCommand( + aiCallCommand: command, + thread: command.ReplyToMessage is { } replyToMessage + ? _telegramMessageCache.GetThread(replyToMessage) + : Enumerable.Empty() + ) + ); + break; + } } } } diff --git a/BotNet.Commands/AI/OpenAI/OpenAIImagePrompt.cs b/BotNet.Commands/AI/OpenAI/OpenAIImagePrompt.cs new file mode 100644 index 0000000..b38f78d --- /dev/null +++ b/BotNet.Commands/AI/OpenAI/OpenAIImagePrompt.cs @@ -0,0 +1,61 @@ +using BotNet.Commands.BotUpdate.Message; + +namespace BotNet.Commands.AI.OpenAI { + public sealed record OpenAIImagePrompt : ICommand { + public string CallSign { get; } + public string Prompt { get; } + public string ImageFileId { get; } + public HumanMessageBase Command { get; } + public IEnumerable Thread { get; } + + private OpenAIImagePrompt( + string callSign, + string prompt, + string imageFileId, + HumanMessageBase command, + IEnumerable thread + ) { + CallSign = callSign; + Prompt = prompt; + ImageFileId = imageFileId; + Command = command; + Thread = thread; + } + + public static OpenAIImagePrompt FromAICallCommand(AICallCommand aiCallCommand, IEnumerable thread) { + // Call sign must be AI, Bot, or GPT + if (aiCallCommand.CallSign is not "AI" and not "Bot" and not "GPT") { + throw new ArgumentException("Call sign must be AI, Bot, or GPT.", nameof(aiCallCommand)); + } + + // Prompt must be non-empty + if (string.IsNullOrWhiteSpace(aiCallCommand.Text)) { + throw new ArgumentException("Prompt must be non-empty.", nameof(aiCallCommand)); + } + + // File ID must be non-empty + if (string.IsNullOrWhiteSpace(aiCallCommand.ImageFileId)) { + throw new ArgumentException("File ID must be non-empty.", nameof(aiCallCommand)); + } + + // Non-empty thread must begin with reply to message + if (thread.FirstOrDefault() is { + MessageId: { } firstMessageId, + Chat.Id: { } firstChatId + }) { + if (firstMessageId != aiCallCommand.ReplyToMessage?.MessageId + || firstChatId != aiCallCommand.Chat.Id) { + throw new ArgumentException("Thread must begin with reply to message.", nameof(thread)); + } + } + + return new( + callSign: aiCallCommand.CallSign, + prompt: aiCallCommand.Text, + imageFileId: aiCallCommand.ImageFileId, + command: aiCallCommand, + thread: thread + ); + } + } +}