Skip to content

Commit

Permalink
fix differentiating between global and guild commands
Browse files Browse the repository at this point in the history
  • Loading branch information
PsychicNoodles committed Oct 28, 2022
1 parent 64c44f5 commit 8cb9418
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 45 deletions.
10 changes: 6 additions & 4 deletions src/commands.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::str::FromStr;
use std::{collections::HashMap, hash::Hash, str::FromStr};

use async_trait::async_trait;
use serenity::{
builder::CreateApplicationCommand,
model::prelude::interaction::application_command::ApplicationCommandInteraction,
prelude::Context,
model::prelude::{interaction::application_command::ApplicationCommandInteraction, CommandId},
prelude::{Context, TypeMapKey},
};

use crate::{util::LocalizedString, Handler, HandlerError, MessageDbData};
Expand All @@ -30,7 +30,9 @@ trait AppCmd {
}

#[async_trait]
pub trait CommandsEnum: FromStr {
pub trait CommandsEnum:
FromStr + TypeMapKey<Value = HashMap<CommandId, Self>> + std::fmt::Debug + Copy + Eq + Hash
{
async fn handle(
self,
cmd: &ApplicationCommandInteraction,
Expand Down
12 changes: 8 additions & 4 deletions src/commands/global.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::str::FromStr;
use std::{collections::HashMap, str::FromStr};

use async_trait::async_trait;
use serenity::{
builder::CreateApplicationCommand,
model::prelude::interaction::application_command::ApplicationCommandInteraction,
prelude::Context,
model::prelude::{interaction::application_command::ApplicationCommandInteraction, CommandId},
prelude::{Context, TypeMapKey},
};
use strum::IntoEnumIterator;
use strum_macros::{AsRefStr, Display, EnumIter};
Expand All @@ -25,7 +25,7 @@ pub mod list_emotes;
pub mod stats;
pub mod user_settings;

#[derive(Debug, Clone, Copy, AsRefStr, Display, EnumIter)]
#[derive(Debug, Clone, Copy, AsRefStr, Display, EnumIter, PartialEq, Eq, Hash)]
pub enum GlobalCommands {
EmoteSelect,
UserSettings,
Expand Down Expand Up @@ -86,6 +86,10 @@ impl CommandsEnum for GlobalCommands {
}
}

impl TypeMapKey for GlobalCommands {
type Value = HashMap<CommandId, Self>;
}

#[derive(Debug, Clone, Error)]
#[error("Not a valid command: {0}")]
pub struct InvalidGlobalCommand(String);
Expand Down
12 changes: 8 additions & 4 deletions src/commands/guild.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
pub mod server_settings;
pub mod stats;

use std::str::FromStr;
use std::{collections::HashMap, str::FromStr};

use async_trait::async_trait;
use serenity::{
builder::CreateApplicationCommand,
model::prelude::interaction::application_command::ApplicationCommandInteraction,
prelude::Context,
model::prelude::{interaction::application_command::ApplicationCommandInteraction, CommandId},
prelude::{Context, TypeMapKey},
};
use strum::IntoEnumIterator;
use strum_macros::{AsRefStr, Display, EnumIter};
Expand All @@ -19,7 +19,7 @@ use self::{server_settings::ServerSettingsCmd, stats::GuildStatsCmd};

use super::{AppCmd, CommandsEnum};

#[derive(Debug, Clone, Copy, AsRefStr, Display, EnumIter)]
#[derive(Debug, Clone, Copy, AsRefStr, Display, EnumIter, PartialEq, Eq, Hash)]
pub enum GuildCommands {
ServerSettings,
Stats,
Expand Down Expand Up @@ -64,6 +64,10 @@ impl CommandsEnum for GuildCommands {
}
}

impl TypeMapKey for GuildCommands {
type Value = HashMap<CommandId, Self>;
}

#[derive(Debug, Clone, Error)]
#[error("Not a valid command: {0}")]
pub struct InvalidGuildCommand(String);
Expand Down
124 changes: 91 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ use db::{
models::{DbGuild, DbUser},
Db,
};
use futures::future::{try_join_all, TryFutureExt};
use futures::{
future::{try_join_all, TryFutureExt},
stream, StreamExt, TryStreamExt,
};
use sqlx::PgPool;
use std::{borrow::Cow, sync::Arc, time::Duration};
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::sync::OnceCell;
use tracing::*;
Expand Down Expand Up @@ -144,6 +147,10 @@ pub enum HandlerError {
EmoteLogCountNoParams,
#[error("Internal error, could not build response")]
CountNone,
#[error("Received command info for unknown command")]
CommandRegisterUnknown,
#[error("Internal error, could not build response")]
TypeMapNotFound,
}

impl HandlerError {
Expand Down Expand Up @@ -219,7 +226,7 @@ impl EventHandler for Handler {
cmd.guild_id.as_ref().map(ToString::to_string),
);

if let Err(err) | Ok(Err(err)) = self
if let Err(err) = self
.try_handle_commands::<GlobalCommands>(&context, &cmd, &message_db_data)
.or_else(|_| {
self.try_handle_commands::<GuildCommands>(&context, &cmd, &message_db_data)
Expand Down Expand Up @@ -247,50 +254,99 @@ impl EventHandler for Handler {

#[instrument(skip(self, context))]
async fn ready(&self, context: Context, ready: Ready) {
async fn save_command_ids<T>(
context: &Context,
commands: impl Iterator<Item = Command>,
) -> Result<(), HandlerError>
where
T: CommandsEnum,
{
let mut cmd_map = HashMap::new();
for cmd in commands {
let cmd_enum =
T::from_str(&cmd.name).map_err(|_| HandlerError::CommandRegisterUnknown)?;
if let Some(prev) = cmd_map.insert(cmd.id, cmd_enum) {
warn!("overwrote previous command with same id: {:?}", prev);
}
}
context.data.write().await.insert::<T>(cmd_map);
Ok(())
}

info!("{} is connected", ready.user.name);

info!(
"guilds: {:?}",
ready.guilds.iter().map(|ug| ug.id).collect::<Vec<_>>()
);
// global commands

match Command::set_global_application_commands(&context, |create| {
let global_commands = match Command::set_global_application_commands(&context, |create| {
create.set_application_commands(GlobalCommands::application_commands().collect());
create
})
.await
{
Ok(res) => {
info!(
"registered global commands: {:?}",
res.into_iter().map(|c| c.name).collect::<Vec<_>>()
);
}
Err(err) => {
error!("error registering global application commands: {:?}", err);
context.shard.shutdown_clean();
return;
}
}
Ok(commands) => commands,
};

match try_join_all(ready.guilds.iter().map(|g| {
g.id.set_application_commands(&context, |create| {
create.set_application_commands(GuildCommands::application_commands().collect());
create
})
}))
.await
info!(
"registered global commands: {:?}",
global_commands.iter().map(|c| &c.name).collect::<Vec<_>>()
);
if let Err(err) =
save_command_ids::<GlobalCommands>(&context, global_commands.into_iter()).await
{
Ok(res) => {
if let Some(sample) = res.first() {
info!(
"registered guild commands: {:?}",
sample.into_iter().map(|c| &c.name).collect::<Vec<_>>()
);
error!("error saving global application command data: {:?}", err);
context.shard.shutdown_clean();
return;
}

// guild commands

if !ready.guilds.is_empty() {
let guild_commands = match try_join_all(ready.guilds.iter().map(|g| {
g.id.set_application_commands(&context, |create| {
create
.set_application_commands(GuildCommands::application_commands().collect());
create
})
}))
.await
{
Err(err) => {
error!("error registering guild application commands: {:?}", err);
context.shard.shutdown_clean();
return;
}
Ok(commands) => commands,
};

if let Some(first) = guild_commands.first() {
info!(
"registered guild commands: {:?}",
first.iter().map(|c| &c.name).collect::<Vec<_>>()
);
} else {
error!("guilds list is not empty, but no guild commands were registered");
context.shard.shutdown_clean();
return;
}
Err(err) => {
error!("error registering guild application commands: {:?}", err);
if let Err(err) = stream::iter(guild_commands.into_iter())
.map(Ok)
.try_for_each(|cmds| async {
save_command_ids::<GuildCommands>(&context, cmds.into_iter()).await
})
.await
{
error!("error saving guild application command data: {:?}", err);
context.shard.shutdown_clean();
return;
}
}
}
Expand Down Expand Up @@ -423,16 +479,18 @@ impl Handler {
context: &Context,
cmd: &ApplicationCommandInteraction,
message_db_data: &MessageDbData<'a>,
) -> Result<Result<(), HandlerError>, HandlerError>
) -> Result<(), HandlerError>
where
T: CommandsEnum,
{
if let Ok(app_cmd) = T::from_str(cmd.data.name.as_str()) {
trace!("handing off to app command handler");
Ok(app_cmd.handle(cmd, self, context, message_db_data).await)
} else {
Err(HandlerError::UnrecognizedCommand(cmd.data.name.to_string()))
}
let read = context.data.read().await;
let app_cmd = read
.get::<T>()
.ok_or(HandlerError::TypeMapNotFound)?
.get(&cmd.data.id)
.ok_or(HandlerError::UnrecognizedCommand(cmd.data.name.to_string()))?;
trace!("handing off to app command handler: {:?}", app_cmd);
app_cmd.handle(cmd, self, context, message_db_data).await
}

#[instrument(skip(self))]
Expand Down

0 comments on commit 8cb9418

Please sign in to comment.