Skip to content

Commit

Permalink
Implement SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
ronnygunawan committed Feb 17, 2024
1 parent f2a4779 commit 4fa3945
Show file tree
Hide file tree
Showing 13 changed files with 606 additions and 0 deletions.
58 changes: 58 additions & 0 deletions BotNet.CommandHandlers/BotUpdate/Message/MessageUpdateHandler.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using BotNet.Commands;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Commands.CommandPrioritization;
using BotNet.Commands.SQL;
using BotNet.Services.BotProfile;
using BotNet.Services.SocialLink;
using RG.Ninja;
using SqlParser;
using SqlParser.Ast;
using Telegram.Bot;
using Telegram.Bot.Types.Enums;

Expand Down Expand Up @@ -157,6 +160,61 @@ out AIFollowUpMessage? aiFollowUpMessage
);

await _commandQueue.DispatchAsync(aiFollowUpMessage);
return;
}

// Handle SQL
if (update.Message is {
ReplyToMessage: null,
Text: { } text
} && text.StartsWith("select", StringComparison.OrdinalIgnoreCase)) {
try {
Sequence<Statement> ast = new SqlParser.Parser().ParseSql(text);
if (ast.Count > 1) {
// Fire and forget
Task _ = Task.Run(async () => {
try {
await _telegramBotClient.SendTextMessageAsync(
chatId: update.Message.Chat.Id,
text: $"Your SQL contains more than one statement.",
replyToMessageId: update.Message.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
// Terminate gracefully
}
});
return;
}
if (ast[0] is not Statement.Select selectStatement) {
// Fire and forget
Task _ = Task.Run(async () => {
try {
await _telegramBotClient.SendTextMessageAsync(
chatId: update.Message.Chat.Id,
text: $"Your SQL is not a SELECT statement.",
replyToMessageId: update.Message.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
// Terminate gracefully
}
});
return;
}
if (SQLCommand.TryCreate(
message: update.Message,
commandPriorityCategorizer: _commandPriorityCategorizer,
sqlCommand: out SQLCommand? sqlCommand
)) {
await _commandQueue.DispatchAsync(
command: sqlCommand
);
return;
}
} catch {
// Suppress
}
}
}
}
Expand Down
184 changes: 184 additions & 0 deletions BotNet.CommandHandlers/SQL/SQLCommandHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
using System.Text;
using BotNet.Commands.SQL;
using BotNet.Services.SQL;
using BotNet.Services.Sqlite;
using Microsoft.Extensions.DependencyInjection;
using SqlParser.Ast;
using Telegram.Bot;
using Telegram.Bot.Types.Enums;

namespace BotNet.CommandHandlers.SQL {
public sealed class SQLCommandHandler(
ITelegramBotClient telegramBotClient,
IServiceProvider serviceProvider
) : ICommandHandler<SQLCommand> {
private readonly ITelegramBotClient _telegramBotClient = telegramBotClient;
private readonly IServiceProvider _serviceProvider = serviceProvider;

public async Task Handle(SQLCommand command, CancellationToken cancellationToken) {
if (command.SelectStatement.Query.Body.AsSelectExpression().Select.From is not { } froms
|| froms.Count == 0) {
await _telegramBotClient.SendTextMessageAsync(
chatId: command.Chat.Id,
text: "No FROM clause found.",
replyToMessageId: command.SQLMessageId,
cancellationToken: cancellationToken
);
return;
}

// Collect table names from query
HashSet<string> tables = new();
foreach (TableWithJoins from in froms) {
if (from.Relation != null) {
CollectTableNames(ref tables, from.Relation);
}

if (from.Joins != null) {
foreach (Join join in from.Joins) {
if (join.Relation != null) {
CollectTableNames(ref tables, join.Relation);
}
}
}
}

// Create scoped for scoped database
using IServiceScope serviceScope = _serviceProvider.CreateScope();

// Load tables into memory
foreach (string table in tables) {
IScopedDataSource? dataSource = serviceScope.ServiceProvider.GetKeyedService<IScopedDataSource>(table);
if (dataSource == null) {
await _telegramBotClient.SendTextMessageAsync(
chatId: command.Chat.Id,
text: $$"""
Table '{{table}}' not found. Available tables are:
- pilpres
""",
replyToMessageId: command.SQLMessageId,
cancellationToken: cancellationToken
);
return;
}

await dataSource.LoadTableAsync(cancellationToken);
}

// Execute query
using ScopedDatabase scopedDatabase = serviceScope.ServiceProvider.GetRequiredService<ScopedDatabase>();
StringBuilder resultBuilder = new();
scopedDatabase.ExecuteReader(
commandText: command.RawStatement,
readAction: (reader) => {
string[] values = new string[reader.FieldCount];

// Get column names
for (int i = 0; i < reader.FieldCount; i++) {
values[i] = '"' + reader.GetName(i).Replace("\"", "\"\"") + '"';
}
resultBuilder.AppendLine(string.Join(',', values));

// Get rows
while (reader.Read()) {
for (int i = 0; i < reader.FieldCount; i++) {
if (reader.IsDBNull(i)) {
values[i] = "";
continue;
}

Type fieldType = reader.GetFieldType(i);
if (fieldType == typeof(string)) {
values[i] = '"' + reader.GetString(i).Replace("\"", "\"\"") + '"';
} else if (fieldType == typeof(int)) {
values[i] = reader.GetInt32(i).ToString();
} else if (fieldType == typeof(long)) {
values[i] = reader.GetInt64(i).ToString();
} else if (fieldType == typeof(float)) {
values[i] = reader.GetFloat(i).ToString();
} else if (fieldType == typeof(double)) {
values[i] = reader.GetDouble(i).ToString();
} else if (fieldType == typeof(decimal)) {
values[i] = reader.GetDecimal(i).ToString();
} else if (fieldType == typeof(bool)) {
values[i] = reader.GetBoolean(i).ToString();
} else if (fieldType == typeof(DateTime)) {
values[i] = reader.GetDateTime(i).ToString();
} else if (fieldType == typeof(byte[])) {
values[i] = BitConverter.ToString(reader.GetFieldValue<byte[]>(i)).Replace("-", "");
} else {
values[i] = reader[i].ToString();
}
}
resultBuilder.AppendLine(string.Join(',', values));
}
}
);

// Send result
await _telegramBotClient.SendTextMessageAsync(
chatId: command.Chat.Id,
text: "```csv\n" + resultBuilder.ToString() + "```",
parseMode: ParseMode.MarkdownV2,
replyToMessageId: command.SQLMessageId,
cancellationToken: cancellationToken
);

return;
}

private static void CollectTableNames(ref HashSet<string> tables, TableFactor tableFactor) {
switch (tableFactor) {
case TableFactor.Derived derived:
if (derived.SubQuery.Body.AsSelectExpression().Select.From is { } derivedFroms) {
foreach (TableWithJoins derivedFrom in derivedFroms) {
if (derivedFrom.Relation != null) {
CollectTableNames(ref tables, derivedFrom.Relation);
}

if (derivedFrom.Joins != null) {
foreach (Join join in derivedFrom.Joins) {
if (join.Relation != null) {
CollectTableNames(ref tables, join.Relation);
}
}
}
}
}
break;
case TableFactor.Function function:
break;
case TableFactor.JsonTable jsonTable:
break;
case TableFactor.NestedJoin nestedJoin:
if (nestedJoin.TableWithJoins != null) {
if (nestedJoin.TableWithJoins.Relation != null) {
CollectTableNames(ref tables, nestedJoin.TableWithJoins.Relation);
}

if (nestedJoin.TableWithJoins.Joins != null) {
foreach (Join join in nestedJoin.TableWithJoins.Joins) {
if (join.Relation != null) {
CollectTableNames(ref tables, join.Relation);
}
}
}
}
break;
case TableFactor.Pivot pivot:
CollectTableNames(ref tables, pivot.TableFactor);
break;
case TableFactor.Table table:
tables.Add(table.Name.ToString());
break;
case TableFactor.TableFunction tableFunction:
break;
case TableFactor.UnNest unNest:
break;
case TableFactor.Unpivot unpivot:
tables.Add(unpivot.Name.ToString());
break;
}
}
}
}
2 changes: 2 additions & 0 deletions BotNet.Commands/BotNet.Commands.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

