diff --git a/Makefile b/Makefile index 7651e36..3d2c434 100644 --- a/Makefile +++ b/Makefile @@ -10,9 +10,15 @@ dropdb: migrateup: migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up +migrateup1: + migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up 1 + migratedown: migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down +migratedown1: + migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down 1 + sqlc: sqlc generate @@ -25,4 +31,4 @@ server: mock: mockgen -package mockdb -destination db/mock/store.go github.com/viralparmarme/simple-bank/db/sqlc Store -.PHONY: postgres createdb dropdb migrateup migratedown sqlc test server mock \ No newline at end of file +.PHONY: postgres createdb dropdb migrateup migratedown sqlc test server mock migrateup1 migratedown1 \ No newline at end of file diff --git a/api/account.go b/api/account.go index 24668e2..3858706 100644 --- a/api/account.go +++ b/api/account.go @@ -2,10 +2,13 @@ package api import ( "database/sql" + "errors" "net/http" "github.com/gin-gonic/gin" + "github.com/lib/pq" db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/token" ) func errorResponse(err error) gin.H { @@ -13,7 +16,6 @@ func errorResponse(err error) gin.H { } type createAccountRequest struct { - Owner string `json:"owner" binding:"required"` Currency string `json:"currency" binding:"required,currency"` } @@ -25,14 +27,23 @@ func (server *Server) CreateAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authPayloadKey).(*token.Payload) + arg := db.CreateAccountParams{ - Owner: req.Owner, + Owner: authPayload.Username, Currency: req.Currency, Balance: 0, } account, err := server.store.CreateAccount(ctx, arg) if err != nil { + if pqErr, ok := err.(*pq.Error); ok { + switch pqErr.Code.Name() { + case "foreign_key_violation", "unique_violation": + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return + } + } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return } @@ -62,6 +73,13 @@ func (server *Server) GetAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authPayloadKey).(*token.Payload) + if account.Owner != authPayload.Username { + err := errors.New("account does not belong to this user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + ctx.JSON(http.StatusOK, account) } @@ -78,7 +96,10 @@ func (server *Server) ListAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authPayloadKey).(*token.Payload) + arg := db.ListAccountsParams{ + Owner: authPayload.Username, Limit: req.PageSize, Offset: (req.PageID - 1) * req.PageSize, } diff --git a/api/account_test.go b/api/account_test.go index 18e1c80..9b1ed53 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -9,26 +9,33 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" mockdb "github.com/viralparmarme/simple-bank/db/mock" db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/token" "github.com/viralparmarme/simple-bank/util" ) func TestGetAccountAPI(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccount(user.Username) testCases := []struct { name string accountID int64 + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) }{ { name: "OK", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -43,6 +50,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "NotFound", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -56,6 +66,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InternalError", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -69,6 +82,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InvalidID", accountID: 0, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Any()). @@ -78,6 +94,36 @@ func TestGetAccountAPI(t *testing.T) { require.Equal(t, http.StatusBadRequest, recorder.Code) }, }, + { + name: "UnauthorizedUser", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, "unauthorized_user", time.Minute) + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "NoAuthorization", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, } for i := range testCases { @@ -90,23 +136,24 @@ func TestGetAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() url := fmt.Sprintf("/accounts/%d", tc.accountID) request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) } } -func randomAccount() db.Account { +func randomAccount(owner string) db.Account { return db.Account{ ID: util.RandomInt(1, 1000), - Owner: util.RandomOwner(), + Owner: owner, Balance: util.RandomMoney(), Currency: util.RandomCurrency(), } diff --git a/api/main_test.go b/api/main_test.go index b7b6d3b..92051f4 100644 --- a/api/main_test.go +++ b/api/main_test.go @@ -3,10 +3,26 @@ package api import ( "os" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/util" ) +func newTestServer(t *testing.T, store db.Store) *Server { + config := util.Config{ + TokenSymmetricKey: util.RandomString(32), + AccessTokenDuration: time.Minute, + } + + server, err := NewServer(config, store) + require.NoError(t, err) + + return server +} + func TestMain(m *testing.M) { gin.SetMode(gin.TestMode) diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..e761f50 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,52 @@ +package api + +import ( + "errors" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/viralparmarme/simple-bank/token" +) + +const ( + authHeaderKey = "authorization" + authTypeBearer = "bearer" + authPayloadKey = "authorization_payload" +) + +func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authHeader := ctx.GetHeader(authHeaderKey) + if len(authHeader) == 0 { + err := errors.New("authorization header is not provided") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + fields := strings.Fields(authHeader) + if len(fields) < 2 { + err := errors.New("authorization header is invalid") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + authType := strings.ToLower(fields[0]) + if authType != authTypeBearer { + err := errors.New("authorization header is unsupported") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + err := errors.New("authorization header format is invalid") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authPayloadKey, payload) + ctx.Next() + } +} diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..29fb3c4 --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,106 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/viralparmarme/simple-bank/token" +) + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + username string, + duration time.Duration, +) { + token, err := tokenMaker.CreateToken(username, duration) + require.NoError(t, err) + + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) + request.Header.Set(authHeaderKey, authorizationHeader) +} + +func TestAuthMiddleware(t *testing.T) { + testCases := []struct { + name string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) + checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + name: "OK", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "NotAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "UnsupportedAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "InvalidAuthorizationFormat", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authTypeBearer, "user", -time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "ExpiredToken", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "", "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + server := newTestServer(t, nil) + + authPath := "/auth" + server.router.GET( + authPath, + authMiddleware(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.NoError(t, err) + + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + }) + } +} diff --git a/api/server.go b/api/server.go index f87f1fe..3295377 100644 --- a/api/server.go +++ b/api/server.go @@ -1,33 +1,61 @@ package api import ( + "fmt" + "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/token" + "github.com/viralparmarme/simple-bank/util" ) type Server struct { - store db.Store - router *gin.Engine + config util.Config + store db.Store + tokenMaker token.Maker + router *gin.Engine } -func NewServer(store db.Store) *Server { +func NewServer(config util.Config, store db.Store) (*Server, error) { server := &Server{store: store} - router := gin.Default() + + tokenMaker, err := token.NewPasetoMaker(config.TokenSymmetricKey) + if err != nil { + return nil, fmt.Errorf("can't create token maker: %w", err) + } + + server = &Server{ + config: config, + store: store, + tokenMaker: tokenMaker, + } if v, ok := binding.Validator.Engine().(*validator.Validate); ok { v.RegisterValidation("currency", validCurrency) } - router.POST("/accounts", server.CreateAccount) - router.GET("/accounts/:id", server.GetAccount) - router.GET("/accounts", server.ListAccount) + server.setupRouter() + + return server, nil +} + +func (server *Server) setupRouter() { + router := gin.Default() + + authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker)) + + authRoutes.POST("/accounts", server.CreateAccount) + authRoutes.GET("/accounts/:id", server.GetAccount) + authRoutes.GET("/accounts", server.ListAccount) + + authRoutes.POST("/transfers", server.CreateTransfer) - router.POST("/transfers", server.CreateTransfer) + router.POST("/users", server.CreateUser) + router.POST("/users/login", server.loginUser) server.router = router - return server } func (server *Server) Start(address string) error { diff --git a/api/transfer.go b/api/transfer.go index 3072881..b1f025e 100644 --- a/api/transfer.go +++ b/api/transfer.go @@ -2,11 +2,13 @@ package api import ( "database/sql" + "errors" "fmt" "net/http" "github.com/gin-gonic/gin" db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/token" ) type transferRequest struct { @@ -24,10 +26,20 @@ func (server *Server) CreateTransfer(ctx *gin.Context) { return } - if !server.validAccount(ctx, req.FromAccountId, req.Currency) { + fromAccount, valid := server.validAccount(ctx, req.FromAccountId, req.Currency) + if !valid { return } - if !server.validAccount(ctx, req.ToAccountId, req.Currency) { + + authPayload := ctx.MustGet(authPayloadKey).(*token.Payload) + if fromAccount.Owner != authPayload.Username { + err := errors.New("from account does not belong to authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + _, valid = server.validAccount(ctx, req.FromAccountId, req.Currency) + if !valid { return } @@ -46,22 +58,22 @@ func (server *Server) CreateTransfer(ctx *gin.Context) { ctx.JSON(http.StatusOK, result) } -func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) bool { +func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) (db.Account, bool) { account, err := server.store.GetAccount(ctx, accountID) if err != nil { if err == sql.ErrNoRows { ctx.JSON(http.StatusNotFound, errorResponse(err)) - return false + return account, false } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) - return false + return account, false } if account.Currency != currency { err := fmt.Errorf("account %d mismatch: %s vs %s", accountID, account.Currency, currency) ctx.JSON(http.StatusBadRequest, errorResponse(err)) - return false + return account, false } - return true + return account, true } diff --git a/api/user.go b/api/user.go new file mode 100644 index 0000000..af2f0a3 --- /dev/null +++ b/api/user.go @@ -0,0 +1,126 @@ +package api + +import ( + "database/sql" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/lib/pq" + db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/util" +) + +type createUserRequest struct { + Username string `json:"username" binding:"required,alphanum"` + Password string `json:"password" binding:"required,min=6"` + FullName string `json:"full_name" binding:"required"` + Email string `json:"email" binding:"required,email"` +} + +type userResponse struct { + Username string `json:"username"` + FullName string `json:"full_name"` + Email string `json:"email"` + PasswordChangedAt time.Time `json:"password_changed_at"` + CreatedAt time.Time `json:"created_at"` +} + +func newUserResponse(user db.User) userResponse { + return userResponse{ + Username: user.Username, + FullName: user.FullName, + Email: user.Email, + PasswordChangedAt: user.PasswordChangedAt, + CreatedAt: user.CreatedAt, + } +} + +func (server *Server) CreateUser(ctx *gin.Context) { + var req createUserRequest + err := ctx.ShouldBindJSON(&req) + if err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + hashedPassword, err := util.HashPassword(req.Password) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + arg := db.CreateUserParams{ + Username: req.Username, + HashedPassword: hashedPassword, + FullName: req.FullName, + Email: req.Email, + } + + user, err := server.store.CreateUser(ctx, arg) + if err != nil { + if pqErr, ok := err.(*pq.Error); ok { + switch pqErr.Code.Name() { + case "unique_violation": + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return + } + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + resp := newUserResponse(user) + ctx.JSON(http.StatusOK, resp) +} + +type loginUserRequest struct { + Username string `json:"username" binding:"required,alphanum"` + Password string `json:"password" binding:"required,min=6"` +} + +type loginUserResponse struct { + AccessToken string `json:"access_token"` + User userResponse `json:"user"` +} + +func (server *Server) loginUser(ctx *gin.Context) { + var req loginUserRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + user, err := server.store.GetUser(ctx, req.Username) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + } + + err = util.CheckPassword(req.Password, user.HashedPassword) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken, err := server.tokenMaker.CreateToken( + user.Username, + server.config.AccessTokenDuration, + ) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + resp := loginUserResponse{ + AccessToken: accessToken, + User: newUserResponse(user), + } + + ctx.JSON(http.StatusOK, resp) +} diff --git a/api/user_test.go b/api/user_test.go new file mode 100644 index 0000000..d010cc6 --- /dev/null +++ b/api/user_test.go @@ -0,0 +1,23 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" + db "github.com/viralparmarme/simple-bank/db/sqlc" + "github.com/viralparmarme/simple-bank/util" +) + +func randomUser(t *testing.T) (user db.User, password string) { + password = util.RandomString(6) + hashedPassword, err := util.HashPassword(password) + require.NoError(t, err) + + user = db.User{ + Username: util.RandomOwner(), + HashedPassword: hashedPassword, + FullName: util.RandomOwner(), + Email: util.RandomEmail(), + } + return +} diff --git a/app.env b/app.env index 3d2b7a2..e49452d 100644 --- a/app.env +++ b/app.env @@ -1,3 +1,5 @@ DB_DRIVER="postgres" DB_SOURCE="postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -SERVER_ADDRESS="0.0.0.0:8080" \ No newline at end of file +SERVER_ADDRESS="0.0.0.0:8080" +TOKEN_SYMMETRIC_KEY=12345678901234567890121234567890 +ACCESS_TOKEN_DURATION=15m \ No newline at end of file diff --git a/db/migration/000002_add_users.down.sql b/db/migration/000002_add_users.down.sql new file mode 100644 index 0000000..9c61760 --- /dev/null +++ b/db/migration/000002_add_users.down.sql @@ -0,0 +1,5 @@ +ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "owner_currency_key"; + +ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "accounts_owner_fkey"; + +DROP TABLE IF EXISTS "users"; \ No newline at end of file diff --git a/db/migration/000002_add_users.up.sql b/db/migration/000002_add_users.up.sql new file mode 100644 index 0000000..3b30a6b --- /dev/null +++ b/db/migration/000002_add_users.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE "users" ( + "username" varchar PRIMARY KEY, + "hashed_password" varchar NOT NULL, + "full_name" varchar NOT NULL, + "email" varchar UNIQUE NOT NULL, + "password_changed_at" timestamptz NOT NULL DEFAULT '0001-01-01 00:00:00Z', + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + +ALTER TABLE "accounts" ADD FOREIGN KEY ("owner") REFERENCES "users" ("username"); + +ALTER TABLE "accounts" ADD CONSTRAINT "owner_currency_key" UNIQUE ("owner", "currency"); \ No newline at end of file diff --git a/db/mock/store.go b/db/mock/store.go index 7379bf2..c721553 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -95,6 +95,21 @@ func (mr *MockStoreMockRecorder) CreateTransfer(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTransfer", reflect.TypeOf((*MockStore)(nil).CreateTransfer), arg0, arg1) } +// CreateUser mocks base method. +func (m *MockStore) CreateUser(arg0 context.Context, arg1 db.CreateUserParams) (db.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", arg0, arg1) + ret0, _ := ret[0].(db.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser. +func (mr *MockStoreMockRecorder) CreateUser(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockStore)(nil).CreateUser), arg0, arg1) +} + // DeleteAccount mocks base method. func (m *MockStore) DeleteAccount(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() @@ -169,6 +184,21 @@ func (mr *MockStoreMockRecorder) GetTransfer(arg0, arg1 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransfer", reflect.TypeOf((*MockStore)(nil).GetTransfer), arg0, arg1) } +// GetUser mocks base method. +func (m *MockStore) GetUser(arg0 context.Context, arg1 string) (db.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUser", arg0, arg1) + ret0, _ := ret[0].(db.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUser indicates an expected call of GetUser. +func (mr *MockStoreMockRecorder) GetUser(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockStore)(nil).GetUser), arg0, arg1) +} + // ListAccounts mocks base method. func (m *MockStore) ListAccounts(arg0 context.Context, arg1 db.ListAccountsParams) ([]db.Account, error) { m.ctrl.T.Helper() diff --git a/db/query/account.sql b/db/query/account.sql index 8ae578e..cbe7599 100644 --- a/db/query/account.sql +++ b/db/query/account.sql @@ -18,9 +18,10 @@ FOR NO KEY UPDATE; -- name: ListAccounts :many SELECT * FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2; +LIMIT $2 +OFFSET $3; -- name: UpdateAccount :one UPDATE accounts diff --git a/db/query/user.sql b/db/query/user.sql new file mode 100644 index 0000000..4c8518b --- /dev/null +++ b/db/query/user.sql @@ -0,0 +1,13 @@ +-- name: CreateUser :one +INSERT INTO users ( + username, + hashed_password, + full_name, + email +) VALUES ( + $1, $2, $3, $4 +) RETURNING *; + +-- name: GetUser :one +SELECT * from users +WHERE username = $1 LIMIT 1; \ No newline at end of file diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 8bf481b..fbfc09d 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -112,18 +112,20 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e const listAccounts = `-- name: ListAccounts :many SELECT id, owner, balance, currency, created_at FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2 +LIMIT $2 +OFFSET $3 ` type ListAccountsParams struct { - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + Owner string `json:"owner"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.QueryContext(ctx, listAccounts, arg.Limit, arg.Offset) + rows, err := q.db.QueryContext(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index daa91d9..f45864c 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -11,8 +11,10 @@ import ( ) func createRandomAccount(t *testing.T) Account { + user := createRandomUser(t) + arg := CreateAccountParams{ - Owner: util.RandomOwner(), + Owner: user.Username, Balance: util.RandomMoney(), Currency: util.RandomCurrency(), } @@ -90,20 +92,24 @@ func TestDeleteAccount(t *testing.T) { } func TestListAccounts(t *testing.T) { + var lastAccount Account + for i := 0; i < 10; i++ { - createRandomAccount(t) + lastAccount = createRandomAccount(t) } arg := ListAccountsParams{ + Owner: lastAccount.Owner, Limit: 5, - Offset: 5, + Offset: 0, } accounts, err := testQueries.ListAccounts(context.Background(), arg) require.NoError(t, err) - require.Len(t, accounts, 5) + require.NotEmpty(t, accounts) for _, account := range accounts { require.NotEmpty(t, account) + require.Equal(t, lastAccount.Owner, account.Owner) } } diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 19fe958..6412782 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -32,3 +32,12 @@ type Transfer struct { Amount int64 `json:"amount"` CreatedAt time.Time `json:"created_at"` } + +type User struct { + Username string `json:"username"` + HashedPassword string `json:"hashed_password"` + FullName string `json:"full_name"` + Email string `json:"email"` + PasswordChangedAt time.Time `json:"password_changed_at"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index e721b1f..ba338e4 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -13,11 +13,13 @@ type Querier interface { CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) + CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAccount(ctx context.Context, id int64) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) GetEntry(ctx context.Context, id int64) (Entry, error) GetTransfer(ctx context.Context, id int64) (Transfer, error) + GetUser(ctx context.Context, username string) (User, error) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) diff --git a/db/sqlc/user.sql.go b/db/sqlc/user.sql.go new file mode 100644 index 0000000..3e93a52 --- /dev/null +++ b/db/sqlc/user.sql.go @@ -0,0 +1,66 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 +// source: user.sql + +package db + +import ( + "context" +) + +const createUser = `-- name: CreateUser :one +INSERT INTO users ( + username, + hashed_password, + full_name, + email +) VALUES ( + $1, $2, $3, $4 +) RETURNING username, hashed_password, full_name, email, password_changed_at, created_at +` + +type CreateUserParams struct { + Username string `json:"username"` + HashedPassword string `json:"hashed_password"` + FullName string `json:"full_name"` + Email string `json:"email"` +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRowContext(ctx, createUser, + arg.Username, + arg.HashedPassword, + arg.FullName, + arg.Email, + ) + var i User + err := row.Scan( + &i.Username, + &i.HashedPassword, + &i.FullName, + &i.Email, + &i.PasswordChangedAt, + &i.CreatedAt, + ) + return i, err +} + +const getUser = `-- name: GetUser :one +SELECT username, hashed_password, full_name, email, password_changed_at, created_at from users +WHERE username = $1 LIMIT 1 +` + +func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { + row := q.db.QueryRowContext(ctx, getUser, username) + var i User + err := row.Scan( + &i.Username, + &i.HashedPassword, + &i.FullName, + &i.Email, + &i.PasswordChangedAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/db/sqlc/user_test.go b/db/sqlc/user_test.go new file mode 100644 index 0000000..e083293 --- /dev/null +++ b/db/sqlc/user_test.go @@ -0,0 +1,57 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/viralparmarme/simple-bank/util" +) + +func createRandomUser(t *testing.T) User { + hashedPassword, err := util.HashPassword(util.RandomString(6)) + require.NoError(t, err) + + arg := CreateUserParams{ + Username: util.RandomOwner(), + HashedPassword: hashedPassword, + FullName: util.RandomOwner(), + Email: util.RandomEmail(), + } + + user, err := testQueries.CreateUser(context.Background(), arg) + + require.NoError(t, err) + require.NotEmpty(t, user) + + require.Equal(t, arg.Username, user.Username) + require.Equal(t, arg.HashedPassword, user.HashedPassword) + require.Equal(t, arg.FullName, user.FullName) + require.Equal(t, arg.Email, user.Email) + + require.True(t, user.PasswordChangedAt.IsZero()) + require.NotZero(t, user.CreatedAt) + + return user +} + +func TestCreateUser(t *testing.T) { + createRandomUser(t) +} + +func TestGetUser(t *testing.T) { + user1 := createRandomUser(t) + user2, err := testQueries.GetUser(context.Background(), user1.Username) + + require.NoError(t, err) + require.NotEmpty(t, user2) + + require.Equal(t, user1.Username, user2.Username) + require.Equal(t, user1.HashedPassword, user2.HashedPassword) + require.Equal(t, user1.FullName, user2.FullName) + require.Equal(t, user1.Email, user2.Email) + + require.WithinDuration(t, user1.PasswordChangedAt, user2.PasswordChangedAt, time.Second) + require.WithinDuration(t, user1.CreatedAt, user2.CreatedAt, time.Second) +} diff --git a/go.mod b/go.mod index 86f1ee6..43ea453 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,22 @@ module github.com/viralparmarme/simple-bank go 1.21 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gin-gonic/gin v1.9.1 github.com/go-playground/validator/v10 v10.14.0 github.com/golang/mock v1.6.0 + github.com/google/uuid v1.3.1 github.com/lib/pq v1.10.9 + github.com/o1egl/paseto v1.0.0 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.9.0 ) require ( + github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb // indirect + github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -31,6 +38,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.5.1 // indirect @@ -40,7 +48,6 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.9.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 90a98d3..b15d0d9 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,12 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= +github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb h1:6Z/wqhPFZ7y5ksCEV/V5MXOazLaeu/EW97CU5rz8NWk= +github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb/go.mod h1:UzH9IX1MMqOcwhoNOIjmTQeAxrFgzs50j4golQtXXxU= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -55,6 +61,8 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -142,6 +150,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= @@ -182,8 +192,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/o1egl/paseto v1.0.0 h1:bwpvPu2au176w4IBlhbyUv/S5VPptERIA99Oap5qUd0= +github.com/o1egl/paseto v1.0.0/go.mod h1:5HxsZPmw/3RI2pAwGo1HhOOwSdvBpcuVzO7uDkm+CLU= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -237,6 +251,7 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20181025213731-e84da0312774/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -336,6 +351,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index b580b8e..6ead7ac 100644 --- a/main.go +++ b/main.go @@ -24,7 +24,10 @@ func main() { } store := db.NewStore(conn) - server := api.NewServer(store) + server, err := api.NewServer(config, store) + if err != nil { + log.Fatal("Can't start server:", err) + } err = server.Start(config.ServerAddress) if err != nil { diff --git a/token/jwt_maker.go b/token/jwt_maker.go new file mode 100644 index 0000000..bb298a0 --- /dev/null +++ b/token/jwt_maker.go @@ -0,0 +1,57 @@ +package token + +import ( + "errors" + "fmt" + "time" + + "github.com/dgrijalva/jwt-go" +) + +const minSecretKeySize = 32 + +type JWTMaker struct { + secretKey string +} + +func NewJWTMaker(secretKey string) (Maker, error) { + if len(secretKey) < minSecretKeySize { + return nil, fmt.Errorf("invalid key size: must be at least %d chars", minSecretKeySize) + } + + return &JWTMaker{secretKey}, nil +} + +func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (string, error) { + p, err := NewPayload(username, duration) + if err != nil { + return "", err + } + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, p) + return jwtToken.SignedString([]byte(maker.secretKey)) +} + +func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { + keyFunc := func(token *jwt.Token) (interface{}, error) { + _, ok := token.Method.(*jwt.SigningMethodHMAC) + if !ok { + return nil, ErrInvalidToken + } + return []byte(maker.secretKey), nil + } + + jwToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc) + if err != nil { + verr, ok := err.(*jwt.ValidationError) + if ok && errors.Is(verr.Inner, ErrExpiredToken) { + return nil, ErrExpiredToken + } + return nil, ErrInvalidToken + } + + payload, ok := jwToken.Claims.(*Payload) + if !ok { + return nil, ErrInvalidToken + } + return payload, nil +} diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go new file mode 100644 index 0000000..b5ba223 --- /dev/null +++ b/token/jwt_maker_test.go @@ -0,0 +1,68 @@ +package token + +import ( + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + "github.com/viralparmarme/simple-bank/util" +) + +func TestJWTMaker(t *testing.T) { + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := time.Minute + + issuedAt := time.Now() + expiredAt := issuedAt.Add(duration) + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.NoError(t, err) + require.NotEmpty(t, payload) + + require.NotZero(t, payload.ID) + require.Equal(t, username, payload.Username) + require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) + require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) +} + +func TestExpiredJWT(t *testing.T) { + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := -time.Minute + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrExpiredToken.Error()) + require.Nil(t, payload) +} + +func TestInvalidJWTTokenAlgNone(t *testing.T) { + payload, err := NewPayload(util.RandomOwner(), time.Minute) + require.NoError(t, err) + + jwToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) + token, err := jwToken.SignedString(jwt.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + payload, err = maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrInvalidToken.Error()) + require.Nil(t, payload) +} diff --git a/token/maker.go b/token/maker.go new file mode 100644 index 0000000..a5d2c1e --- /dev/null +++ b/token/maker.go @@ -0,0 +1,9 @@ +package token + +import "time" + +type Maker interface { + CreateToken(username string, duration time.Duration) (string, error) + + VerifyToken(token string) (*Payload, error) +} diff --git a/token/paseto_maker.go b/token/paseto_maker.go new file mode 100644 index 0000000..eb96b24 --- /dev/null +++ b/token/paseto_maker.go @@ -0,0 +1,52 @@ +package token + +import ( + "fmt" + "time" + + "github.com/o1egl/paseto" + "golang.org/x/crypto/chacha20poly1305" +) + +type PasetoMaker struct { + paseto *paseto.V2 + symmetricKey []byte +} + +func NewPasetoMaker(symmetricKey string) (Maker, error) { + if len(symmetricKey) != chacha20poly1305.KeySize { + return nil, fmt.Errorf("invalid key size, must be %d", chacha20poly1305.KeySize) + } + + maker := &PasetoMaker{ + paseto: paseto.NewV2(), + symmetricKey: []byte(symmetricKey), + } + return maker, nil +} + +func (maker *PasetoMaker) CreateToken(username string, duration time.Duration) (string, error) { + p, err := NewPayload(username, duration) + if err != nil { + return "", err + } + + return maker.paseto.Encrypt(maker.symmetricKey, p, nil) + +} + +func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { + payload := &Payload{} + + err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil) + if err != nil { + return nil, ErrInvalidToken + } + + err = payload.Valid() + if err != nil { + return nil, err + } + + return payload, nil +} diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go new file mode 100644 index 0000000..dae8fce --- /dev/null +++ b/token/paseto_maker_test.go @@ -0,0 +1,50 @@ +package token + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/viralparmarme/simple-bank/util" +) + +func TestPasetoMaker(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := time.Minute + + issuedAt := time.Now() + expiredAt := issuedAt.Add(duration) + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.NoError(t, err) + require.NotEmpty(t, payload) + + require.NotZero(t, payload.ID) + require.Equal(t, username, payload.Username) + require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) + require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) +} + +func TestExpiredPaseto(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := -time.Minute + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrExpiredToken.Error()) + require.Nil(t, payload) +} diff --git a/token/payload.go b/token/payload.go new file mode 100644 index 0000000..c4c2550 --- /dev/null +++ b/token/payload.go @@ -0,0 +1,43 @@ +package token + +import ( + "errors" + "time" + + "github.com/google/uuid" +) + +var ( + ErrExpiredToken = errors.New("token has expired") + ErrInvalidToken = errors.New("token is invalid") +) + +type Payload struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + IssuedAt time.Time `json:"issued_at"` + ExpiredAt time.Time `json:"expired_at"` +} + +func NewPayload(username string, duration time.Duration) (*Payload, error) { + tokenID, err := uuid.NewRandom() + if err != nil { + return nil, err + } + + payload := &Payload{ + ID: tokenID, + Username: username, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(duration), + } + + return payload, nil +} + +func (payload *Payload) Valid() error { + if time.Now().After(payload.ExpiredAt) { + return ErrExpiredToken + } + return nil +} diff --git a/util/config.go b/util/config.go index 06824fd..d709cdb 100644 --- a/util/config.go +++ b/util/config.go @@ -1,11 +1,17 @@ package util -import "github.com/spf13/viper" +import ( + "time" + + "github.com/spf13/viper" +) type Config struct { - DBDriver string `mapstructure:"DB_DRIVER"` - DBSource string `mapstructure:"DB_SOURCE"` - ServerAddress string `mapstructure:"SERVER_ADDRESS"` + DBDriver string `mapstructure:"DB_DRIVER"` + DBSource string `mapstructure:"DB_SOURCE"` + ServerAddress string `mapstructure:"SERVER_ADDRESS"` + TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` + AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` } func LoadConfig(path string) (config Config, err error) { diff --git a/util/password.go b/util/password.go new file mode 100644 index 0000000..e244a3a --- /dev/null +++ b/util/password.go @@ -0,0 +1,19 @@ +package util + +import ( + "fmt" + + "golang.org/x/crypto/bcrypt" +) + +func HashPassword(password string) (string, error) { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash password: %d", err) + } + return string(hashedPassword), nil +} + +func CheckPassword(password string, hashedPassword string) error { + return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) +} diff --git a/util/password_test.go b/util/password_test.go new file mode 100644 index 0000000..0877079 --- /dev/null +++ b/util/password_test.go @@ -0,0 +1,29 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestPassword(t *testing.T) { + password := RandomString(6) + + hashedPassword, err := HashPassword(password) + require.NoError(t, err) + require.NotEmpty(t, hashedPassword) + + err = CheckPassword(password, hashedPassword) + require.NoError(t, err) + + wrongPassword := RandomString(6) + err = CheckPassword(wrongPassword, hashedPassword) + require.Error(t, err) + require.EqualError(t, err, bcrypt.ErrMismatchedHashAndPassword.Error()) + + hashedPassword2, err := HashPassword(password) + require.NoError(t, err) + require.NotEmpty(t, hashedPassword2) + require.NotEqual(t, hashedPassword, hashedPassword2) +} diff --git a/util/random.go b/util/random.go index 28908db..76896f3 100644 --- a/util/random.go +++ b/util/random.go @@ -1,6 +1,7 @@ package util import ( + "fmt" "math/rand" "strings" "time" @@ -41,3 +42,7 @@ func RandomCurrency() string { n := len(currencies) return currencies[rand.Intn(n)] } + +func RandomEmail() string { + return fmt.Sprintf("%s@email.com", RandomString(6)) +}