Skip to content

Commit

Permalink
Merge pull request #105 from teknologi-umum/gemini
Browse files Browse the repository at this point in the history
Gemini
  • Loading branch information
ronnygunawan authored Jan 31, 2024
2 parents f8f95d4 + b6229b3 commit 09976d3
Show file tree
Hide file tree
Showing 19 changed files with 413 additions and 0 deletions.
132 changes: 132 additions & 0 deletions BotNet.CommandHandlers/AI/Gemini/GeminiTextPromptHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Commands.ChatAggregate;
using BotNet.Commands.CommandPrioritization;
using BotNet.Services.Gemini;
using BotNet.Services.Gemini.Models;
using BotNet.Services.MarkdownV2;
using BotNet.Services.RateLimit;
using Microsoft.Extensions.Logging;
using Telegram.Bot;
using Telegram.Bot.Types;
using Telegram.Bot.Types.Enums;

namespace BotNet.CommandHandlers.AI.Gemini {
public sealed class GeminiTextPromptHandler(
ITelegramBotClient telegramBotClient,
GeminiClient geminiClient,
ITelegramMessageCache telegramMessageCache,
CommandPriorityCategorizer commandPriorityCategorizer,
ILogger<GeminiTextPromptHandler> logger
) : ICommandHandler<GeminiTextPrompt> {
internal static readonly RateLimiter CHAT_RATE_LIMITER = RateLimiter.PerChat(60, TimeSpan.FromMinutes(1));

private readonly ITelegramBotClient _telegramBotClient = telegramBotClient;
private readonly GeminiClient _geminiClient = geminiClient;
private readonly ITelegramMessageCache _telegramMessageCache = telegramMessageCache;
private readonly CommandPriorityCategorizer _commandPriorityCategorizer = commandPriorityCategorizer;
private readonly ILogger<GeminiTextPromptHandler> _logger = logger;

public Task Handle(GeminiTextPrompt textPrompt, CancellationToken cancellationToken) {
if (textPrompt.Command.Chat is not HomeGroupChat) {
return _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: MarkdownV2Sanitizer.Sanitize("Gemini tidak bisa dipakai di sini."),
parseMode: ParseMode.MarkdownV2,
replyToMessageId: textPrompt.Command.MessageId,
cancellationToken: cancellationToken
);
}

try {
CHAT_RATE_LIMITER.ValidateActionRate(
chatId: textPrompt.Command.Chat.Id,
userId: textPrompt.Command.Sender.Id
);
} catch (RateLimitExceededException exc) {
return _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: $"<code>Anda terlalu banyak memanggil AI. Coba lagi {exc.Cooldown}.</code>",
parseMode: ParseMode.Html,
replyToMessageId: textPrompt.Command.MessageId,
cancellationToken: cancellationToken
);
}

// Fire and forget
Task.Run(async () => {
List<Content> messages = [];

// Merge adjacent messages from same role
foreach (MessageBase message in textPrompt.Thread.Reverse()) {
Content content = Content.FromText(
role: message.Sender.GeminiRole,
text: message.Text
);

if (messages.Count > 0
&& messages[^1].Role == message.Sender.GeminiRole) {
messages[^1].Add(content);
} else {
messages.Add(content);
}
}

// Trim thread longer than 10 messages
while (messages.Count > 10) {
messages.RemoveAt(0);
}

// Thread must start with user message
while (messages.Count > 0
&& messages[0].Role != "user") {
messages.RemoveAt(0);
}

messages.Add(
Content.FromText("user", textPrompt.Prompt)
);

Message responseMessage = await _telegramBotClient.SendTextMessageAsync(
chatId: textPrompt.Command.Chat.Id,
text: MarkdownV2Sanitizer.Sanitize("… ⏳"),
parseMode: ParseMode.MarkdownV2,
replyToMessageId: textPrompt.Command.MessageId
);

string response = await _geminiClient.ChatAsync(
messages: messages,
maxTokens: 512,
cancellationToken: cancellationToken
);

// Finalize message
try {
responseMessage = await telegramBotClient.EditMessageTextAsync(
chatId: textPrompt.Command.Chat.Id,
messageId: responseMessage.MessageId,
text: MarkdownV2Sanitizer.Sanitize(response),
parseMode: ParseMode.MarkdownV2,
cancellationToken: cancellationToken
);
} catch (Exception exc) {
_logger.LogError(exc, null);
throw;
}

// Track thread
_telegramMessageCache.Add(
message: AIResponseMessage.FromMessage(
message: responseMessage,
replyToMessage: textPrompt.Command,
callSign: "Gemini",
commandPriorityCategorizer: _commandPriorityCategorizer
)
);
});

return Task.CompletedTask;
}
}
}
12 changes: 12 additions & 0 deletions BotNet.CommandHandlers/BotUpdate/Message/AICallCommandHandler.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.AI.OpenAI;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Services.OpenAI;
Expand Down Expand Up @@ -37,6 +38,17 @@ await _commandQueue.DispatchAsync(
);
break;
}
case "Gemini" when command.ImageFileId is null && command.ReplyToMessage?.ImageFileId is null: {
await _commandQueue.DispatchAsync(
command: GeminiTextPrompt.FromAICallCommand(
aiCallCommand: command,
thread: command.ReplyToMessage is { } replyToMessage
? _telegramMessageCache.GetThread(replyToMessage)
: Enumerable.Empty<MessageBase>()
)
);
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotNet.Commands;
using BotNet.Commands.AI.Gemini;
using BotNet.Commands.AI.OpenAI;
using BotNet.Commands.BotUpdate.Message;

Expand All @@ -25,6 +26,18 @@ await _commandQueue.DispatchAsync(
)
);
break;
case "Gemini":
await _commandQueue.DispatchAsync(
command: GeminiTextPrompt.FromAIFollowUpMessage(
aIFollowUpMessage: command,
thread: command.ReplyToMessage is null
? Enumerable.Empty<MessageBase>()
: _telegramMessageCache.GetThread(
firstMessage: command.ReplyToMessage
)
)
);
break;
}
}
}
Expand Down
77 changes: 77 additions & 0 deletions BotNet.Commands/AI/Gemini/GeminiTextPrompt.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using BotNet.Commands.BotUpdate.Message;

