-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #101 from teknologi-umum/vision
Vision
- Loading branch information
Showing
4 changed files
with
327 additions
and
14 deletions.
There are no files selected for viewing
243 changes: 243 additions & 0 deletions
243
BotNet.CommandHandlers/AI/OpenAI/OpenAIImagePromptHandler.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
using BotNet.CommandHandlers.Art; | ||
using BotNet.Commands; | ||
using BotNet.Commands.AI.OpenAI; | ||
using BotNet.Commands.AI.Stability; | ||
using BotNet.Commands.BotUpdate.Message; | ||
using BotNet.Commands.ChatAggregate; | ||
using BotNet.Commands.SenderAggregate; | ||
using BotNet.Services.MarkdownV2; | ||
using BotNet.Services.OpenAI; | ||
using BotNet.Services.OpenAI.Models; | ||
using BotNet.Services.RateLimit; | ||
using Microsoft.Extensions.Logging; | ||
using SkiaSharp; | ||
using Telegram.Bot; | ||
using Telegram.Bot.Types; | ||
using Telegram.Bot.Types.Enums; | ||
|
||
namespace BotNet.CommandHandlers.AI.OpenAI { | ||
public sealed class OpenAIImagePromptHandler( | ||
ITelegramBotClient telegramBotClient, | ||
ICommandQueue commandQueue, | ||
ITelegramMessageCache telegramMessageCache, | ||
OpenAIClient openAIClient, | ||
ILogger<OpenAIImageGenerationPromptHandler> logger | ||
) : ICommandHandler<OpenAIImagePrompt> { | ||
internal static readonly RateLimiter VISION_RATE_LIMITER = RateLimiter.PerUserPerChat(1, TimeSpan.FromMinutes(15)); | ||
|
||
private readonly ITelegramBotClient _telegramBotClient = telegramBotClient; | ||
private readonly ICommandQueue _commandQueue = commandQueue; | ||
private readonly ITelegramMessageCache _telegramMessageCache = telegramMessageCache; | ||
private readonly OpenAIClient _openAIClient = openAIClient; | ||
private readonly ILogger<OpenAIImageGenerationPromptHandler> _logger = logger; | ||
|
||
public Task Handle(OpenAIImagePrompt imagePrompt, CancellationToken cancellationToken) { | ||
if (imagePrompt.Command.Sender is not VIPSender | ||
&& imagePrompt.Command.Chat is not HomeGroupChat) { | ||
return _telegramBotClient.SendTextMessageAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
text: MarkdownV2Sanitizer.Sanitize("Vision tidak bisa dipakai di sini."), | ||
parseMode: ParseMode.MarkdownV2, | ||
replyToMessageId: imagePrompt.Command.MessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
} | ||
|
||
try { | ||
VISION_RATE_LIMITER.ValidateActionRate( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
userId: imagePrompt.Command.Sender.Id | ||
); | ||
} catch (RateLimitExceededException exc) { | ||
return _telegramBotClient.SendTextMessageAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
text: $"<code>Anda terlalu banyak menggunakan vision. Coba lagi {exc.Cooldown}.</code>", | ||
parseMode: ParseMode.Html, | ||
replyToMessageId: imagePrompt.Command.MessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
} | ||
|
||
// Fire and forget | ||
Task.Run(async () => { | ||
(string? imageBase64, string? error) = await GetImageBase64Async( | ||
botClient: _telegramBotClient, | ||
fileId: imagePrompt.ImageFileId, | ||
cancellationToken: cancellationToken | ||
); | ||
|
||
if (error is not null) { | ||
await _telegramBotClient.SendTextMessageAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
text: $"<code>{error}</code>", | ||
parseMode: ParseMode.Html, | ||
replyToMessageId: imagePrompt.Command.MessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
return; | ||
} | ||
|
||
List<ChatMessage> messages = [ | ||
ChatMessage.FromText("system", "The following is a conversation with an AI assistant. The assistant is helpful, creative, direct, concise, and always get to the point. When user asks for an image to be generated, the AI assistant should respond with \"ImageGeneration:\" followed by comma separated list of features to be expected from the generated image.") | ||
]; | ||
|
||
messages.AddRange( | ||
from message in imagePrompt.Thread.Take(10).Reverse() | ||
select ChatMessage.FromText( | ||
role: message.Sender.ChatGPTRole, | ||
text: message.Text | ||
) | ||
); | ||
|
||
messages.Add( | ||
ChatMessage.FromTextWithImageBase64("user", imagePrompt.Prompt, imageBase64!) | ||
); | ||
|
||
Message responseMessage = await _telegramBotClient.SendTextMessageAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
text: MarkdownV2Sanitizer.Sanitize("… ⏳"), | ||
parseMode: ParseMode.MarkdownV2, | ||
replyToMessageId: imagePrompt.Command.MessageId | ||
); | ||
|
||
string response = await _openAIClient.ChatAsync( | ||
model: imagePrompt switch { | ||
({ Command: { Sender: VIPSender } or { Chat: HomeGroupChat } }) => "gpt-4-1106-preview", | ||
_ => "gpt-3.5-turbo" | ||
}, | ||
messages: messages, | ||
maxTokens: 512, | ||
cancellationToken: cancellationToken | ||
); | ||
|
||
// Handle image generation intent | ||
if (response.StartsWith("ImageGeneration:")) { | ||
if (imagePrompt.Command.Sender is not VIPSender) { | ||
try { | ||
ArtCommandHandler.IMAGE_GENERATION_RATE_LIMITER.ValidateActionRate(imagePrompt.Command.Chat.Id, imagePrompt.Command.Sender.Id); | ||
} catch (RateLimitExceededException exc) { | ||
await _telegramBotClient.SendTextMessageAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
text: $"Anda belum mendapat giliran. Coba lagi {exc.Cooldown}.", | ||
parseMode: ParseMode.Html, | ||
replyToMessageId: imagePrompt.Command.MessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
return; | ||
} | ||
} | ||
|
||
string imageGenerationPrompt = response.Substring(response.IndexOf(':') + 1).Trim(); | ||
switch (imagePrompt.Command) { | ||
case { Sender: VIPSender }: | ||
await _commandQueue.DispatchAsync( | ||
command: new OpenAIImageGenerationPrompt( | ||
callSign: imagePrompt.CallSign, | ||
prompt: imageGenerationPrompt, | ||
promptMessageId: imagePrompt.Command.MessageId, | ||
responseMessageId: new(responseMessage.MessageId), | ||
chat: imagePrompt.Command.Chat, | ||
sender: imagePrompt.Command.Sender | ||
) | ||
); | ||
break; | ||
case { Chat: HomeGroupChat }: | ||
await _commandQueue.DispatchAsync( | ||
command: new StabilityTextToImagePrompt( | ||
callSign: imagePrompt.CallSign, | ||
prompt: imageGenerationPrompt, | ||
promptMessageId: imagePrompt.Command.MessageId, | ||
responseMessageId: new(responseMessage.MessageId), | ||
chat: imagePrompt.Command.Chat, | ||
sender: imagePrompt.Command.Sender | ||
) | ||
); | ||
break; | ||
default: | ||
await _telegramBotClient.EditMessageTextAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
messageId: responseMessage.MessageId, | ||
text: MarkdownV2Sanitizer.Sanitize("Image generation tidak bisa dipakai di sini."), | ||
parseMode: ParseMode.MarkdownV2, | ||
cancellationToken: cancellationToken | ||
); | ||
break; | ||
} | ||
return; | ||
} | ||
|
||
// Finalize message | ||
try { | ||
responseMessage = await telegramBotClient.EditMessageTextAsync( | ||
chatId: imagePrompt.Command.Chat.Id, | ||
messageId: responseMessage.MessageId, | ||
text: MarkdownV2Sanitizer.Sanitize(response), | ||
parseMode: ParseMode.MarkdownV2, | ||
cancellationToken: cancellationToken | ||
); | ||
} catch (Exception exc) { | ||
_logger.LogError(exc, null); | ||
throw; | ||
} | ||
|
||
// Track thread | ||
_telegramMessageCache.Add( | ||
message: AIResponseMessage.FromMessage( | ||
message: responseMessage, | ||
replyToMessage: imagePrompt.Command, | ||
callSign: imagePrompt.CallSign | ||
) | ||
); | ||
}); | ||
|
||
return Task.CompletedTask; | ||
} | ||
|
||
private static async Task<(string? ImageBase64, string? Error)> GetImageBase64Async(ITelegramBotClient botClient, string fileId, CancellationToken cancellationToken) { | ||
// Download photo | ||
using MemoryStream originalImageStream = new(); | ||
await botClient.GetInfoAndDownloadFileAsync( | ||
fileId: fileId, | ||
destination: originalImageStream, | ||
cancellationToken: cancellationToken); | ||
byte[] originalImage = originalImageStream.ToArray(); | ||
|
||
// Limit input image to 300KB | ||
if (originalImage.Length > 300 * 1024) { | ||
return (null, "Image larger than 300KB"); | ||
} | ||
|
||
// Decode image | ||
originalImageStream.Position = 0; | ||
using SKCodec codec = SKCodec.Create(originalImageStream, out SKCodecResult codecResult); | ||
if (codecResult != SKCodecResult.Success) { | ||
return (null, "Invalid image"); | ||
} | ||
|
||
if (codec.EncodedFormat != SKEncodedImageFormat.Jpeg | ||
&& codec.EncodedFormat != SKEncodedImageFormat.Webp) { | ||
return (null, "Image must be compressed image"); | ||
} | ||
SKBitmap bitmap = SKBitmap.Decode(codec); | ||
|
||
// Limit input image to 1280x1280 | ||
if (bitmap.Width > 1280 || bitmap.Width > 1280) { | ||
return (null, "Image larger than 1280x1280"); | ||
} | ||
|
||
// Handle stickers | ||
if (codec.EncodedFormat == SKEncodedImageFormat.Webp) { | ||
SKImage image = SKImage.FromBitmap(bitmap); | ||
SKData data = image.Encode(SKEncodedImageFormat.Jpeg, 20); | ||
using MemoryStream jpegStream = new(); | ||
data.SaveTo(jpegStream); | ||
|
||
// Encode image as base64 | ||
return (Convert.ToBase64String(jpegStream.ToArray()), null); | ||
} | ||
|
||
// Encode image as base64 | ||
return (Convert.ToBase64String(originalImage), null); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
using BotNet.Commands.BotUpdate.Message; | ||
|
||
namespace BotNet.Commands.AI.OpenAI { | ||
public sealed record OpenAIImagePrompt : ICommand { | ||
public string CallSign { get; } | ||
public string Prompt { get; } | ||
public string ImageFileId { get; } | ||
public HumanMessageBase Command { get; } | ||
public IEnumerable<MessageBase> Thread { get; } | ||
|
||
private OpenAIImagePrompt( | ||
string callSign, | ||
string prompt, | ||
string imageFileId, | ||
HumanMessageBase command, | ||
IEnumerable<MessageBase> thread | ||
) { | ||
CallSign = callSign; | ||
Prompt = prompt; | ||
ImageFileId = imageFileId; | ||
Command = command; | ||
Thread = thread; | ||
} | ||
|
||
public static OpenAIImagePrompt FromAICallCommand(AICallCommand aiCallCommand, IEnumerable<MessageBase> thread) { | ||
// Call sign must be AI, Bot, or GPT | ||
if (aiCallCommand.CallSign is not "AI" and not "Bot" and not "GPT") { | ||
throw new ArgumentException("Call sign must be AI, Bot, or GPT.", nameof(aiCallCommand)); | ||
} | ||
|
||
// Prompt must be non-empty | ||
if (string.IsNullOrWhiteSpace(aiCallCommand.Text)) { | ||
throw new ArgumentException("Prompt must be non-empty.", nameof(aiCallCommand)); | ||
} | ||
|
||
// File ID must be non-empty | ||
if (string.IsNullOrWhiteSpace(aiCallCommand.ImageFileId)) { | ||
throw new ArgumentException("File ID must be non-empty.", nameof(aiCallCommand)); | ||
} | ||
|
||
// Non-empty thread must begin with reply to message | ||
if (thread.FirstOrDefault() is { | ||
MessageId: { } firstMessageId, | ||
Chat.Id: { } firstChatId | ||
}) { | ||
if (firstMessageId != aiCallCommand.ReplyToMessage?.MessageId | ||
|| firstChatId != aiCallCommand.Chat.Id) { | ||
throw new ArgumentException("Thread must begin with reply to message.", nameof(thread)); | ||
} | ||
} | ||
|
||
return new( | ||
callSign: aiCallCommand.CallSign, | ||
prompt: aiCallCommand.Text, | ||
imageFileId: aiCallCommand.ImageFileId, | ||
command: aiCallCommand, | ||
thread: thread | ||
); | ||
} | ||
} | ||
} |