Skip to content

Commit

Permalink
Fix 'findHandler' to be correct for 'RegisterHandlerMatchFunc' with '…
Browse files Browse the repository at this point in the history
…CallbackQuery'
  • Loading branch information
oanhnn committed Aug 19, 2024
1 parent b3b73e9 commit ef3c539
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
12 changes: 5 additions & 7 deletions process_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,22 @@ func (b *Bot) ProcessUpdate(ctx context.Context, upd *models.Update) {
}()

if upd.Message != nil {
h = b.findHandler(HandlerTypeMessageText, upd)
h = b.findHandler(upd)
return
}
if upd.CallbackQuery != nil {
h = b.findHandler(HandlerTypeCallbackQueryData, upd)
h = b.findHandler(upd)
return
}
}

func (b *Bot) findHandler(handlerType HandlerType, upd *models.Update) HandlerFunc {
func (b *Bot) findHandler(upd *models.Update) HandlerFunc {
b.handlersMx.RLock()
defer b.handlersMx.RUnlock()

for _, h := range b.handlers {
if h.handlerType == handlerType {
if h.match(upd) {
return h.handler
}
if h.match(upd) {
return h.handler
}
}

Expand Down
33 changes: 31 additions & 2 deletions process_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ func TestProcessUpdate_WithMiddlewares(t *testing.T) {
}
}

func TestProcessUpdate_WithMatchTypeFunc(t *testing.T) {
var called string
h1 := func(ctx context.Context, bot *Bot, update *models.Update) {
called = "h1"
}
h2 := func(ctx context.Context, bot *Bot, update *models.Update) {
called = "h2"
}
m := func(update *models.Update) bool {
return update.CallbackQuery.GameShortName == "game"
}

bot := &Bot{
defaultHandlerFunc: h1,
handlersMx: &sync.RWMutex{},
handlers: map[string]handler{},
}

bot.RegisterHandlerMatchFunc(m, h2)

ctx := context.Background()
upd := &models.Update{ID: 42, CallbackQuery: &models.CallbackQuery{ID: "test", GameShortName: "game"}}

bot.ProcessUpdate(ctx, upd)
if called != "h2" {
t.Fatalf("Expected h2 handler to be called but %s handler was called", called)
}
}

func Test_findHandler(t *testing.T) {
var called bool
h := func(ctx context.Context, bot *Bot, update *models.Update) {
Expand All @@ -102,7 +131,7 @@ func Test_findHandler(t *testing.T) {
ctx := context.Background()
upd := &models.Update{Message: &models.Message{Text: "test"}}

handler := bot.findHandler(HandlerTypeMessageText, upd)
handler := bot.findHandler(upd)
handler(ctx, bot, upd)

if !called {
Expand All @@ -125,7 +154,7 @@ func Test_findHandler_Default(t *testing.T) {
ctx := context.Background()
upd := &models.Update{Message: &models.Message{Text: "test"}}

handler := bot.findHandler(HandlerTypeCallbackQueryData, upd)
handler := bot.findHandler(upd)
handler(ctx, bot, upd)

if !called {
Expand Down

0 comments on commit ef3c539

Please sign in to comment.