<ItemGroup>
<PackageReference Include="MediatR" Version="12.2.0" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.2" />
<PackageReference Include="SqlParserCS" Version="0.2.2" />
<PackageReference Include="Telegram.Bot" Version="19.0.0" />
</ItemGroup>

Expand Down
74 changes: 74 additions & 0 deletions BotNet.Commands/SQL/SQLCommand.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using System.Diagnostics.CodeAnalysis;
using BotNet.Commands.BotUpdate.Message;
using BotNet.Commands.ChatAggregate;
using BotNet.Commands.CommandPrioritization;
using SqlParser;
using SqlParser.Ast;

namespace BotNet.Commands.SQL {
public sealed record SQLCommand : ICommand {
public string RawStatement { get; }
public Statement.Select SelectStatement { get; }
public MessageId SQLMessageId { get; }
public ChatBase Chat { get; }

private SQLCommand(
string rawStatement,
Statement.Select selectStatement,
MessageId sqlMessageId,
ChatBase chat
) {
RawStatement = rawStatement;
SelectStatement = selectStatement;
SQLMessageId = sqlMessageId;
Chat = chat;
}

public static bool TryCreate(
Telegram.Bot.Types.Message message,
CommandPriorityCategorizer commandPriorityCategorizer,
[NotNullWhen(true)] out SQLCommand? sqlCommand
) {
// Must start with select
if (message.Text is not { } text || !text.StartsWith("select", StringComparison.OrdinalIgnoreCase)) {
sqlCommand = null;
return false;
}

// Chat must be private or group
if (!ChatBase.TryCreate(message.Chat, commandPriorityCategorizer, out ChatBase? chat)) {
sqlCommand = null;
return false;
}

// Must be a valid SQL statement
Sequence<Statement> ast;
try {
ast = new SqlParser.Parser().ParseSql(text);
} catch {
sqlCommand = null;
return false;
}

// Can only contain one statement
if (ast.Count != 1) {
sqlCommand = null;
return false;
}

// Must be a SELECT statement
if (ast[0] is not Statement.Select selectStatement) {
sqlCommand = null;
return false;
}

sqlCommand = new(
rawStatement: text,
selectStatement: selectStatement,
sqlMessageId: new(message.MessageId),
chat: chat
);
return true;
}
}
}
1 change: 1 addition & 0 deletions BotNet.Services/BotNet.Services.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
<PackageReference Include="Grpc.Net.Client" Version="2.60.0" />
<PackageReference Include="Microsoft.ClearScript" Version="7.4.4" />
<PackageReference Include="Microsoft.ClearScript.V8.Native.linux-x64" Version="7.4.4" />
<PackageReference Include="Microsoft.Data.Sqlite.Core" Version="8.0.2" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.0" />
Expand Down
Loading

0 comments on commit 4fa3945

Please sign in to comment.