diff --git a/aria/bot.go b/aria/bot.go index 7a1022e..006c4c0 100644 --- a/aria/bot.go +++ b/aria/bot.go @@ -14,6 +14,8 @@ import ( const chanTimeout = 30 * time.Second +var errNotInVoice = errors.New("user not in voice") + type bot struct { sync.RWMutex *discordgo.Session @@ -21,6 +23,7 @@ type bot struct { token string prefix string keepMsg keepMsgMap + voice voiceState ariaRecv <-chan *packet ariaSend chan<- *request @@ -38,6 +41,7 @@ type bot struct { func newBot( config *config, + voice voiceState, cliToBot <-chan *packet, botToCli chan<- *request, stream <-chan []byte, @@ -58,6 +62,7 @@ func newBot( } b.keepMsg = config.keepMsg + b.voice = voice b.stream = stream b.ariaRecv = cliToBot b.ariaSend = botToCli @@ -227,15 +232,15 @@ func (b *bot) onDisconnect(_ *discordgo.Session, _ *discordgo.Disconnect) { b.cancel() } -// TODO: when ready, join all AutoJoin channels func (b *bot) onReady(s *discordgo.Session, r *discordgo.Ready) { b.Lock() + b.botUser = r.User defer b.Unlock() - b.botUser = r.User + b.recoverVoiceConnections() } -// parse message from discord, fire cmdHandlers +// onMessage parses message from discord and fire cmdHandlers func (b *bot) onMessage(s *discordgo.Session, m *discordgo.MessageCreate) { b.RLock() if m.Author.ID == b.botUser.ID { @@ -275,6 +280,41 @@ func (b *bot) onMessage(s *discordgo.Session, m *discordgo.MessageCreate) { // utilities +func (b *bot) recoverVoiceConnections() { + v := b.voice.cloneJoined() + for c, g := range v { + if err := b.joinVoice(g, c); err != nil { + log.Printf("failedc to recover voice: %v\n", err) + } + } +} + +func (b *bot) joinVoice(guildID, channelID string) error { + _, err := b.ChannelVoiceJoin(guildID, channelID, false, false) + if err != nil { + return err + } + b.voice.recordJoin(guildID, channelID) + return nil +} + +func (b *bot) disconnectVoice(guildID string) error { + b.Session.RLock() + v, ok := b.Session.VoiceConnections[guildID] + b.Session.RUnlock() + + if !ok { + return errNotInVoice + } + + if err := v.Disconnect(); err != nil { + return err + } + + b.voice.recordDisconnect(v.ChannelID) + return nil +} + func (b *bot) resolveCommand(raw string) (cmd string) { _, ok := b.cmdHandlers[raw] if ok { diff --git a/aria/commands.go b/aria/commands.go index 9b5b527..715a730 100644 --- a/aria/commands.go +++ b/aria/commands.go @@ -324,24 +324,19 @@ func (b *bot) cmdSummon(m *discordgo.Message, _ []string) { return } - _, err = b.ChannelVoiceJoin(g.ID, vid, false, false) - if err != nil { - log.Printf("failed to join voice: %v\n", err) + if err = b.joinVoice(m.GuildID, vid); err != nil { + log.Printf("failed to summon voice: %v\n", err) } } func (b *bot) cmdDisconnect(m *discordgo.Message, _ []string) { - b.Session.RLock() - v, ok := b.VoiceConnections[m.GuildID] - b.Session.RUnlock() - - if !ok { - sendErrorResponse(b, m.ChannelID, "Not in voice channel.") - return - } - - if err := v.Disconnect(); err != nil { - log.Printf("failed to disconnect voice: %v\n", err) + if err := b.disconnectVoice(m.GuildID); err != nil { + switch err { + case errNotInVoice: + sendErrorResponse(b, m.ChannelID, "Not in voice channel.") + default: + log.Printf("failed to disconnect voice: %v\n", err) + } } } diff --git a/aria/launcher.go b/aria/launcher.go index 5281725..e603274 100644 --- a/aria/launcher.go +++ b/aria/launcher.go @@ -30,6 +30,7 @@ func setupLogger() { type launcher struct { config *config + voice voiceState cliToBot chan *packet botToCli chan *request @@ -39,6 +40,7 @@ type launcher struct { func newLauncher(config *config) *launcher { return &launcher{ config: config, + voice: newVoiceState(), cliToBot: make(chan *packet), botToCli: make(chan *request), stream: make(chan []byte), @@ -91,7 +93,7 @@ func (l *launcher) launchBot(ctx context.Context, errChan chan<- error) { } } - b, err := newBot(l.config, l.cliToBot, l.botToCli, l.stream) + b, err := newBot(l.config, l.voice, l.cliToBot, l.botToCli, l.stream) if err != nil { errChan <- fmt.Errorf("failed to initialize bot: %w", err) return diff --git a/aria/voice.go b/aria/voice.go new file mode 100644 index 0000000..760899d --- /dev/null +++ b/aria/voice.go @@ -0,0 +1,49 @@ +package aria + +import ( + "sync" +) + +type voice struct { + sync.RWMutex + joined map[string]string // channelID -> guildID +} + +type voiceState interface { + // cloneJoined returns map contains channelID -> guildID + cloneJoined() map[string]string + recordJoin(guildID, channelID string) + recordDisconnect(channelID string) +} + +func newVoiceState() voiceState { + return &voice{ + joined: make(map[string]string), + } +} + +func (v *voice) cloneJoined() map[string]string { + v.RLock() + defer v.RUnlock() + + ret := make(map[string]string) + for k, v := range v.joined { + ret[k] = v + } + + return ret +} + +func (v *voice) recordJoin(guildID, channelID string) { + v.Lock() + defer v.Unlock() + + v.joined[channelID] = guildID +} + +func (v *voice) recordDisconnect(channelID string) { + v.Lock() + defer v.Unlock() + + delete(v.joined, channelID) +}