namespace BotNet.Commands.AI.Gemini {
public sealed record GeminiTextPrompt : ICommand {
public string Prompt { get; }
public HumanMessageBase Command { get; }
public IEnumerable<MessageBase> Thread { get; }

private GeminiTextPrompt(
string prompt,
HumanMessageBase command,
IEnumerable<MessageBase> thread
) {
Prompt = prompt;
Command = command;
Thread = thread;
}

public static GeminiTextPrompt FromAICallCommand(AICallCommand aiCallCommand, IEnumerable<MessageBase> thread) {
// Call sign must be Gemini
if (aiCallCommand.CallSign != "Gemini") {
throw new ArgumentException("Call sign must be Gemini", nameof(aiCallCommand));
}

// Prompt must be non-empty
if (string.IsNullOrWhiteSpace(aiCallCommand.Text)) {
throw new ArgumentException("Prompt must be non-empty", nameof(aiCallCommand));
}

// Non-empty thread must begin with reply to message
if (thread.FirstOrDefault() is {
MessageId: { } firstMessageId,
Chat.Id: { } firstChatId
}) {
if (firstMessageId != aiCallCommand.ReplyToMessage?.MessageId
|| firstChatId != aiCallCommand.Chat.Id) {
throw new ArgumentException("Thread must begin with reply to message", nameof(thread));
}
}

return new(
prompt: aiCallCommand.Text,
command: aiCallCommand,
thread: thread
);
}

public static GeminiTextPrompt FromAIFollowUpMessage(AIFollowUpMessage aIFollowUpMessage, IEnumerable<MessageBase> thread) {
// Call sign must be Gemini
if (aIFollowUpMessage.CallSign != "Gemini") {
throw new ArgumentException("Call sign must be Gemini", nameof(aIFollowUpMessage));
}

// Prompt must be non-empty
if (string.IsNullOrWhiteSpace(aIFollowUpMessage.Text)) {
throw new ArgumentException("Prompt must be non-empty", nameof(aIFollowUpMessage));
}

// Non-empty thread must begin with reply to message
if (thread.FirstOrDefault() is {
MessageId: { } firstMessageId,
Chat.Id: { } firstChatId
}) {
if (firstMessageId != aIFollowUpMessage.ReplyToMessage?.MessageId
|| firstChatId != aIFollowUpMessage.Chat.Id) {
throw new ArgumentException("Thread must begin with reply to message", nameof(thread));
}
}

return new(
prompt: aIFollowUpMessage.Text,
command: aIFollowUpMessage,
thread: thread
);
}
}
}
3 changes: 3 additions & 0 deletions BotNet.Commands/SenderAggregate/Sender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ public abstract record SenderBase(
string Name
) {
public abstract string ChatGPTRole { get; }
public abstract string GeminiRole { get; }
}

