Skip to content

Commit

Permalink
Merge pull request #71 from teknologi-umum/ai-overhaul
Browse files Browse the repository at this point in the history
Add vision bot
  • Loading branch information
ronnygunawan authored Dec 3, 2023
2 parents 7a081e8 + 81a81c1 commit ceae82d
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 75 deletions.
94 changes: 83 additions & 11 deletions BotNet.Services/BotCommands/OpenAI.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading;
Expand All @@ -10,6 +11,7 @@
using BotNet.Services.RateLimit;
using Microsoft.Extensions.DependencyInjection;
using RG.Ninja;
using SkiaSharp;
using Telegram.Bot;
using Telegram.Bot.Types;
using Telegram.Bot.Types.Enums;
Expand Down Expand Up @@ -446,7 +448,7 @@ await botClient.SendTextMessageAsync(
}

[Obsolete("Use StreamChatWithFriendlyBotAsync instead.", error: true)]
public static async Task<Message?> ChatWithFriendlyBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, ImmutableList<(string Sender, string Text)> thread, CancellationToken cancellationToken) {
public static async Task<Message?> ChatWithFriendlyBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, ImmutableList<(string Sender, string? Text, string? ImageBase64)> thread, CancellationToken cancellationToken) {
try {
(message.Chat.Type == ChatType.Private
? CHAT_PRIVATE_RATE_LIMITER
Expand Down Expand Up @@ -592,7 +594,7 @@ await botClient.SendTextMessageAsync(
return null;
}

public static async Task<Message?> ChatWithSarcasticBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, ImmutableList<(string Sender, string Text)> thread, string callSign, CancellationToken cancellationToken) {
public static async Task<Message?> ChatWithSarcasticBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, ImmutableList<(string Sender, string? Text, string? ImageBase64)> thread, string callSign, CancellationToken cancellationToken) {
try {
(message.Chat.Type == ChatType.Private
? CHAT_PRIVATE_RATE_LIMITER
Expand Down Expand Up @@ -761,18 +763,55 @@ await botClient.SendTextMessageAsync(
}
}

public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, string callSign, Message message, CancellationToken cancellationToken) {
public static async Task StreamChatWithFriendlyBotAsync(
ITelegramBotClient botClient,
IServiceProvider serviceProvider,
Message message,
CancellationToken cancellationToken
) {
try {
(message.Chat.Type == ChatType.Private
? CHAT_PRIVATE_RATE_LIMITER
: CHAT_GROUP_RATE_LIMITER
).ValidateActionRate(message.Chat.Id, message.From!.Id);
await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
message: message.Text!,
callSign: callSign,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);

PhotoSize? photoSize;
string? caption;
if (message is { Photo.Length: > 0, Caption: { } }) {
photoSize = message.Photo.OrderByDescending(photoSize => photoSize.Width).First();
caption = message.Caption;
} else if (message.ReplyToMessage is { Photo.Length: > 0, Caption: { } }) {
photoSize = message.ReplyToMessage.Photo.OrderByDescending(photoSize => photoSize.Width).First();
caption = message.ReplyToMessage.Caption;
} else {
photoSize = null;
caption = null;
}

if (photoSize != null && caption != null) {
(string? imageBase64, string? error) = await GetImageBase64Async(botClient, photoSize, cancellationToken);
if (error != null) {
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: $"<code>{error}</code>",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
return;
}
await serviceProvider.GetRequiredService<VisionBot>().StreamChatAsync(
message: caption,
imageBase64: imageBase64!,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);
} else {
await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
message: message.Text!,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
if (message.Chat.Type == ChatType.Private) {
await botClient.SendTextMessageAsync(
Expand Down Expand Up @@ -802,7 +841,7 @@ await botClient.SendTextMessageAsync(
}
}

public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, string callSign, Message message, ImmutableList<(string Sender, string Text)> thread, CancellationToken cancellationToken) {
public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, ImmutableList<(string Sender, string? Text, string? ImageBase64)> thread, CancellationToken cancellationToken) {
try {
(message.Chat.Type == ChatType.Private
? CHAT_PRIVATE_RATE_LIMITER
Expand All @@ -811,7 +850,6 @@ public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botCl
await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
message: message.Text!,
thread: thread,
callSign: callSign,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);
Expand Down Expand Up @@ -843,5 +881,39 @@ await botClient.SendTextMessageAsync(
cancellationToken: cancellationToken);
}
}

private static async Task<(string? ImageBase64, string? Error)> GetImageBase64Async(ITelegramBotClient botClient, PhotoSize photoSize, CancellationToken cancellationToken) {
// Download photo
using MemoryStream originalImageStream = new();
await botClient.GetInfoAndDownloadFileAsync(
fileId: photoSize.FileId,
destination: originalImageStream,
cancellationToken: cancellationToken);
byte[] originalImage = originalImageStream.ToArray();

// Limit input image to 200KB
if (originalImage.Length > 200 * 1024) {
return (null, "Image larger than 200KB");
}

// Decode image
originalImageStream.Position = 0;
using SKCodec codec = SKCodec.Create(originalImageStream, out SKCodecResult codecResult);
if (codecResult != SKCodecResult.Success) {
return (null, "Invalid image");
}
if (codec.EncodedFormat != SKEncodedImageFormat.Jpeg) {
return (null, "Image must be compressed image");
}
SKBitmap bitmap = SKBitmap.Decode(codec);

// Limit input image to 1280x1280
if (bitmap.Width > 1280 || bitmap.Width > 1280) {
return (null, "Image larger than 1280x1280");
}

// Encode image as base64
return (Convert.ToBase64String(originalImage), null);
}
}
}
4 changes: 2 additions & 2 deletions BotNet.Services/OpenAI/IntentDetector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public async Task<ChatIntent> DetectChatIntentAsync(
CancellationToken cancellationToken
) {
List<ChatMessage> messages = [
new("user", $$"""
ChatMessage.FromText("user", $$"""
These are available intents that one might query when they provide a text prompt:
Question,
Expand Down Expand Up @@ -49,7 +49,7 @@ public async Task<ImagePromptIntent> DetectImagePromptIntentAsync(
CancellationToken cancellationToken
) {
List<ChatMessage> messages = [
new("user", $$"""
ChatMessage.FromText("user", $$"""
These are available intents that one might query when they provide a prompt which contain an image:
Vision,
Expand Down
60 changes: 58 additions & 2 deletions BotNet.Services/OpenAI/Models/ChatMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
namespace BotNet.Services.OpenAI.Models {
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace BotNet.Services.OpenAI.Models {
public record ChatMessage(
string Role,
string Content
List<ChatContent> Content
) {
public static ChatMessage FromText(string role, string text) => new(
Role: role,
Content: [
new ChatContent(
Type: "text",
Text: text ?? throw new ArgumentNullException(nameof(text)),
ImageUrl: null
)
]
);

public static ChatMessage FromTextWithImageBase64(string role, string text, string imageBase64) => new(
Role: role,
Content: [
new ChatContent(
Type: "text",
Text: text ?? throw new ArgumentNullException(nameof(text)),
ImageUrl: null
),
new ChatContent(
Type: "image_url",
Text: null,
ImageUrl: new(
Url: $"data:image/jpeg;base64,{imageBase64 ?? throw new ArgumentNullException(nameof(imageBase64))}"
)
)
]
);

public static ChatMessage FromImageBase64(string role, string imageBase64) => new(
Role: role,
Content: [
new ChatContent(
Type: "image_url",
Text: null,
ImageUrl: new(
Url: $"data:image/jpeg;base64,{imageBase64 ?? throw new ArgumentNullException(nameof(imageBase64))}"
)
)
]
);
}

public record ChatContent(
string Type,
[property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] string? Text,
[property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] ImageUrl? ImageUrl
);

public record ImageUrl(
string Url
);
}
9 changes: 7 additions & 2 deletions BotNet.Services/OpenAI/Models/Choice.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
public record Choice(
string? Text,
int? Index,
ChatMessage? Message,
ChatMessage? Delta,
ChoiceChatMessage? Message,
ChoiceChatMessage? Delta,
Logprobs? Logprobs,
string? FinishReason
);

public record ChoiceChatMessage(
string Role,
string Content
);
}
15 changes: 11 additions & 4 deletions BotNet.Services/OpenAI/OpenAIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,24 @@
using System.Threading.Tasks;
using BotNet.Services.Json;
using BotNet.Services.OpenAI.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using RG.Ninja;

namespace BotNet.Services.OpenAI {
public class OpenAIClient(
HttpClient httpClient,
IOptions<OpenAIOptions> openAIOptionsAccessor
IOptions<OpenAIOptions> openAIOptionsAccessor,
ILogger<OpenAIClient> logger
) {
private const string COMPLETION_URL_TEMPLATE = "https://api.openai.com/v1/engines/{0}/completions";
private const string CHAT_URL = "https://api.openai.com/v1/chat/completions";
private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() {
PropertyNamingPolicy = new SnakeCaseNamingPolicy()
};
private readonly HttpClient _httpClient = httpClient;
private readonly string _apiKey = openAIOptionsAccessor.Value.ApiKey!;
private readonly string _apiKey = openAIOptionsAccessor.Value.ApiKey!;
private readonly ILogger<OpenAIClient> _logger = logger;

public async Task<string> AutocompleteAsync(string engine, string prompt, string[]? stop, int maxTokens, double frequencyPenalty, double presencePenalty, double temperature, double topP, CancellationToken cancellationToken) {
using HttpRequestMessage request = new(HttpMethod.Post, string.Format(COMPLETION_URL_TEMPLATE, engine)) {
Expand Down Expand Up @@ -111,8 +114,12 @@ [EnumeratorCancellation] CancellationToken cancellationToken
options: JSON_SERIALIZER_OPTIONS
)
};
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) {
string errorMessage = await response.Content.ReadAsStringAsync(cancellationToken);
_logger.LogError(errorMessage);
response.EnsureSuccessStatusCode();
}

StringBuilder result = new();
using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken);
Expand Down
30 changes: 19 additions & 11 deletions BotNet.Services/OpenAI/OpenAIStreamingClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ int replyToMessageId

// Task for continuously consuming the stream
Task downstreamTask = Task.Run(async () => {
await foreach ((string result, bool stop) in enumerable) {
lastResult = result;
try {
await foreach ((string result, bool stop) in enumerable) {
lastResult = result;

if (stop) {
break;
if (stop) {
break;
}
}
} catch (Exception exc) {
_logger.LogError(exc, null);
}
});

Expand All @@ -59,10 +63,12 @@ await Task.WhenAny(
);

// If downstream task is completed, send the last result
if (downstreamTask.IsCompleted) {
if (downstreamTask.IsCompletedSuccessfully) {
if (lastResult is null) return;

Message completeMessage = await telegramBotClient.SendTextMessageAsync(
chatId: chatId,
text: MarkdownV2Sanitizer.Sanitize(lastResult!),
text: MarkdownV2Sanitizer.Sanitize(lastResult),
parseMode: ParseMode.MarkdownV2,
replyToMessageId: replyToMessageId
);
Expand All @@ -72,7 +78,8 @@ await Task.WhenAny(
threadTracker.TrackMessage(
messageId: completeMessage.MessageId,
sender: callSign,
text: lastResult!,
text: lastResult,
imageBase64: null,
replyToMessageId: replyToMessageId
);

Expand All @@ -81,10 +88,10 @@ await Task.WhenAny(
}

// Otherwise, send incomplete result and continue streaming
string lastSent = lastResult!;
string lastSent = lastResult ?? "";
Message incompleteMessage = await telegramBotClient.SendTextMessageAsync(
chatId: chatId,
text: MarkdownV2Sanitizer.Sanitize(lastResult!) + "… ⏳", // ellipsis, nbsp, hourglass emoji
text: MarkdownV2Sanitizer.Sanitize(lastResult ?? "") + "… ⏳", // ellipsis, nbsp, hourglass emoji
parseMode: ParseMode.MarkdownV2,
replyToMessageId: replyToMessageId
);
Expand All @@ -104,7 +111,7 @@ await Task.WhenAny(
await telegramBotClient.EditMessageTextAsync(
chatId: chatId,
messageId: incompleteMessage.MessageId,
text: MarkdownV2Sanitizer.Sanitize(lastResult!) + "… ⏳", // ellipsis, nbsp, hourglass emoji
text: MarkdownV2Sanitizer.Sanitize(lastResult ?? "") + "… ⏳", // ellipsis, nbsp, hourglass emoji
parseMode: ParseMode.MarkdownV2,
cancellationToken: cts.Token
);
Expand All @@ -128,7 +135,7 @@ await telegramBotClient.EditMessageTextAsync(
await telegramBotClient.EditMessageTextAsync(
chatId: chatId,
messageId: incompleteMessage.MessageId,
text: MarkdownV2Sanitizer.Sanitize(lastResult!),
text: MarkdownV2Sanitizer.Sanitize(lastResult ?? ""),
parseMode: ParseMode.MarkdownV2,
cancellationToken: cts.Token
);
Expand All @@ -143,6 +150,7 @@ await telegramBotClient.EditMessageTextAsync(
messageId: incompleteMessage.MessageId,
sender: callSign,
text: lastResult!,
imageBase64: null,
replyToMessageId: replyToMessageId
);
} catch {
Expand Down
1 change: 1 addition & 0 deletions BotNet.Services/OpenAI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public static IServiceCollection AddOpenAIClient(this IServiceCollection service
services.AddTransient<AttachmentGenerator>();
services.AddTransient<TldrGenerator>();
services.AddTransient<IntentDetector>();
services.AddTransient<VisionBot>();
return services;
}
}
Expand Down
Loading

0 comments on commit ceae82d

Please sign in to comment.