diff --git a/go.mod b/go.mod index 1b2235c..cc1446f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/hashicorp/vault v1.16.3 github.com/hashicorp/vault/api v1.14.0 github.com/labstack/echo/v4 v4.12.0 + github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.24.0 ) @@ -241,7 +242,6 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/stretchr/testify v1.9.0 // indirect github.com/tencentcloud/tencentcloud-sdk-go v1.0.162 // indirect github.com/tklauser/go-sysconf v0.3.10 // indirect github.com/tklauser/numcpus v0.4.0 // indirect diff --git a/internal/handlers_test.go b/internal/handlers_test.go index 06e930b..ed6b3a5 100644 --- a/internal/handlers_test.go +++ b/internal/handlers_test.go @@ -4,10 +4,10 @@ import ( "errors" "net/http" "net/http/httptest" - "reflect" "testing" "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" ) type FakeSecretMsgStorer struct { @@ -37,23 +37,11 @@ func TestGetMsgHandlerSuccess(t *testing.T) { s := &FakeSecretMsgStorer{msg: "secret"} h := newSecretHandlers(s) err := h.GetMsgHandler(c) - if err != nil { - t.Fatalf("got error %v, none expected", err) - } - - if s.lastUsedToken != "secrettoken" { - t.Fatalf("Storer::Get was called with %s, expected %s", s.lastUsedToken, "secrettoken") - } - - if rec.Code != http.StatusOK { - t.Fatalf("got statusCode %d, expected %d", rec.Code, http.StatusOK) - } - expected := "{\"msg\":\"secret\"}\n" - actual := rec.Body.String() - if expected != actual { - t.Fatalf("got body %s, expected %s", expected, actual) - } + assert.NoError(t, err) + assert.Equal(t, "secrettoken", s.lastUsedToken) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "{\"msg\":\"secret\"}\n", rec.Body.String()) } func TestGetMsgHandlerError(t *testing.T) { @@ -65,17 +53,11 @@ func TestGetMsgHandlerError(t *testing.T) { s := &FakeSecretMsgStorer{msg: "secret", err: errors.New("expired")} h := newSecretHandlers(s) err := h.GetMsgHandler(c) - if err == nil { - t.Fatalf("got no error, expected one") - } - v, ok := err.(*echo.HTTPError) - if !ok { - t.Fatalf("expected an HTTPError, got %s", reflect.TypeOf(v)) - } - - if v.Code != http.StatusInternalServerError { - t.Fatalf("got statusCode %d, expected %d", v.Code, http.StatusInternalServerError) + assert.Error(t, err) + if assert.IsType(t, &echo.HTTPError{}, err) { + v, _ := err.(*echo.HTTPError) + assert.Equal(t, http.StatusInternalServerError, v.Code) } } @@ -86,13 +68,9 @@ func TestHealthHandler(t *testing.T) { c := e.NewContext(req, rec) err := healthHandler(c) - if err != nil { - t.Fatalf("error returned %v, expected nil", err) - } - if rec.Code != http.StatusOK { - t.Fatalf("got statusCode %d, expected %d", rec.Code, http.StatusOK) - } + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) } func TestRedirectHandler(t *testing.T) { @@ -102,16 +80,8 @@ func TestRedirectHandler(t *testing.T) { c := e.NewContext(req, rec) err := redirectHandler(c) - if err != nil { - t.Fatalf("error returned %v, expected nil", err) - } + assert.NoError(t, err) - if rec.Code != http.StatusPermanentRedirect { - t.Fatalf("got statusCode %d, expected %d", rec.Code, http.StatusOK) - } - - l := rec.Result().Header.Get("Location") - if l != "/msg" { - t.Fatalf("redirect Location is %s, expected %s", l, "/msg") - } + assert.Equal(t, http.StatusPermanentRedirect, rec.Code) + assert.Equal(t, "/msg", rec.Result().Header.Get("Location")) } diff --git a/internal/server.go b/internal/server.go index 18743f8..dd328ea 100644 --- a/internal/server.go +++ b/internal/server.go @@ -24,9 +24,17 @@ func Serve(cnf conf) { e.AutoTLSManager.Cache = autocert.DirCache("/var/www/.cache") } - e.Use(middleware.Logger()) + // // Limit to 10 RPS (only human should use this service) + e.Use(middleware.RateLimiter(middleware.NewRateLimiterMemoryStore(10))) + // do not log the /health endpoint + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ + Skipper: func(c echo.Context) bool { + return c.Path() == "/health" + }, + })) e.Use(middleware.BodyLimit("50M")) e.Use(middleware.Secure()) + e.Use(middleware.Recover()) e.GET("/", redirectHandler) e.File("/robots.txt", "static/robots.txt") diff --git a/internal/vault.go b/internal/vault.go index 9def58e..db71b10 100644 --- a/internal/vault.go +++ b/internal/vault.go @@ -31,17 +31,17 @@ func (v vault) Store(msg string, ttl string) (token string, err error) { // Default TTL if ttl == "" { ttl = "48h" - } - - // Verify duration - d, err := time.ParseDuration(ttl) - if err != nil { - return "", fmt.Errorf("cannot parse duration %v", err) - } + } else { + // Verify duration + d, err := time.ParseDuration(ttl) + if err != nil { + return "", fmt.Errorf("cannot parse duration %v", err) + } - // validate duration length - if d > 168*time.Hour || d == 0*time.Hour { - return "", fmt.Errorf("cannot set ttl to infinte or more than 7 days %v", err) + // validate duration length + if d > 168*time.Hour || d == 0*time.Hour { + return "", fmt.Errorf("cannot set ttl to infinte or more than 7 days %v", err) + } } t, err := v.createOneTimeToken(ttl) diff --git a/internal/vault_test.go b/internal/vault_test.go index 248db20..6686e2f 100644 --- a/internal/vault_test.go +++ b/internal/vault_test.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/vault/api" vaulthttp "github.com/hashicorp/vault/http" hashivault "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/assert" ) func createTestVault(t *testing.T) (net.Listener, *api.Client) { @@ -24,14 +25,11 @@ func createTestVault(t *testing.T) (net.Listener, *api.Client) { conf.Address = addr c, err := api.NewClient(conf) - if err != nil { - t.Fatal(err) - } - c.SetToken(rootToken) - _, err = c.Sys().Health() - if err != nil { - t.Fatal(err) + if assert.NoError(t, err) { + c.SetToken(rootToken) + _, err = c.Sys().Health() + assert.NoError(t, err) } return ln, c @@ -44,17 +42,10 @@ func TestStoreAndGet(t *testing.T) { v := newVault(c.Address(), "secret/test/", c.Token()) secret := "my secret" token, err := v.Store(secret, "") - if err != nil { - t.Fatalf("no error expected, got %v", err) - } - - msg, err := v.Get(token) - if err != nil { - t.Fatalf("no error expected, got %v", err) - } - - if msg != secret { - t.Fatalf("expected message %s, got: %s", secret, msg) + if assert.NoError(t, err) { + msg, err := v.Get(token) + assert.NoError(t, err) + assert.Equal(t, secret, msg) } } @@ -65,17 +56,11 @@ func TestMsgCanOnlyBeAccessedOnce(t *testing.T) { v := newVault(c.Address(), "secret/test/", c.Token()) secret := "my secret" token, err := v.Store(secret, "") - if err != nil { - t.Fatalf("no error expected, got %v", err) - } - - _, err = v.Get(token) - if err != nil { - t.Fatalf("no error expected, got %v", err) - } + if assert.NoError(t, err) { + _, err = v.Get(token) + assert.NoError(t, err) - _, err = v.Get(token) - if err == nil { - t.Fatal("error expected, got nil") + _, err = v.Get(token) + assert.Error(t, err) } }