-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2a4779
commit 4fa3945
Showing
13 changed files
with
606 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
using System.Text; | ||
using BotNet.Commands.SQL; | ||
using BotNet.Services.SQL; | ||
using BotNet.Services.Sqlite; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using SqlParser.Ast; | ||
using Telegram.Bot; | ||
using Telegram.Bot.Types.Enums; | ||
|
||
namespace BotNet.CommandHandlers.SQL { | ||
public sealed class SQLCommandHandler( | ||
ITelegramBotClient telegramBotClient, | ||
IServiceProvider serviceProvider | ||
) : ICommandHandler<SQLCommand> { | ||
private readonly ITelegramBotClient _telegramBotClient = telegramBotClient; | ||
private readonly IServiceProvider _serviceProvider = serviceProvider; | ||
|
||
public async Task Handle(SQLCommand command, CancellationToken cancellationToken) { | ||
if (command.SelectStatement.Query.Body.AsSelectExpression().Select.From is not { } froms | ||
|| froms.Count == 0) { | ||
await _telegramBotClient.SendTextMessageAsync( | ||
chatId: command.Chat.Id, | ||
text: "No FROM clause found.", | ||
replyToMessageId: command.SQLMessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
return; | ||
} | ||
|
||
// Collect table names from query | ||
HashSet<string> tables = new(); | ||
foreach (TableWithJoins from in froms) { | ||
if (from.Relation != null) { | ||
CollectTableNames(ref tables, from.Relation); | ||
} | ||
|
||
if (from.Joins != null) { | ||
foreach (Join join in from.Joins) { | ||
if (join.Relation != null) { | ||
CollectTableNames(ref tables, join.Relation); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Create scoped for scoped database | ||
using IServiceScope serviceScope = _serviceProvider.CreateScope(); | ||
|
||
// Load tables into memory | ||
foreach (string table in tables) { | ||
IScopedDataSource? dataSource = serviceScope.ServiceProvider.GetKeyedService<IScopedDataSource>(table); | ||
if (dataSource == null) { | ||
await _telegramBotClient.SendTextMessageAsync( | ||
chatId: command.Chat.Id, | ||
text: $$""" | ||
Table '{{table}}' not found. Available tables are: | ||
- pilpres | ||
""", | ||
replyToMessageId: command.SQLMessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
return; | ||
} | ||
|
||
await dataSource.LoadTableAsync(cancellationToken); | ||
} | ||
|
||
// Execute query | ||
using ScopedDatabase scopedDatabase = serviceScope.ServiceProvider.GetRequiredService<ScopedDatabase>(); | ||
StringBuilder resultBuilder = new(); | ||
scopedDatabase.ExecuteReader( | ||
commandText: command.RawStatement, | ||
readAction: (reader) => { | ||
string[] values = new string[reader.FieldCount]; | ||
|
||
// Get column names | ||
for (int i = 0; i < reader.FieldCount; i++) { | ||
values[i] = '"' + reader.GetName(i).Replace("\"", "\"\"") + '"'; | ||
} | ||
resultBuilder.AppendLine(string.Join(',', values)); | ||
|
||
// Get rows | ||
while (reader.Read()) { | ||
for (int i = 0; i < reader.FieldCount; i++) { | ||
if (reader.IsDBNull(i)) { | ||
values[i] = ""; | ||
continue; | ||
} | ||
|
||
Type fieldType = reader.GetFieldType(i); | ||
if (fieldType == typeof(string)) { | ||
values[i] = '"' + reader.GetString(i).Replace("\"", "\"\"") + '"'; | ||
} else if (fieldType == typeof(int)) { | ||
values[i] = reader.GetInt32(i).ToString(); | ||
} else if (fieldType == typeof(long)) { | ||
values[i] = reader.GetInt64(i).ToString(); | ||
} else if (fieldType == typeof(float)) { | ||
values[i] = reader.GetFloat(i).ToString(); | ||
} else if (fieldType == typeof(double)) { | ||
values[i] = reader.GetDouble(i).ToString(); | ||
} else if (fieldType == typeof(decimal)) { | ||
values[i] = reader.GetDecimal(i).ToString(); | ||
} else if (fieldType == typeof(bool)) { | ||
values[i] = reader.GetBoolean(i).ToString(); | ||
} else if (fieldType == typeof(DateTime)) { | ||
values[i] = reader.GetDateTime(i).ToString(); | ||
} else if (fieldType == typeof(byte[])) { | ||
values[i] = BitConverter.ToString(reader.GetFieldValue<byte[]>(i)).Replace("-", ""); | ||
} else { | ||
values[i] = reader[i].ToString(); | ||
} | ||
} | ||
resultBuilder.AppendLine(string.Join(',', values)); | ||
} | ||
} | ||
); | ||
|
||
// Send result | ||
await _telegramBotClient.SendTextMessageAsync( | ||
chatId: command.Chat.Id, | ||
text: "```csv\n" + resultBuilder.ToString() + "```", | ||
parseMode: ParseMode.MarkdownV2, | ||
replyToMessageId: command.SQLMessageId, | ||
cancellationToken: cancellationToken | ||
); | ||
|
||
return; | ||
} | ||
|
||
private static void CollectTableNames(ref HashSet<string> tables, TableFactor tableFactor) { | ||
switch (tableFactor) { | ||
case TableFactor.Derived derived: | ||
if (derived.SubQuery.Body.AsSelectExpression().Select.From is { } derivedFroms) { | ||
foreach (TableWithJoins derivedFrom in derivedFroms) { | ||
if (derivedFrom.Relation != null) { | ||
CollectTableNames(ref tables, derivedFrom.Relation); | ||
} | ||
|
||
if (derivedFrom.Joins != null) { | ||
foreach (Join join in derivedFrom.Joins) { | ||
if (join.Relation != null) { | ||
CollectTableNames(ref tables, join.Relation); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
break; | ||
case TableFactor.Function function: | ||
break; | ||
case TableFactor.JsonTable jsonTable: | ||
break; | ||
case TableFactor.NestedJoin nestedJoin: | ||
if (nestedJoin.TableWithJoins != null) { | ||
if (nestedJoin.TableWithJoins.Relation != null) { | ||
CollectTableNames(ref tables, nestedJoin.TableWithJoins.Relation); | ||
} | ||
|
||
if (nestedJoin.TableWithJoins.Joins != null) { | ||
foreach (Join join in nestedJoin.TableWithJoins.Joins) { | ||
if (join.Relation != null) { | ||
CollectTableNames(ref tables, join.Relation); | ||
} | ||
} | ||
} | ||
} | ||
break; | ||
case TableFactor.Pivot pivot: | ||
CollectTableNames(ref tables, pivot.TableFactor); | ||
break; | ||
case TableFactor.Table table: | ||
tables.Add(table.Name.ToString()); | ||
break; | ||
case TableFactor.TableFunction tableFunction: | ||
break; | ||
case TableFactor.UnNest unNest: | ||
break; | ||
case TableFactor.Unpivot unpivot: | ||
tables.Add(unpivot.Name.ToString()); | ||
break; | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
using System.Diagnostics.CodeAnalysis; | ||
using BotNet.Commands.BotUpdate.Message; | ||
using BotNet.Commands.ChatAggregate; | ||
using BotNet.Commands.CommandPrioritization; | ||
using SqlParser; | ||
using SqlParser.Ast; | ||
|
||
namespace BotNet.Commands.SQL { | ||
public sealed record SQLCommand : ICommand { | ||
public string RawStatement { get; } | ||
public Statement.Select SelectStatement { get; } | ||
public MessageId SQLMessageId { get; } | ||
public ChatBase Chat { get; } | ||
|
||
private SQLCommand( | ||
string rawStatement, | ||
Statement.Select selectStatement, | ||
MessageId sqlMessageId, | ||
ChatBase chat | ||
) { | ||
RawStatement = rawStatement; | ||
SelectStatement = selectStatement; | ||
SQLMessageId = sqlMessageId; | ||
Chat = chat; | ||
} | ||
|
||
public static bool TryCreate( | ||
Telegram.Bot.Types.Message message, | ||
CommandPriorityCategorizer commandPriorityCategorizer, | ||
[NotNullWhen(true)] out SQLCommand? sqlCommand | ||
) { | ||
// Must start with select | ||
if (message.Text is not { } text || !text.StartsWith("select", StringComparison.OrdinalIgnoreCase)) { | ||
sqlCommand = null; | ||
return false; | ||
} | ||
|
||
// Chat must be private or group | ||
if (!ChatBase.TryCreate(message.Chat, commandPriorityCategorizer, out ChatBase? chat)) { | ||
sqlCommand = null; | ||
return false; | ||
} | ||
|
||
// Must be a valid SQL statement | ||
Sequence<Statement> ast; | ||
try { | ||
ast = new SqlParser.Parser().ParseSql(text); | ||
} catch { | ||
sqlCommand = null; | ||
return false; | ||
} | ||
|
||
// Can only contain one statement | ||
if (ast.Count != 1) { | ||
sqlCommand = null; | ||
return false; | ||
} | ||
|
||
// Must be a SELECT statement | ||
if (ast[0] is not Statement.Select selectStatement) { | ||
sqlCommand = null; | ||
return false; | ||
} | ||
|
||
sqlCommand = new( | ||
rawStatement: text, | ||
selectStatement: selectStatement, | ||
sqlMessageId: new(message.MessageId), | ||
chat: chat | ||
); | ||
return true; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.