From 458043b38695e9494975bd6cc7728b7c1d90bb2d Mon Sep 17 00:00:00 2001 From: Ronny Gunawan <3048897+ronnygunawan@users.noreply.github.com> Date: Sun, 3 Dec 2023 18:51:56 +0700 Subject: [PATCH] Add image generation skill --- BotNet.Services/BotCommands/OpenAI.cs | 46 +++++++++++++++++-- .../OpenAI/Models/ImageGenerationResult.cs | 13 ++++++ BotNet.Services/OpenAI/OpenAIClient.cs | 28 ++++++++++- .../OpenAI/ServiceCollectionExtensions.cs | 1 + .../OpenAI/Skills/ImageGenerationBot.cs | 20 ++++++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 BotNet.Services/OpenAI/Models/ImageGenerationResult.cs create mode 100644 BotNet.Services/OpenAI/Skills/ImageGenerationBot.cs diff --git a/BotNet.Services/BotCommands/OpenAI.cs b/BotNet.Services/BotCommands/OpenAI.cs index bdad8b4..d07e1d0 100644 --- a/BotNet.Services/BotCommands/OpenAI.cs +++ b/BotNet.Services/BotCommands/OpenAI.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.IO; using System.Linq; @@ -7,6 +8,7 @@ using System.Threading.Tasks; using BotNet.Services.MarkdownV2; using BotNet.Services.OpenAI; +using BotNet.Services.OpenAI.Models; using BotNet.Services.OpenAI.Skills; using BotNet.Services.RateLimit; using Microsoft.Extensions.DependencyInjection; @@ -811,11 +813,49 @@ await serviceProvider.GetRequiredService().StreamChatAsync( replyToMessageId: message.MessageId ); } else { - await serviceProvider.GetRequiredService().StreamChatAsync( + IntentDetector intentDetector = serviceProvider.GetRequiredService(); + ChatIntent chatIntent = await intentDetector.DetectChatIntentAsync( message: message.Text!, - chatId: message.Chat.Id, - replyToMessageId: message.MessageId + cancellationToken: cancellationToken ); + + switch (chatIntent) { + case ChatIntent.Question: + await serviceProvider.GetRequiredService().StreamChatAsync( + message: message.Text!, + chatId: message.Chat.Id, + replyToMessageId: message.MessageId + ); + break; + case ChatIntent.ImageGeneration: + Message busyMessage = await botClient.SendTextMessageAsync( + chatId: message.Chat.Id, + text: "Generating image… ⏳", + parseMode: ParseMode.Markdown, + replyToMessageId: message.MessageId, + cancellationToken: cancellationToken + ); + Uri generatedImageUrl = await serviceProvider.GetRequiredService().GenerateImageAsync( + prompt: message.Text!, + cancellationToken: cancellationToken + ); + try { + await botClient.DeleteMessageAsync( + chatId: busyMessage.Chat.Id, + messageId: busyMessage.MessageId, + cancellationToken: cancellationToken + ); + } catch (OperationCanceledException) { + throw; + } + await botClient.SendPhotoAsync( + chatId: message.Chat.Id, + photo: new InputFileUrl(generatedImageUrl), + replyToMessageId: message.MessageId, + cancellationToken: cancellationToken + ); + break; + } } } catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) { if (message.Chat.Type == ChatType.Private) { diff --git a/BotNet.Services/OpenAI/Models/ImageGenerationResult.cs b/BotNet.Services/OpenAI/Models/ImageGenerationResult.cs new file mode 100644 index 0000000..8a88151 --- /dev/null +++ b/BotNet.Services/OpenAI/Models/ImageGenerationResult.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace BotNet.Services.OpenAI.Models { + public record ImageGenerationResult( + [property: JsonPropertyName("created")] int CreatedUnixTime, + List Data + ); + + public record GeneratedImage( + string Url + ); +} diff --git a/BotNet.Services/OpenAI/OpenAIClient.cs b/BotNet.Services/OpenAI/OpenAIClient.cs index 4bdb54e..07635db 100644 --- a/BotNet.Services/OpenAI/OpenAIClient.cs +++ b/BotNet.Services/OpenAI/OpenAIClient.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.IO; using System.Net.Http; using System.Net.Http.Json; @@ -20,7 +21,8 @@ public class OpenAIClient( ILogger logger ) { private const string COMPLETION_URL_TEMPLATE = "https://api.openai.com/v1/engines/{0}/completions"; - private const string CHAT_URL = "https://api.openai.com/v1/chat/completions"; + private const string CHAT_URL = "https://api.openai.com/v1/chat/completions"; + private const string IMAGE_GENERATION_URL = "https://api.openai.com/v1/images/generations"; private static readonly JsonSerializerOptions JSON_SERIALIZER_OPTIONS = new() { PropertyNamingPolicy = new SnakeCaseNamingPolicy() }; @@ -172,6 +174,28 @@ [EnumeratorCancellation] CancellationToken cancellationToken ); } } + } + + public async Task GenerateImageAsync(string model, string prompt, CancellationToken cancellationToken) { + using HttpRequestMessage request = new(HttpMethod.Post, IMAGE_GENERATION_URL) { + Headers = { + { "Authorization", $"Bearer {_apiKey}" } + }, + Content = JsonContent.Create( + inputValue: new { + Model = model, + Prompt = prompt, + N = 1, + Size = "1024x1024" + }, + options: JSON_SERIALIZER_OPTIONS + ) + }; + using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + + ImageGenerationResult? imageGenerationResult = await response.Content.ReadFromJsonAsync(JSON_SERIALIZER_OPTIONS, cancellationToken); + return new(imageGenerationResult!.Data[0].Url); } } } diff --git a/BotNet.Services/OpenAI/ServiceCollectionExtensions.cs b/BotNet.Services/OpenAI/ServiceCollectionExtensions.cs index b2cf318..45541a7 100644 --- a/BotNet.Services/OpenAI/ServiceCollectionExtensions.cs +++ b/BotNet.Services/OpenAI/ServiceCollectionExtensions.cs @@ -17,6 +17,7 @@ public static IServiceCollection AddOpenAIClient(this IServiceCollection service services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); return services; } } diff --git a/BotNet.Services/OpenAI/Skills/ImageGenerationBot.cs b/BotNet.Services/OpenAI/Skills/ImageGenerationBot.cs new file mode 100644 index 0000000..c4221de --- /dev/null +++ b/BotNet.Services/OpenAI/Skills/ImageGenerationBot.cs @@ -0,0 +1,20 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace BotNet.Services.OpenAI.Skills { + public class ImageGenerationBot( + OpenAIClient openAIClient + ) { + private readonly OpenAIClient _openAIClient = openAIClient; + + public Task GenerateImageAsync( + string prompt, + CancellationToken cancellationToken + ) => _openAIClient.GenerateImageAsync( + model: "dall-e-3", + prompt: prompt, + cancellationToken: cancellationToken + ); + } +}