From 8d58ab82b2558eec34a48ed41df77a83aaf95d0e Mon Sep 17 00:00:00 2001 From: Ronny Gunawan <3048897+ronnygunawan@users.noreply.github.com> Date: Tue, 5 Dec 2023 00:45:22 +0700 Subject: [PATCH] Add Stability image variation bot --- .../Stability/ServiceCollectionExtensions.cs | 1 + .../Stability/Skills/ImageVariationBot.cs | 23 +++++++ BotNet.Services/Stability/StabilityClient.cs | 66 +++++++++++++++++++ 3 files changed, 90 insertions(+) create mode 100644 BotNet.Services/Stability/Skills/ImageVariationBot.cs diff --git a/BotNet.Services/Stability/ServiceCollectionExtensions.cs b/BotNet.Services/Stability/ServiceCollectionExtensions.cs index 8c995c8..60f1223 100644 --- a/BotNet.Services/Stability/ServiceCollectionExtensions.cs +++ b/BotNet.Services/Stability/ServiceCollectionExtensions.cs @@ -6,6 +6,7 @@ public static class ServiceCollectionExtensions { public static IServiceCollection AddStabilityClient(this IServiceCollection services) { services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); return services; } } diff --git a/BotNet.Services/Stability/Skills/ImageVariationBot.cs b/BotNet.Services/Stability/Skills/ImageVariationBot.cs new file mode 100644 index 0000000..2ace9bb --- /dev/null +++ b/BotNet.Services/Stability/Skills/ImageVariationBot.cs @@ -0,0 +1,23 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace BotNet.Services.Stability.Skills { + public sealed class ImageVariationBot( + StabilityClient stabilityClient + ) { + private readonly StabilityClient _stabilityClient = stabilityClient; + + public async Task ModifyImageAsync( + byte[] image, + string prompt, + CancellationToken cancellationToken + ) { + return await _stabilityClient.ModifyImageAsync( + engine: "stable-diffusion-xl-1024-v1-0", + promptImage: image, + promptText: prompt, + cancellationToken: cancellationToken + ); + } + } +} diff --git a/BotNet.Services/Stability/StabilityClient.cs b/BotNet.Services/Stability/StabilityClient.cs index e1c6f57..a58a87c 100644 --- a/BotNet.Services/Stability/StabilityClient.cs +++ b/BotNet.Services/Stability/StabilityClient.cs @@ -16,6 +16,8 @@ public sealed class StabilityClient( ILogger logger ) { private const string TEXT_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/text-to-image"; + private const string IMAGE_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/image-to-image"; + private static readonly JsonSerializerOptions SNAKE_CASE_SERIALIZER_OPTIONS = new() { PropertyNamingPolicy = new SnakeCaseNamingPolicy() }; @@ -73,5 +75,69 @@ CancellationToken cancellationToken return Convert.FromBase64String(base64); } + + public async Task ModifyImageAsync( + string engine, + byte[] promptImage, + string promptText, + CancellationToken cancellationToken + ) { + string url = string.Format(IMAGE_TO_IMAGE_URL_TEMPLATE, engine); + using HttpRequestMessage request = new(HttpMethod.Post, url); + request.Headers.Add("Authorization", $"Bearer {_apiKey}"); + request.Headers.Add("Accept", "application/json"); + using MultipartFormDataContent formData = new(); + using ByteArrayContent promptImageContent = new(promptImage); + formData.Add( + content: promptImageContent, + name: "init_image", + fileName: "init_image.png" + ); + using StringContent initImageMode = new("IMAGE_STRENGTH"); + using StringContent imageStrength = new("0.35"); + using StringContent steps = new("40"); + using StringContent width = new("1024"); + using StringContent height = new("1024"); + using StringContent seed = new("0"); + using StringContent cfgScale = new("5"); + using StringContent samples = new("1"); + using StringContent textPrompts0Text = new(promptText); + using StringContent textPrompts0Weight = new("1"); + using StringContent textPrompts1Text = new("blurry, bad, saturated, high contrast, watermark, signature, label, worst quality, normal quality, low quality, low res, extra digits, cropped, jpeg artifacts, username, error, duplicate, ugly, monochrome, mutation, disgusting, bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms"); + using StringContent textPrompts1Weight = new("-1"); + formData.Add(initImageMode, "init_image_mode"); + formData.Add(imageStrength, "image_strength"); + formData.Add(steps, "steps"); + formData.Add(width, "width"); + formData.Add(height, "height"); + formData.Add(seed, "seed"); + formData.Add(cfgScale, "cfg_scale"); + formData.Add(samples, "samples"); + formData.Add(textPrompts0Text, "text_prompts[0][text]"); + formData.Add(textPrompts0Weight, "text_prompts[0][weight]"); + formData.Add(textPrompts1Text, "text_prompts[1][text]"); + formData.Add(textPrompts1Weight, "text_prompts[1][weight]"); + request.Content = formData; + using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + if (!response.IsSuccessStatusCode) { + string error = await response.Content.ReadAsStringAsync(cancellationToken); + _logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode); + response.EnsureSuccessStatusCode(); + } + + string responseJson = await response.Content.ReadAsStringAsync(cancellationToken); + + TextToImageResponse? responseData = JsonSerializer.Deserialize(responseJson, CAMEL_CASE_SERIALIZER_OPTIONS); + + if (responseData is { Artifacts: [Artifact { FinishReason: "CONTENT_FILTERED" }] }) { + throw new ContentFilteredException(); + } + + if (responseData is not { Artifacts: [Artifact { FinishReason: "SUCCESS", Base64: var base64 }] }) { + throw new HttpRequestException(); + } + + return Convert.FromBase64String(base64); + } } }