diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..05e765b --- /dev/null +++ b/errors_test.go @@ -0,0 +1,82 @@ +package bot + +import ( + "errors" + "testing" +) + +func TestTooManyRequestsError(t *testing.T) { + err := &TooManyRequestsError{ + Message: "rate limit exceeded", + RetryAfter: 30, + } + + expectedErrorMsg := "rate limit exceeded: retry_after 30" + if err.Error() != expectedErrorMsg { + t.Errorf("expected %s, got %s", expectedErrorMsg, err.Error()) + } + + if !IsTooManyRequestsError(err) { + t.Errorf("expected IsTooManyRequestsError to return true") + } + + var genericError error = err + if !IsTooManyRequestsError(genericError) { + t.Errorf("expected IsTooManyRequestsError to return true for generic error type") + } +} + +func TestMigrateError(t *testing.T) { + err := &MigrateError{ + Message: "chat migrated", + MigrateToChatID: 12345, + } + + expectedErrorMsg := "chat migrated: migrate_to_chat_id 12345" + if err.Error() != expectedErrorMsg { + t.Errorf("expected %s, got %s", expectedErrorMsg, err.Error()) + } + + if !IsMigrateError(err) { + t.Errorf("expected IsMigrateError to return true") + } + + var genericError error = err + if !IsMigrateError(genericError) { + t.Errorf("expected IsMigrateError to return true for generic error type") + } +} + +func TestStandardErrors(t *testing.T) { + tests := []struct { + err error + expected string + }{ + {ErrorForbidden, "forbidden"}, + {ErrorBadRequest, "bad request"}, + {ErrorUnauthorized, "unauthorized"}, + {ErrorTooManyRequests, "too many requests"}, + {ErrorNotFound, "not found"}, + {ErrorConflict, "conflict"}, + } + + for _, tt := range tests { + if tt.err.Error() != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, tt.err.Error()) + } + } +} + +func TestIsTooManyRequestsErrorFalse(t *testing.T) { + err := errors.New("some other error") + if IsTooManyRequestsError(err) { + t.Errorf("expected IsTooManyRequestsError to return false") + } +} + +func TestIsMigrateErrorFalse(t *testing.T) { + err := errors.New("some other error") + if IsMigrateError(err) { + t.Errorf("expected IsMigrateError to return false") + } +} diff --git a/process_update_test.go b/process_update_test.go new file mode 100644 index 0000000..6f749cd --- /dev/null +++ b/process_update_test.go @@ -0,0 +1,134 @@ +package bot + +import ( + "context" + "sync" + "testing" + + "github.com/go-telegram/bot/models" +) + +func Test_applyMiddlewares(t *testing.T) { + h := func(ctx context.Context, bot *Bot, update *models.Update) {} + + middleware1 := func(next HandlerFunc) HandlerFunc { + return func(ctx context.Context, bot *Bot, update *models.Update) { + next(ctx, bot, update) + } + } + + middleware2 := func(next HandlerFunc) HandlerFunc { + return func(ctx context.Context, bot *Bot, update *models.Update) { + next(ctx, bot, update) + } + } + + wrapped := applyMiddlewares(h, middleware1, middleware2) + if wrapped == nil { + t.Fatal("Expected wrapped handler to be non-nil") + } +} + +func TestProcessUpdate(t *testing.T) { + var called bool + h := func(ctx context.Context, bot *Bot, update *models.Update) { + called = true + } + + bot := &Bot{ + defaultHandlerFunc: h, + middlewares: []Middleware{}, + handlersMx: &sync.RWMutex{}, + handlers: map[string]handler{}, + } + + ctx := context.Background() + upd := &models.Update{Message: &models.Message{Text: "test"}} + + bot.ProcessUpdate(ctx, upd) + if !called { + t.Fatal("Expected default handler to be called") + } +} + +func TestProcessUpdate_WithMiddlewares(t *testing.T) { + var called bool + h := func(ctx context.Context, bot *Bot, update *models.Update) { + called = true + } + + middleware := func(next HandlerFunc) HandlerFunc { + return func(ctx context.Context, bot *Bot, update *models.Update) { + next(ctx, bot, update) + } + } + + bot := &Bot{ + defaultHandlerFunc: h, + middlewares: []Middleware{middleware}, + handlersMx: &sync.RWMutex{}, + handlers: map[string]handler{}, + } + + ctx := context.Background() + upd := &models.Update{Message: &models.Message{Text: "test"}} + + bot.ProcessUpdate(ctx, upd) + if !called { + t.Fatal("Expected default handler to be called") + } +} + +func Test_findHandler(t *testing.T) { + var called bool + h := func(ctx context.Context, bot *Bot, update *models.Update) { + called = true + } + + bot := &Bot{ + defaultHandlerFunc: h, + handlersMx: &sync.RWMutex{}, + handlers: map[string]handler{}, + } + + // Register a handler + bot.handlers["test"] = handler{ + handlerType: HandlerTypeMessageText, + matchType: MatchTypeExact, + pattern: "test", + handler: h, + } + + ctx := context.Background() + upd := &models.Update{Message: &models.Message{Text: "test"}} + + handler := bot.findHandler(HandlerTypeMessageText, upd) + handler(ctx, bot, upd) + + if !called { + t.Fatal("Expected registered handler to be called") + } +} + +func Test_findHandler_Default(t *testing.T) { + var called bool + h := func(ctx context.Context, bot *Bot, update *models.Update) { + called = true + } + + bot := &Bot{ + defaultHandlerFunc: h, + handlersMx: &sync.RWMutex{}, + handlers: map[string]handler{}, + } + + ctx := context.Background() + upd := &models.Update{Message: &models.Message{Text: "test"}} + + handler := bot.findHandler(HandlerTypeCallbackQueryData, upd) + handler(ctx, bot, upd) + + if !called { + t.Fatal("Expected default handler to be called") + } +} diff --git a/webhook_handler.go b/webhook_handler.go index 1734a3f..3febad1 100644 --- a/webhook_handler.go +++ b/webhook_handler.go @@ -32,7 +32,13 @@ func (b *Bot) WebhookHandler() http.HandlerFunc { case <-req.Context().Done(): b.error("some updates lost, ctx done") return + default: + } + + select { case b.updates <- update: + case <-req.Context().Done(): + b.error("failed to send update, ctx done") } } } diff --git a/webhook_handler_test.go b/webhook_handler_test.go new file mode 100644 index 0000000..c63d2ea --- /dev/null +++ b/webhook_handler_test.go @@ -0,0 +1,187 @@ +package bot + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-telegram/bot/models" +) + +type mockDebugHandler struct { + messages []string +} + +func (h *mockDebugHandler) Handle(format string, args ...interface{}) { + h.messages = append(h.messages, format) +} + +type mockErrorsHandler struct { + errors []error +} + +func (h *mockErrorsHandler) Handle(err error) { + h.errors = append(h.errors, err) +} + +type errReaderStruct struct { + err error +} + +func (e *errReaderStruct) Read(p []byte) (int, error) { + return 0, e.err +} + +func (e *errReaderStruct) Close() error { + return nil +} + +func errReader(err error) io.ReadCloser { + return &errReaderStruct{err: err} +} + +func TestWebhookHandler_Success(t *testing.T) { + debugHandler := &mockDebugHandler{} + errorsHandler := &mockErrorsHandler{} + + bot := &Bot{ + updates: make(chan *models.Update, 1), + isDebug: true, + debugHandler: func(format string, args ...interface{}) { + debugHandler.Handle(format, args...) + }, + errorsHandler: func(err error) { + errorsHandler.Handle(err) + }, + } + + update := &models.Update{ + ID: 12345, + } + updateBody, _ := json.Marshal(update) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(updateBody)) + w := httptest.NewRecorder() + + handler := bot.WebhookHandler() + handler(w, req) + + select { + case upd := <-bot.updates: + if upd.ID != update.ID { + t.Fatalf("Expected update ID %d, got %d", update.ID, upd.ID) + } + default: + t.Fatal("Expected update to be sent to bot.updates channel") + } +} + +func TestWebhookHandler_ReadBodyError(t *testing.T) { + debugHandler := &mockDebugHandler{} + errorsHandler := &mockErrorsHandler{} + + bot := &Bot{ + debugHandler: func(format string, args ...interface{}) { + debugHandler.Handle(format, args...) + }, + errorsHandler: func(err error) { + errorsHandler.Handle(err) + }, + } + + req := httptest.NewRequest(http.MethodPost, "/", errReader(errors.New("read error"))) + w := httptest.NewRecorder() + + handler := bot.WebhookHandler() + handler(w, req) + + if len(errorsHandler.errors) == 0 { + t.Fatal("Expected an error, but none occurred") + } + + if capturedError := errorsHandler.errors[0]; capturedError == nil || !containsString(capturedError.Error(), "read error") { + t.Fatalf("Expected read body error, got %v", capturedError) + } +} + +func TestWebhookHandler_DecodeError(t *testing.T) { + debugHandler := &mockDebugHandler{} + errorsHandler := &mockErrorsHandler{} + + bot := &Bot{ + debugHandler: func(format string, args ...interface{}) { + debugHandler.Handle(format, args...) + }, + errorsHandler: func(err error) { + errorsHandler.Handle(err) + }, + } + + invalidJSON := []byte("{invalid json}") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(invalidJSON)) + w := httptest.NewRecorder() + + handler := bot.WebhookHandler() + handler(w, req) + + if len(errorsHandler.errors) == 0 { + t.Fatal("Expected an error, but none occurred") + } + + if capturedError := errorsHandler.errors[0]; capturedError == nil || !containsString(capturedError.Error(), "error decode request body") { + t.Fatalf("Expected decode error, got %v", capturedError) + } +} + +func TestWebhookHandler_ContextDone(t *testing.T) { + debugHandler := &mockDebugHandler{} + errorsHandler := &mockErrorsHandler{} + + bot := &Bot{ + updates: make(chan *models.Update, 1), + debugHandler: func(format string, args ...interface{}) { + debugHandler.Handle(format, args...) + }, + errorsHandler: func(err error) { + errorsHandler.Handle(err) + }, + } + + update := &models.Update{ + ID: 12345, + } + updateBody, _ := json.Marshal(update) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(updateBody)).WithContext(ctx) + w := httptest.NewRecorder() + + handler := bot.WebhookHandler() + handler(w, req) + + select { + case <-bot.updates: + t.Fatal("Did not expect update to be sent to bot.updates channel") + default: + // Expected outcome, context was cancelled before sending the update + } + + if len(errorsHandler.errors) == 0 { + t.Fatal("Expected an error, but none occurred") + } + + if capturedError := errorsHandler.errors[0]; capturedError == nil || !containsString(capturedError.Error(), "some updates lost, ctx done") { + t.Fatalf("Expected context done error, got %v", capturedError) + } +} + +func containsString(s, substr string) bool { + return bytes.Contains([]byte(s), []byte(substr)) +}