diff --git a/process_update.go b/process_update.go index 6bdcc1f..b24eb2e 100644 --- a/process_update.go +++ b/process_update.go @@ -26,24 +26,22 @@ func (b *Bot) ProcessUpdate(ctx context.Context, upd *models.Update) { }() if upd.Message != nil { - h = b.findHandler(HandlerTypeMessageText, upd) + h = b.findHandler(upd) return } if upd.CallbackQuery != nil { - h = b.findHandler(HandlerTypeCallbackQueryData, upd) + h = b.findHandler(upd) return } } -func (b *Bot) findHandler(handlerType HandlerType, upd *models.Update) HandlerFunc { +func (b *Bot) findHandler(upd *models.Update) HandlerFunc { b.handlersMx.RLock() defer b.handlersMx.RUnlock() for _, h := range b.handlers { - if h.handlerType == handlerType { - if h.match(upd) { - return h.handler - } + if h.match(upd) { + return h.handler } } diff --git a/process_update_test.go b/process_update_test.go index 6f749cd..a5f4bfb 100644 --- a/process_update_test.go +++ b/process_update_test.go @@ -79,6 +79,35 @@ func TestProcessUpdate_WithMiddlewares(t *testing.T) { } } +func TestProcessUpdate_WithMatchTypeFunc(t *testing.T) { + var called string + h1 := func(ctx context.Context, bot *Bot, update *models.Update) { + called = "h1" + } + h2 := func(ctx context.Context, bot *Bot, update *models.Update) { + called = "h2" + } + m := func(update *models.Update) bool { + return update.CallbackQuery.GameShortName == "game" + } + + bot := &Bot{ + defaultHandlerFunc: h1, + handlersMx: &sync.RWMutex{}, + handlers: map[string]handler{}, + } + + bot.RegisterHandlerMatchFunc(m, h2) + + ctx := context.Background() + upd := &models.Update{ID: 42, CallbackQuery: &models.CallbackQuery{ID: "test", GameShortName: "game"}} + + bot.ProcessUpdate(ctx, upd) + if called != "h2" { + t.Fatalf("Expected h2 handler to be called but %s handler was called", called) + } +} + func Test_findHandler(t *testing.T) { var called bool h := func(ctx context.Context, bot *Bot, update *models.Update) { @@ -102,7 +131,7 @@ func Test_findHandler(t *testing.T) { ctx := context.Background() upd := &models.Update{Message: &models.Message{Text: "test"}} - handler := bot.findHandler(HandlerTypeMessageText, upd) + handler := bot.findHandler(upd) handler(ctx, bot, upd) if !called { @@ -125,7 +154,7 @@ func Test_findHandler_Default(t *testing.T) { ctx := context.Background() upd := &models.Update{Message: &models.Message{Text: "test"}} - handler := bot.findHandler(HandlerTypeCallbackQueryData, upd) + handler := bot.findHandler(upd) handler(ctx, bot, upd) if !called {