Skip to content
This repository has been archived by the owner on Oct 28, 2020. It is now read-only.

Commit

Permalink
bot: fix voice reconnect
Browse files Browse the repository at this point in the history
fixes #10 - reconnect to discord voice properly
  • Loading branch information
sarisia committed May 16, 2020
1 parent bfc7d47 commit 8e7705d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 18 deletions.
46 changes: 43 additions & 3 deletions aria/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ import (

const chanTimeout = 30 * time.Second

var errNotInVoice = errors.New("user not in voice")

type bot struct {
sync.RWMutex
*discordgo.Session

token string
prefix string
keepMsg keepMsgMap
voice voiceState
ariaRecv <-chan *packet
ariaSend chan<- *request

Expand All @@ -38,6 +41,7 @@ type bot struct {

func newBot(
config *config,
voice voiceState,
cliToBot <-chan *packet,
botToCli chan<- *request,
stream <-chan []byte,
Expand All @@ -58,6 +62,7 @@ func newBot(
}
b.keepMsg = config.keepMsg

b.voice = voice
b.stream = stream
b.ariaRecv = cliToBot
b.ariaSend = botToCli
Expand Down Expand Up @@ -227,15 +232,15 @@ func (b *bot) onDisconnect(_ *discordgo.Session, _ *discordgo.Disconnect) {
b.cancel()
}

// TODO: when ready, join all AutoJoin channels
func (b *bot) onReady(s *discordgo.Session, r *discordgo.Ready) {
b.Lock()
b.botUser = r.User
defer b.Unlock()

b.botUser = r.User
b.recoverVoiceConnections()
}

// parse message from discord, fire cmdHandlers
// onMessage parses message from discord and fire cmdHandlers
func (b *bot) onMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
b.RLock()
if m.Author.ID == b.botUser.ID {
Expand Down Expand Up @@ -275,6 +280,41 @@ func (b *bot) onMessage(s *discordgo.Session, m *discordgo.MessageCreate) {

// utilities

func (b *bot) recoverVoiceConnections() {
v := b.voice.cloneJoined()
for c, g := range v {
if err := b.joinVoice(g, c); err != nil {
log.Printf("failedc to recover voice: %v\n", err)
}
}
}

func (b *bot) joinVoice(guildID, channelID string) error {
_, err := b.ChannelVoiceJoin(guildID, channelID, false, false)
if err != nil {
return err
}
b.voice.recordJoin(guildID, channelID)
return nil
}

func (b *bot) disconnectVoice(guildID string) error {
b.Session.RLock()
v, ok := b.Session.VoiceConnections[guildID]
b.Session.RUnlock()

if !ok {
return errNotInVoice
}

if err := v.Disconnect(); err != nil {
return err
}

b.voice.recordDisconnect(v.ChannelID)
return nil
}

func (b *bot) resolveCommand(raw string) (cmd string) {
_, ok := b.cmdHandlers[raw]
if ok {
Expand Down
23 changes: 9 additions & 14 deletions aria/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,24 +324,19 @@ func (b *bot) cmdSummon(m *discordgo.Message, _ []string) {
return
}

_, err = b.ChannelVoiceJoin(g.ID, vid, false, false)
if err != nil {
log.Printf("failed to join voice: %v\n", err)
if err = b.joinVoice(m.GuildID, vid); err != nil {
log.Printf("failed to summon voice: %v\n", err)
}
}

func (b *bot) cmdDisconnect(m *discordgo.Message, _ []string) {
b.Session.RLock()
v, ok := b.VoiceConnections[m.GuildID]
b.Session.RUnlock()

if !ok {
sendErrorResponse(b, m.ChannelID, "Not in voice channel.")
return
}

if err := v.Disconnect(); err != nil {
log.Printf("failed to disconnect voice: %v\n", err)
if err := b.disconnectVoice(m.GuildID); err != nil {
switch err {
case errNotInVoice:
sendErrorResponse(b, m.ChannelID, "Not in voice channel.")
default:
log.Printf("failed to disconnect voice: %v\n", err)
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion aria/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func setupLogger() {

type launcher struct {
config *config
voice voiceState

cliToBot chan *packet
botToCli chan *request
Expand All @@ -39,6 +40,7 @@ type launcher struct {
func newLauncher(config *config) *launcher {
return &launcher{
config: config,
voice: newVoiceState(),
cliToBot: make(chan *packet),
botToCli: make(chan *request),
stream: make(chan []byte),
Expand Down Expand Up @@ -91,7 +93,7 @@ func (l *launcher) launchBot(ctx context.Context, errChan chan<- error) {
}
}

b, err := newBot(l.config, l.cliToBot, l.botToCli, l.stream)
b, err := newBot(l.config, l.voice, l.cliToBot, l.botToCli, l.stream)
if err != nil {
errChan <- fmt.Errorf("failed to initialize bot: %w", err)
return
Expand Down
49 changes: 49 additions & 0 deletions aria/voice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package aria

import (
"sync"
)

type voice struct {
sync.RWMutex
joined map[string]string // channelID -> guildID
}

type voiceState interface {
// cloneJoined returns map contains channelID -> guildID
cloneJoined() map[string]string
recordJoin(guildID, channelID string)
recordDisconnect(channelID string)
}

func newVoiceState() voiceState {
return &voice{
joined: make(map[string]string),
}
}

func (v *voice) cloneJoined() map[string]string {
v.RLock()
defer v.RUnlock()

ret := make(map[string]string)
for k, v := range v.joined {
ret[k] = v
}

return ret
}

func (v *voice) recordJoin(guildID, channelID string) {
v.Lock()
defer v.Unlock()

v.joined[channelID] = guildID
}

func (v *voice) recordDisconnect(channelID string) {
v.Lock()
defer v.Unlock()

delete(v.joined, channelID)
}

0 comments on commit 8e7705d

Please sign in to comment.