Skip to content

Commit

Permalink
Add Stability image variation bot
Browse files Browse the repository at this point in the history
  • Loading branch information
ronnygunawan committed Dec 4, 2023
1 parent a5ad180 commit 8d58ab8
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
1 change: 1 addition & 0 deletions BotNet.Services/Stability/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public static class ServiceCollectionExtensions {
public static IServiceCollection AddStabilityClient(this IServiceCollection services) {
services.AddSingleton<StabilityClient>();
services.AddSingleton<ImageGenerationBot>();
services.AddSingleton<ImageVariationBot>();
return services;
}
}
Expand Down
23 changes: 23 additions & 0 deletions BotNet.Services/Stability/Skills/ImageVariationBot.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System.Threading;
using System.Threading.Tasks;

namespace BotNet.Services.Stability.Skills {
public sealed class ImageVariationBot(
StabilityClient stabilityClient
) {
private readonly StabilityClient _stabilityClient = stabilityClient;

public async Task<byte[]> ModifyImageAsync(
byte[] image,
string prompt,
CancellationToken cancellationToken
) {
return await _stabilityClient.ModifyImageAsync(
engine: "stable-diffusion-xl-1024-v1-0",
promptImage: image,
promptText: prompt,
cancellationToken: cancellationToken
);
}
}
}
66 changes: 66 additions & 0 deletions BotNet.Services/Stability/StabilityClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public sealed class StabilityClient(
ILogger<StabilityClient> logger
) {
private const string TEXT_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/text-to-image";
private const string IMAGE_TO_IMAGE_URL_TEMPLATE = "https://api.stability.ai/v1/generation/{0}/image-to-image";

private static readonly JsonSerializerOptions SNAKE_CASE_SERIALIZER_OPTIONS = new() {
PropertyNamingPolicy = new SnakeCaseNamingPolicy()
};
Expand Down Expand Up @@ -73,5 +75,69 @@ CancellationToken cancellationToken

return Convert.FromBase64String(base64);
}

public async Task<byte[]> ModifyImageAsync(
string engine,
byte[] promptImage,
string promptText,
CancellationToken cancellationToken
) {
string url = string.Format(IMAGE_TO_IMAGE_URL_TEMPLATE, engine);
using HttpRequestMessage request = new(HttpMethod.Post, url);
request.Headers.Add("Authorization", $"Bearer {_apiKey}");
request.Headers.Add("Accept", "application/json");
using MultipartFormDataContent formData = new();
using ByteArrayContent promptImageContent = new(promptImage);
formData.Add(
content: promptImageContent,
name: "init_image",
fileName: "init_image.png"
);
using StringContent initImageMode = new("IMAGE_STRENGTH");
using StringContent imageStrength = new("0.35");
using StringContent steps = new("40");
using StringContent width = new("1024");
using StringContent height = new("1024");
using StringContent seed = new("0");
using StringContent cfgScale = new("5");
using StringContent samples = new("1");
using StringContent textPrompts0Text = new(promptText);
using StringContent textPrompts0Weight = new("1");
using StringContent textPrompts1Text = new("blurry, bad, saturated, high contrast, watermark, signature, label, worst quality, normal quality, low quality, low res, extra digits, cropped, jpeg artifacts, username, error, duplicate, ugly, monochrome, mutation, disgusting, bad anatomy, bad hands, three hands, three legs, bad arms, missing legs, missing arms");
using StringContent textPrompts1Weight = new("-1");
formData.Add(initImageMode, "init_image_mode");
formData.Add(imageStrength, "image_strength");
formData.Add(steps, "steps");
formData.Add(width, "width");
formData.Add(height, "height");
formData.Add(seed, "seed");
formData.Add(cfgScale, "cfg_scale");
formData.Add(samples, "samples");
formData.Add(textPrompts0Text, "text_prompts[0][text]");
formData.Add(textPrompts0Weight, "text_prompts[0][weight]");
formData.Add(textPrompts1Text, "text_prompts[1][text]");
formData.Add(textPrompts1Weight, "text_prompts[1][weight]");
request.Content = formData;
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) {
string error = await response.Content.ReadAsStringAsync(cancellationToken);
_logger.LogError("Unable to generate image: {0}, HTTP Status Code: {1}", error, (int)response.StatusCode);
response.EnsureSuccessStatusCode();
}

string responseJson = await response.Content.ReadAsStringAsync(cancellationToken);

TextToImageResponse? responseData = JsonSerializer.Deserialize<TextToImageResponse>(responseJson, CAMEL_CASE_SERIALIZER_OPTIONS);

if (responseData is { Artifacts: [Artifact { FinishReason: "CONTENT_FILTERED" }] }) {
throw new ContentFilteredException();
}

if (responseData is not { Artifacts: [Artifact { FinishReason: "SUCCESS", Base64: var base64 }] }) {
throw new HttpRequestException();
}

return Convert.FromBase64String(base64);
}
}
}

0 comments on commit 8d58ab8

Please sign in to comment.