public record HumanSender(
SenderId Id,
string Name
) : SenderBase(Id, Name) {
public override string ChatGPTRole => "user";
public override string GeminiRole => "user";

public static bool TryCreate(
Telegram.Bot.Types.User user,
Expand Down Expand Up @@ -51,6 +53,7 @@ public sealed record BotSender(
string Name
) : SenderBase(Id, Name) {
public override string ChatGPTRole => "assistant";
public override string GeminiRole => "model";

public static bool TryCreate(
Telegram.Bot.Types.User user,
Expand Down
54 changes: 54 additions & 0 deletions BotNet.Services/Gemini/GeminiClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using BotNet.Services.Gemini.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace BotNet.Services.Gemini {
public class GeminiClient(
HttpClient httpClient,
IOptions<GeminiOptions> geminiOptionsAccessor,
ILogger<GeminiClient> logger
) {
private const string BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent";
private readonly HttpClient _httpClient = httpClient;
private readonly string _apiKey = geminiOptionsAccessor.Value.ApiKey!;
private readonly ILogger<GeminiClient> _logger = logger;

public async Task<string> ChatAsync(IEnumerable<Content> messages, int maxTokens, CancellationToken cancellationToken) {
GeminiRequest geminiRequest = new(
Contents: messages.ToImmutableList(),
SafetySettings: null,
GenerationConfig: new(
MaxOutputTokens: maxTokens
)
);
using HttpRequestMessage request = new(HttpMethod.Post, BASE_URL + $"?key={_apiKey}") {
Headers = {
{ "Accept", "application/json" }
},
Content = JsonContent.Create(
inputValue: geminiRequest
)
};
using HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
string responseContent = await response.Content.ReadAsStringAsync(cancellationToken);
response.EnsureSuccessStatusCode();

GeminiResponse? geminiResponse = JsonSerializer.Deserialize<GeminiResponse>(responseContent);
if (geminiResponse == null) return "";
if (geminiResponse.Candidates == null) return "";
if (geminiResponse.Candidates.Count == 0) return "";
Content? content = geminiResponse.Candidates[0].Content;
if (content == null) return "";
if (content.Parts == null) return "";
if (content.Parts.Count == 0) return "";
return content.Parts[0].Text ?? "";
}
}
}
5 changes: 5 additions & 0 deletions BotNet.Services/Gemini/GeminiOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace BotNet.Services.Gemini {
public class GeminiOptions {
public string? ApiKey { get; set; }
}
}
11 changes: 11 additions & 0 deletions BotNet.Services/Gemini/Models/Candidate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Collections.Immutable;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public sealed record Candidate(
[property: JsonPropertyName("content")] Content? Content,
[property: JsonPropertyName("finishReason")] string? FinishReason,
[property: JsonPropertyName("index")] int? Index,
[property: JsonPropertyName("safetyRatings")] ImmutableList<SafetyRating>? SafetyRatings
);
}
22 changes: 22 additions & 0 deletions BotNet.Services/Gemini/Models/Content.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public record Content(
[property: JsonPropertyName("role"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] string? Role,
[property: JsonPropertyName("parts")] List<Part>? Parts
) {
public static Content FromText(string role, string text) => new(
Role: role,
Parts: [
new(Text: text)
]
);

public void Add(Content content) {
if (content.Role != Role) throw new InvalidOperationException();
Parts!.AddRange(content.Parts!);
}
}
}
10 changes: 10 additions & 0 deletions BotNet.Services/Gemini/Models/GeminiRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.Collections.Immutable;
using System.Text.Json.Serialization;

namespace BotNet.Services.Gemini.Models {
public sealed record GeminiRequest(
[property: JsonPropertyName("contents")] ImmutableList<Content> Contents,
[property: JsonPropertyName("safetySettings"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] ImmutableList<SafetySettings>? SafetySettings,
[property: JsonPropertyName("generationConfig"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] GenerationConfig? GenerationConfig
);
}
Loading

0 comments on commit 09976d3

Please sign in to comment.