Skip to content

Commit

Permalink
Add HTTP API, token based auth and middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
viralparmarme committed Sep 25, 2023
1 parent 12b85fc commit f57b2f5
Show file tree
Hide file tree
Showing 35 changed files with 1,047 additions and 41 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ dropdb:
migrateup:
migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up

migrateup1:
migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up 1

migratedown:
migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down

migratedown1:
migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down 1

sqlc:
sqlc generate

Expand All @@ -25,4 +31,4 @@ server:
mock:
mockgen -package mockdb -destination db/mock/store.go github.com/viralparmarme/simple-bank/db/sqlc Store

.PHONY: postgres createdb dropdb migrateup migratedown sqlc test server mock
.PHONY: postgres createdb dropdb migrateup migratedown sqlc test server mock migrateup1 migratedown1
25 changes: 23 additions & 2 deletions api/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package api

import (
"database/sql"
"errors"
"net/http"

"github.com/gin-gonic/gin"
"github.com/lib/pq"
db "github.com/viralparmarme/simple-bank/db/sqlc"
"github.com/viralparmarme/simple-bank/token"
)

func errorResponse(err error) gin.H {
return gin.H{"error": err.Error()}
}

type createAccountRequest struct {
Owner string `json:"owner" binding:"required"`
Currency string `json:"currency" binding:"required,currency"`
}

Expand All @@ -25,14 +27,23 @@ func (server *Server) CreateAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authPayloadKey).(*token.Payload)

arg := db.CreateAccountParams{
Owner: req.Owner,
Owner: authPayload.Username,
Currency: req.Currency,
Balance: 0,
}

account, err := server.store.CreateAccount(ctx, arg)
if err != nil {
if pqErr, ok := err.(*pq.Error); ok {
switch pqErr.Code.Name() {
case "foreign_key_violation", "unique_violation":
ctx.JSON(http.StatusForbidden, errorResponse(err))
return
}
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
Expand Down Expand Up @@ -62,6 +73,13 @@ func (server *Server) GetAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authPayloadKey).(*token.Payload)
if account.Owner != authPayload.Username {
err := errors.New("account does not belong to this user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.JSON(http.StatusOK, account)
}

Expand All @@ -78,7 +96,10 @@ func (server *Server) ListAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authPayloadKey).(*token.Payload)

arg := db.ListAccountsParams{
Owner: authPayload.Username,
Limit: req.PageSize,
Offset: (req.PageID - 1) * req.PageSize,
}
Expand Down
55 changes: 51 additions & 4 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,33 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
mockdb "github.com/viralparmarme/simple-bank/db/mock"
db "github.com/viralparmarme/simple-bank/db/sqlc"
"github.com/viralparmarme/simple-bank/token"
"github.com/viralparmarme/simple-bank/util"
)

func TestGetAccountAPI(t *testing.T) {
account := randomAccount()
user, _ := randomUser(t)
account := randomAccount(user.Username)

testCases := []struct {
name string
accountID int64
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
buildStubs func(store *mockdb.MockStore)
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
name: "OK",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -43,6 +50,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "NotFound",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -56,6 +66,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InternalError",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -69,6 +82,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InvalidID",
accountID: 0,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Any()).
Expand All @@ -78,6 +94,36 @@ func TestGetAccountAPI(t *testing.T) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
},
{
name: "UnauthorizedUser",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authTypeBearer, "unauthorized_user", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Times(1).
Return(account, nil)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "NoAuthorization",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Any()).
Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
}

for i := range testCases {
Expand All @@ -90,23 +136,24 @@ func TestGetAccountAPI(t *testing.T) {
store := mockdb.NewMockStore(ctrl)
tc.buildStubs(store)

server := NewServer(store)
server := newTestServer(t, store)
recorder := httptest.NewRecorder()

url := fmt.Sprintf("/accounts/%d", tc.accountID)
request, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}

func randomAccount() db.Account {
func randomAccount(owner string) db.Account {
return db.Account{
ID: util.RandomInt(1, 1000),
Owner: util.RandomOwner(),
Owner: owner,
Balance: util.RandomMoney(),
Currency: util.RandomCurrency(),
}
Expand Down
16 changes: 16 additions & 0 deletions api/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@ package api
import (
"os"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
db "github.com/viralparmarme/simple-bank/db/sqlc"
"github.com/viralparmarme/simple-bank/util"
)

func newTestServer(t *testing.T, store db.Store) *Server {
config := util.Config{
TokenSymmetricKey: util.RandomString(32),
AccessTokenDuration: time.Minute,
}

server, err := NewServer(config, store)
require.NoError(t, err)

return server
}

func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode)

Expand Down
52 changes: 52 additions & 0 deletions api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package api

import (
"errors"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/viralparmarme/simple-bank/token"
)

const (
authHeaderKey = "authorization"
authTypeBearer = "bearer"
authPayloadKey = "authorization_payload"
)

func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
return func(ctx *gin.Context) {
authHeader := ctx.GetHeader(authHeaderKey)
if len(authHeader) == 0 {
err := errors.New("authorization header is not provided")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

fields := strings.Fields(authHeader)
if len(fields) < 2 {
err := errors.New("authorization header is invalid")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

authType := strings.ToLower(fields[0])
if authType != authTypeBearer {
err := errors.New("authorization header is unsupported")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

accessToken := fields[1]
payload, err := tokenMaker.VerifyToken(accessToken)
if err != nil {
err := errors.New("authorization header format is invalid")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.Set(authPayloadKey, payload)
ctx.Next()
}
}
Loading

0 comments on commit f57b2f5

Please sign in to comment.