From 36cc0c5419ca3ae68009a7462e4b746cbd887e60 Mon Sep 17 00:00:00 2001 From: souvik03-136 <66234771+souvik03-136@users.noreply.github.com> Date: Fri, 16 Aug 2024 00:54:57 +0530 Subject: [PATCH 1/3] feat: add oauth anonymous auth --- go.mod | 2 + go.sum | 4 + internal/controllers/v1/auth/handlers.go | 178 ++++++++++++++++++ internal/database/auth.sql.go | 46 +++++ internal/database/db.go | 2 +- .../000004_create_table_users.down.sql | 1 + .../000004_create_table_users.up.sql | 8 + internal/database/models.go | 9 +- internal/database/pastebin.sql.go | 2 +- internal/database/queries/auth.sql | 8 + internal/database/url.sql.go | 2 +- internal/server/routes.go | 10 +- internal/server/server.go | 21 +-- 13 files changed, 277 insertions(+), 16 deletions(-) create mode 100644 internal/database/auth.sql.go create mode 100644 internal/database/migrations/000004_create_table_users.down.sql create mode 100644 internal/database/migrations/000004_create_table_users.up.sql create mode 100644 internal/database/queries/auth.sql diff --git a/go.mod b/go.mod index 7be21b0..c5d2505 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx v3.6.2+incompatible // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect @@ -32,6 +33,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/go.sum b/go.sum index 684bd41..e49a530 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= +github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= @@ -65,6 +67,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= diff --git a/internal/controllers/v1/auth/handlers.go b/internal/controllers/v1/auth/handlers.go index 8832b06..a4871a5 100644 --- a/internal/controllers/v1/auth/handlers.go +++ b/internal/controllers/v1/auth/handlers.go @@ -1 +1,179 @@ package auth + +import ( + "errors" + "net/http" + "triton-backend/internal/database" + "triton-backend/internal/merrors" + "triton-backend/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +type AuthHandler struct { + db *pgxpool.Pool +} + +func Handler(db *pgxpool.Pool) *AuthHandler { + return &AuthHandler{ + db: db, + } +} + +func (a *AuthHandler) RegisterOAuthUser(c *gin.Context) { + var input struct { + OAuthID string `json:"oauth_id" binding:"required"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + merrors.Validation(c, err.Error()) + return + } + + tx, err := a.db.Begin(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + defer tx.Rollback(c) + + qtx := database.New(a.db).WithTx(tx) + + // Create a new user UUID + userUUID := uuid.New() + + // Try to create a new user in the database + err = qtx.CreateUser(c, database.CreateUserParams{ + Uuid: userUUID, + AuthType: "oauth", + OauthID: input.OAuthID, + }) + var e *pgconn.PgError + if errors.As(err, &e) && e.Code == pgerrcode.UniqueViolation { + merrors.Validation(c, "user already exists with this OAuth ID!") + return + } else if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + err = tx.Commit(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "OAuth user successfully registered", + StatusCode: http.StatusOK, + }) +} + +func (a *AuthHandler) GetUserByOAuthID(c *gin.Context) { + var input struct { + OAuthID string `json:"oauth_id" binding:"required"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + merrors.Validation(c, err.Error()) + return + } + + q := database.New(a.db) + userUUID, err := q.GetUserByOAuthID(c, database.GetUserByOAuthIDParams{ + AuthType: "oauth", + OauthID: input.OAuthID, + }) + if errors.Is(err, pgx.ErrNoRows) { + merrors.NotFound(c, "user not found!") + return + } else if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "OAuth user successfully retrieved", + Data: userUUID, + StatusCode: http.StatusOK, + }) +} + +func (a *AuthHandler) RegisterAnonymousUser(c *gin.Context) { + // For anonymous auth, we generate a UUID and register it as a user. + userUUID := uuid.New() + + tx, err := a.db.Begin(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + defer tx.Rollback(c) + + qtx := database.New(a.db).WithTx(tx) + + // Try to create a new user in the database + err = qtx.CreateUser(c, database.CreateUserParams{ + Uuid: userUUID, + AuthType: "anonymous", + }) + var e *pgconn.PgError + if errors.As(err, &e) && e.Code == pgerrcode.UniqueViolation { + merrors.Validation(c, "user already exists with this ID!") + return + } else if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + err = tx.Commit(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "Anonymous user successfully registered", + Data: userUUID, + StatusCode: http.StatusOK, + }) +} + +func (a *AuthHandler) GetUserByAnonymousID(c *gin.Context) { + var input struct { + UserID uuid.UUID `json:"user_id" binding:"required"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + merrors.Validation(c, err.Error()) + return + } + + q := database.New(a.db) + userUUID, err := q.GetUserByOAuthID(c, database.GetUserByOAuthIDParams{ + AuthType: "anonymous", + OauthID: input.UserID.String(), // You might need to adjust this part based on your actual query setup. + }) + if errors.Is(err, pgx.ErrNoRows) { + merrors.NotFound(c, "user not found!") + return + } else if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "Anonymous user successfully retrieved", + Data: userUUID, + StatusCode: http.StatusOK, + }) +} diff --git a/internal/database/auth.sql.go b/internal/database/auth.sql.go new file mode 100644 index 0000000..3212685 --- /dev/null +++ b/internal/database/auth.sql.go @@ -0,0 +1,46 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: auth.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const createUser = `-- name: CreateUser :exec +INSERT INTO users (uuid, auth_type, oauth_id) +VALUES ($1, $2, $3) +` + +type CreateUserParams struct { + Uuid uuid.UUID `json:"uuid"` + AuthType string `json:"auth_type"` + OauthID string `json:"oauth_id"` +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) error { + _, err := q.db.Exec(ctx, createUser, arg.Uuid, arg.AuthType, arg.OauthID) + return err +} + +const getUserByOAuthID = `-- name: GetUserByOAuthID :one +SELECT uuid +FROM users +WHERE auth_type = $1 AND oauth_id = $2 +` + +type GetUserByOAuthIDParams struct { + AuthType string `json:"auth_type"` + OauthID string `json:"oauth_id"` +} + +func (q *Queries) GetUserByOAuthID(ctx context.Context, arg GetUserByOAuthIDParams) (uuid.UUID, error) { + row := q.db.QueryRow(ctx, getUserByOAuthID, arg.AuthType, arg.OauthID) + var uuid uuid.UUID + err := row.Scan(&uuid) + return uuid, err +} diff --git a/internal/database/db.go b/internal/database/db.go index 1d02744..8187a2b 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 package database diff --git a/internal/database/migrations/000004_create_table_users.down.sql b/internal/database/migrations/000004_create_table_users.down.sql new file mode 100644 index 0000000..6b17475 --- /dev/null +++ b/internal/database/migrations/000004_create_table_users.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS public.users; diff --git a/internal/database/migrations/000004_create_table_users.up.sql b/internal/database/migrations/000004_create_table_users.up.sql new file mode 100644 index 0000000..cfad536 --- /dev/null +++ b/internal/database/migrations/000004_create_table_users.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS public.users ( + uuid uuid NOT NULL, + created_at timestamptz DEFAULT now() NOT NULL, + auth_type text NOT NULL, + oauth_id text NOT NULL, + CONSTRAINT users_unique UNIQUE (auth_type, oauth_id), + CONSTRAINT users_pkey PRIMARY KEY (uuid) +); diff --git a/internal/database/models.go b/internal/database/models.go index 6269ae8..b7a1052 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 package database @@ -22,3 +22,10 @@ type Url struct { UrlUuid uuid.UUID `json:"url_uuid"` UrlName interface{} `json:"url_name"` } + +type User struct { + Uuid uuid.UUID `json:"uuid"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + AuthType string `json:"auth_type"` + OauthID string `json:"oauth_id"` +} diff --git a/internal/database/pastebin.sql.go b/internal/database/pastebin.sql.go index 855e4f1..6629168 100644 --- a/internal/database/pastebin.sql.go +++ b/internal/database/pastebin.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 // source: pastebin.sql package database diff --git a/internal/database/queries/auth.sql b/internal/database/queries/auth.sql new file mode 100644 index 0000000..1561867 --- /dev/null +++ b/internal/database/queries/auth.sql @@ -0,0 +1,8 @@ +-- name: CreateUser :exec +INSERT INTO users (uuid, auth_type, oauth_id) +VALUES (@uuid, @auth_type, @oauth_id); + +-- name: GetUserByOAuthID :one +SELECT uuid +FROM users +WHERE auth_type = @auth_type AND oauth_id = @oauth_id; diff --git a/internal/database/url.sql.go b/internal/database/url.sql.go index 473c96d..a7333af 100644 --- a/internal/database/url.sql.go +++ b/internal/database/url.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.26.0 +// sqlc v1.27.0 // source: url.sql package database diff --git a/internal/server/routes.go b/internal/server/routes.go index 8393066..6931e58 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -15,13 +15,20 @@ func (s *Server) RegisterRoutes() http.Handler { r.Use(CORSMiddleware()) r.GET("/", s.HelloWorldHandler) - r.GET("/health", s.healthHandler) v1 := r.Group("/v1") v1.POST("/pastebin/create", s.PastebinHandler.CreatePastebin) v1.GET("/pastebin/:url", s.PastebinHandler.GetPastebin) + authGroup := v1.Group("/auth") + { + authGroup.POST("/register/oauth", s.AuthHandler.RegisterOAuthUser) + authGroup.POST("/register/anonymous", s.AuthHandler.RegisterAnonymousUser) + authGroup.POST("/get/oauth", s.AuthHandler.GetUserByOAuthID) + authGroup.POST("/get/anonymous", s.AuthHandler.GetUserByAnonymousID) + } + return r } @@ -44,6 +51,7 @@ func (s *Server) healthHandler(c *gin.Context) { stats["error"] = fmt.Sprintf("db down: %v", err) log.Fatalf(fmt.Sprintf("db down: %v", err)) // Log the error and terminate the program c.JSON(http.StatusInternalServerError, stats) + return } // Database is up, add more statistics diff --git a/internal/server/server.go b/internal/server/server.go index ed71a35..4b50e18 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,35 +10,34 @@ import ( "github.com/jackc/pgx/v5/pgxpool" _ "github.com/joho/godotenv/autoload" + "triton-backend/internal/controllers/v1/auth" "triton-backend/internal/controllers/v1/pastebin" "triton-backend/internal/database" ) type Server struct { - port int - + port int db *pgxpool.Pool PastebinHandler *pastebin.PastebinHandler + AuthHandler *auth.AuthHandler } func NewServer() *http.Server { port, _ := strconv.Atoi(os.Getenv("PORT")) db := database.NewService() - NewServer := &Server{ - port: port, - + server := &Server{ + port: port, db: db, PastebinHandler: pastebin.Handler(db), + AuthHandler: auth.Handler(db), } - // Declare Server config - server := &http.Server{ - Addr: fmt.Sprintf("localhost:%d", NewServer.port), - Handler: NewServer.RegisterRoutes(), + // Create a new HTTP server + return &http.Server{ + Addr: fmt.Sprintf("localhost:%d", server.port), + Handler: server.RegisterRoutes(), IdleTimeout: time.Minute, ReadTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, } - - return server } From 8a17c891a4143ae8000a4978fde1792b2d39fb2b Mon Sep 17 00:00:00 2001 From: souvik03-136 <66234771+souvik03-136@users.noreply.github.com> Date: Sat, 17 Aug 2024 01:01:41 +0530 Subject: [PATCH 2/3] update --- .env.example | 5 + go.mod | 3 + go.sum | 6 + internal/controllers/v1/auth/handlers.go | 191 +++++++++++++++++- .../000005_create_table_tokens.down.sql | 1 + .../000005_create_table_tokens.up.sql | 9 + internal/database/models.go | 10 + internal/database/queries/token.sql | 8 + internal/database/token.sql.go | 33 +++ internal/server/routes.go | 19 ++ 10 files changed, 282 insertions(+), 3 deletions(-) create mode 100644 internal/database/migrations/000005_create_table_tokens.down.sql create mode 100644 internal/database/migrations/000005_create_table_tokens.up.sql create mode 100644 internal/database/queries/token.sql create mode 100644 internal/database/token.sql.go diff --git a/.env.example b/.env.example index bae3fec..884f3cf 100644 --- a/.env.example +++ b/.env.example @@ -7,3 +7,8 @@ DB_DATABASE=blueprint DB_USERNAME=melkey DB_PASSWORD=password1234 DB_SCHEMA=public + +OAUTH_CLIENT_ID=your_oauth_client_id_here +GOOGLE_CLIENT_ID=your_google_client_id_here +GOOGLE_CLIENT_SECRET=your_google_client_secret_here +GOOGLE_REDIRECT_URI=your_google_redirect_uri_here diff --git a/go.mod b/go.mod index c5d2505..5bc93f4 100644 --- a/go.mod +++ b/go.mod @@ -11,10 +11,12 @@ require ( ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/bytedance/sonic v1.12.1 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/gabriel-vasile/mimetype v1.4.5 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -40,6 +42,7 @@ require ( golang.org/x/arch v0.9.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/oauth2 v0.22.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/go.sum b/go.sum index e49a530..f4a1cc4 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -11,6 +13,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/gabriel-vasile/mimetype v1.4.5 h1:J7wGKdGu33ocBOhGy0z653k/lFKLFDPJMG8Gql0kxn4= github.com/gabriel-vasile/mimetype v1.4.5/go.mod h1:ibHel+/kbxn9x2407k1izTA1S81ku1z/DlgOW2QE0M4= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -95,6 +99,8 @@ golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= +golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/controllers/v1/auth/handlers.go b/internal/controllers/v1/auth/handlers.go index a4871a5..36a55cd 100644 --- a/internal/controllers/v1/auth/handlers.go +++ b/internal/controllers/v1/auth/handlers.go @@ -1,8 +1,12 @@ package auth import ( + "context" + "encoding/json" "errors" + "io" "net/http" + "os" "triton-backend/internal/database" "triton-backend/internal/merrors" "triton-backend/internal/utils" @@ -13,18 +17,126 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" ) type AuthHandler struct { - db *pgxpool.Pool + db *pgxpool.Pool + googleOauthConfig *oauth2.Config + oauthStateString string } func Handler(db *pgxpool.Pool) *AuthHandler { return &AuthHandler{ db: db, + googleOauthConfig: &oauth2.Config{ + ClientID: os.Getenv("GOOGLE_CLIENT_ID"), + ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"), + RedirectURL: os.Getenv("GOOGLE_REDIRECT_URI"), + Scopes: []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}, + Endpoint: google.Endpoint, + }, + oauthStateString: "random_state_string", // Ideally, generate dynamically } } +// GoogleLoginHandler initiates the OAuth2 login process +func (a *AuthHandler) GoogleLoginHandler(c *gin.Context) { + url := a.googleOauthConfig.AuthCodeURL(a.oauthStateString, oauth2.AccessTypeOffline) + c.Redirect(http.StatusTemporaryRedirect, url) +} + +// GoogleCallbackHandler handles the callback from Google +func (a *AuthHandler) GoogleCallbackHandler(c *gin.Context) { + state := c.Request.FormValue("state") + if state != a.oauthStateString { + merrors.Validation(c, "Invalid OAuth state") + return + } + + code := c.Request.FormValue("code") + if code == "" { + merrors.Validation(c, "Code not found") + return + } + + token, err := a.googleOauthConfig.Exchange(context.Background(), code) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + client := a.googleOauthConfig.Client(context.Background(), token) + response, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + defer response.Body.Close() + + userInfo, err := io.ReadAll(response.Body) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + var user map[string]interface{} + err = json.Unmarshal(userInfo, &user) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + oauthID := user["id"].(string) + + // Check if the user exists + qtx := database.New(a.db) + userUUID, err := qtx.GetUserByOAuthID(c, database.GetUserByOAuthIDParams{ + AuthType: "oauth", + OauthID: oauthID, + }) + + if errors.Is(err, pgx.ErrNoRows) { + // If not, register a new user + userUUID = uuid.New() + tx, err := a.db.Begin(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + defer tx.Rollback(c) + + qtx = qtx.WithTx(tx) + err = qtx.CreateUser(c, database.CreateUserParams{ + Uuid: userUUID, + AuthType: "oauth", + OauthID: oauthID, + }) + var e *pgconn.PgError + if errors.As(err, &e) && e.Code == pgerrcode.UniqueViolation { + merrors.Validation(c, "User already exists with this OAuth ID!") + return + } else if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + err = tx.Commit(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "OAuth user successfully authenticated", + Data: userUUID, + StatusCode: http.StatusOK, + }) +} + func (a *AuthHandler) RegisterOAuthUser(c *gin.Context) { var input struct { OAuthID string `json:"oauth_id" binding:"required"` @@ -55,7 +167,7 @@ func (a *AuthHandler) RegisterOAuthUser(c *gin.Context) { }) var e *pgconn.PgError if errors.As(err, &e) && e.Code == pgerrcode.UniqueViolation { - merrors.Validation(c, "user already exists with this OAuth ID!") + merrors.Validation(c, "User already exists with this OAuth ID!") return } else if err != nil { merrors.InternalServer(c, err.Error()) @@ -91,7 +203,7 @@ func (a *AuthHandler) GetUserByOAuthID(c *gin.Context) { OauthID: input.OAuthID, }) if errors.Is(err, pgx.ErrNoRows) { - merrors.NotFound(c, "user not found!") + merrors.NotFound(c, "User not found!") return } else if err != nil { merrors.InternalServer(c, err.Error()) @@ -177,3 +289,76 @@ func (a *AuthHandler) GetUserByAnonymousID(c *gin.Context) { StatusCode: http.StatusOK, }) } +func (a *AuthHandler) LogoutHandler(c *gin.Context) { + // Assuming you use a token-based authentication mechanism like JWT + + // Invalidate the token (You might need to remove the token from a store or mark it as invalid in the DB) + token := c.Request.Header.Get("Authorization") + if token == "" { + merrors.Validation(c, "Authorization token required") + return + } + + // Example: remove token from the database + err := a.invalidateToken(c, token) + if err != nil { + merrors.InternalServer(c, "Error invalidating token") + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "Successfully logged out", + StatusCode: http.StatusOK, + }) +} +func (a *AuthHandler) invalidateToken(ctx context.Context, token string) error { + // Create a new instance of database.Queries + q := database.New(a.db) + + // Invalidate the token (e.g., delete it from the database) + err := q.DeleteTokenByToken(ctx, token) + if err != nil { + return err + } + + return nil +} + +func (a *AuthHandler) RefreshTokenHandler(c *gin.Context) { + var input struct { + RefreshToken string `json:"refresh_token" binding:"required"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + merrors.Validation(c, err.Error()) + return + } + + // Validate and refresh the token + newToken, err := a.refreshAccessToken(c, input.RefreshToken) + if err != nil { + merrors.Unauthorized(c, "Invalid or expired refresh token") + return + } + + c.JSON(http.StatusOK, utils.BaseResponse{ + Success: true, + Message: "Token refreshed successfully", + Data: newToken, + StatusCode: http.StatusOK, + }) +} + +func (a *AuthHandler) refreshAccessToken(ctx context.Context, refreshToken string) (string, error) { + // Create a new instance of database.Queries + q := database.New(a.db) + + // Get the new access token using the refresh token + newToken, err := q.GetNewAccessTokenByRefreshToken(ctx, refreshToken) + if err != nil { + return "", err + } + + return newToken, nil +} diff --git a/internal/database/migrations/000005_create_table_tokens.down.sql b/internal/database/migrations/000005_create_table_tokens.down.sql new file mode 100644 index 0000000..1029218 --- /dev/null +++ b/internal/database/migrations/000005_create_table_tokens.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS tokens; diff --git a/internal/database/migrations/000005_create_table_tokens.up.sql b/internal/database/migrations/000005_create_table_tokens.up.sql new file mode 100644 index 0000000..33e2197 --- /dev/null +++ b/internal/database/migrations/000005_create_table_tokens.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE tokens ( + id SERIAL PRIMARY KEY, + token TEXT NOT NULL UNIQUE, + user_uuid UUID REFERENCES users(uuid) ON DELETE CASCADE, + refresh_token TEXT NOT NULL, + new_access_token TEXT NOT NULL, + is_valid BOOLEAN DEFAULT true, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); diff --git a/internal/database/models.go b/internal/database/models.go index b7a1052..65afb75 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -18,6 +18,16 @@ type Pastebin struct { Extension string `json:"extension"` } +type Token struct { + ID int32 `json:"id"` + Token string `json:"token"` + UserUuid pgtype.UUID `json:"user_uuid"` + RefreshToken string `json:"refresh_token"` + NewAccessToken string `json:"new_access_token"` + IsValid *bool `json:"is_valid"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type Url struct { UrlUuid uuid.UUID `json:"url_uuid"` UrlName interface{} `json:"url_name"` diff --git a/internal/database/queries/token.sql b/internal/database/queries/token.sql new file mode 100644 index 0000000..7d5313a --- /dev/null +++ b/internal/database/queries/token.sql @@ -0,0 +1,8 @@ +-- name: DeleteTokenByToken :exec +DELETE FROM tokens +WHERE token = @token; + +-- name: GetNewAccessTokenByRefreshToken :one +SELECT new_access_token +FROM tokens +WHERE refresh_token = @refresh_token AND is_valid = true; diff --git a/internal/database/token.sql.go b/internal/database/token.sql.go new file mode 100644 index 0000000..fd9885a --- /dev/null +++ b/internal/database/token.sql.go @@ -0,0 +1,33 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: token.sql + +package database + +import ( + "context" +) + +const deleteTokenByToken = `-- name: DeleteTokenByToken :exec +DELETE FROM tokens +WHERE token = $1 +` + +func (q *Queries) DeleteTokenByToken(ctx context.Context, token string) error { + _, err := q.db.Exec(ctx, deleteTokenByToken, token) + return err +} + +const getNewAccessTokenByRefreshToken = `-- name: GetNewAccessTokenByRefreshToken :one +SELECT new_access_token +FROM tokens +WHERE refresh_token = $1 AND is_valid = true +` + +func (q *Queries) GetNewAccessTokenByRefreshToken(ctx context.Context, refreshToken string) (string, error) { + row := q.db.QueryRow(ctx, getNewAccessTokenByRefreshToken, refreshToken) + var new_access_token string + err := row.Scan(&new_access_token) + return new_access_token, err +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 6931e58..c7e731d 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -23,10 +23,29 @@ func (s *Server) RegisterRoutes() http.Handler { authGroup := v1.Group("/auth") { + // OAuth Login + authGroup.GET("/login/oauth", s.AuthHandler.GoogleLoginHandler) + + // OAuth Callback + authGroup.GET("/callback/oauth", s.AuthHandler.GoogleCallbackHandler) + + // Register User via OAuth authGroup.POST("/register/oauth", s.AuthHandler.RegisterOAuthUser) + + // Register User Anonymously authGroup.POST("/register/anonymous", s.AuthHandler.RegisterAnonymousUser) + + // Get User by OAuth ID authGroup.POST("/get/oauth", s.AuthHandler.GetUserByOAuthID) + + // Get User by Anonymous ID authGroup.POST("/get/anonymous", s.AuthHandler.GetUserByAnonymousID) + + // Logout + authGroup.POST("/logout", s.AuthHandler.LogoutHandler) // Needs implementation + + // Refresh OAuth Token + authGroup.POST("/token/refresh", s.AuthHandler.RefreshTokenHandler) // Needs implementation } return r From 4a8d00e444bff853499af7c06ad4d8b3f31e255e Mon Sep 17 00:00:00 2001 From: Aditya-Chowdhary Date: Mon, 19 Aug 2024 01:58:10 +0530 Subject: [PATCH 3/3] feat: auth middleware --- go.mod | 5 +- go.sum | 12 +-- internal/controllers/v1/auth/handlers.go | 96 +++++++++++-------- internal/controllers/v1/auth/tokens.go | 52 ++++++++++ internal/controllers/v1/pastebin/handlers.go | 31 ++++-- internal/database/auth.sql.go | 52 ++++++++-- internal/database/db.go | 2 +- ...sql => 000002_create_table_users.down.sql} | 0 ...p.sql => 000002_create_table_users.up.sql} | 2 +- ...n.sql => 000003_create_table_url.down.sql} | 0 ....up.sql => 000003_create_table_url.up.sql} | 0 ... => 000004_create_table_pastebin.down.sql} | 0 ...ql => 000004_create_table_pastebin.up.sql} | 3 +- .../000005_create_table_tokens.up.sql | 23 +++-- internal/database/models.go | 15 ++- internal/database/pastebin.sql.go | 18 ++-- internal/database/queries/auth.sql | 16 +++- internal/database/queries/pastebin.sql | 2 +- internal/database/queries/token.sql | 18 +++- internal/database/token.sql.go | 53 +++++++--- internal/database/url.sql.go | 2 +- internal/server/middleware.go | 83 ++++++++++++++++ internal/server/routes.go | 19 +--- sqlc.yaml | 5 + 24 files changed, 375 insertions(+), 134 deletions(-) create mode 100644 internal/controllers/v1/auth/tokens.go rename internal/database/migrations/{000004_create_table_users.down.sql => 000002_create_table_users.down.sql} (100%) rename internal/database/migrations/{000004_create_table_users.up.sql => 000002_create_table_users.up.sql} (83%) rename internal/database/migrations/{000002_create_table_url.down.sql => 000003_create_table_url.down.sql} (100%) rename internal/database/migrations/{000002_create_table_url.up.sql => 000003_create_table_url.up.sql} (100%) rename internal/database/migrations/{000003_create_table_pastebin.down.sql => 000004_create_table_pastebin.down.sql} (100%) rename internal/database/migrations/{000003_create_table_pastebin.up.sql => 000004_create_table_pastebin.up.sql} (65%) create mode 100644 internal/server/middleware.go diff --git a/go.mod b/go.mod index 5bc93f4..7a867c1 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx/v5 v5.6.0 github.com/joho/godotenv v1.5.1 + golang.org/x/oauth2 v0.22.0 ) require ( @@ -16,7 +17,6 @@ require ( github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/gabriel-vasile/mimetype v1.4.5 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -25,7 +25,6 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx v3.6.2+incompatible // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect @@ -35,14 +34,12 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.9.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/net v0.28.0 // indirect - golang.org/x/oauth2 v0.22.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/go.sum b/go.sum index f4a1cc4..1f7c3b1 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/gabriel-vasile/mimetype v1.4.5 h1:J7wGKdGu33ocBOhGy0z653k/lFKLFDPJMG8Gql0kxn4= github.com/gabriel-vasile/mimetype v1.4.5/go.mod h1:ibHel+/kbxn9x2407k1izTA1S81ku1z/DlgOW2QE0M4= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -31,8 +29,8 @@ github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4 github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -42,8 +40,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= -github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= @@ -71,8 +67,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= @@ -109,8 +103,6 @@ golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/controllers/v1/auth/handlers.go b/internal/controllers/v1/auth/handlers.go index 36a55cd..c689a65 100644 --- a/internal/controllers/v1/auth/handlers.go +++ b/internal/controllers/v1/auth/handlers.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "time" "triton-backend/internal/database" "triton-backend/internal/merrors" "triton-backend/internal/utils" @@ -14,8 +15,9 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/jackc/pgerrcode" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -61,13 +63,13 @@ func (a *AuthHandler) GoogleCallbackHandler(c *gin.Context) { return } - token, err := a.googleOauthConfig.Exchange(context.Background(), code) + token, err := a.googleOauthConfig.Exchange(c, code) if err != nil { merrors.InternalServer(c, err.Error()) return } - client := a.googleOauthConfig.Client(context.Background(), token) + client := a.googleOauthConfig.Client(c, token) response, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") if err != nil { merrors.InternalServer(c, err.Error()) @@ -81,35 +83,34 @@ func (a *AuthHandler) GoogleCallbackHandler(c *gin.Context) { return } - var user map[string]interface{} - err = json.Unmarshal(userInfo, &user) + var googleUser map[string]interface{} + err = json.Unmarshal(userInfo, &googleUser) if err != nil { merrors.InternalServer(c, err.Error()) return } - oauthID := user["id"].(string) + oauthID := googleUser["id"].(string) + + var userUUID uuid.UUID + + tx, err := a.db.Begin(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + defer tx.Rollback(c) + qtx := database.New(a.db).WithTx(tx) // Check if the user exists - qtx := database.New(a.db) - userUUID, err := qtx.GetUserByOAuthID(c, database.GetUserByOAuthIDParams{ + userUUID, err = qtx.GetUserByOAuthID(c, database.GetUserByOAuthIDParams{ AuthType: "oauth", OauthID: oauthID, }) if errors.Is(err, pgx.ErrNoRows) { // If not, register a new user - userUUID = uuid.New() - tx, err := a.db.Begin(c) - if err != nil { - merrors.InternalServer(c, err.Error()) - return - } - defer tx.Rollback(c) - - qtx = qtx.WithTx(tx) - err = qtx.CreateUser(c, database.CreateUserParams{ - Uuid: userUUID, + userUUID, err = qtx.CreateUser(c, database.CreateUserParams{ AuthType: "oauth", OauthID: oauthID, }) @@ -121,18 +122,35 @@ func (a *AuthHandler) GoogleCallbackHandler(c *gin.Context) { merrors.InternalServer(c, err.Error()) return } + } - err = tx.Commit(c) - if err != nil { - merrors.InternalServer(c, err.Error()) - return - } + tok, err := generateToken(userUUID, 24*time.Hour, "authentication") + if err != nil { + merrors.InternalServer(c, "error in token generation") + return + } + + err = qtx.CreateNewToken(c, database.CreateNewTokenParams{ + Hash: tok.Hash, + UserUuid: tok.UserID, + Expiry: pgtype.Timestamptz{Time: tok.Expiry, Valid: true}, + Scope: tok.Scope, + }) + if err != nil { + merrors.InternalServer(c, err.Error()) + return + } + + err = tx.Commit(c) + if err != nil { + merrors.InternalServer(c, err.Error()) + return } c.JSON(http.StatusOK, utils.BaseResponse{ Success: true, Message: "OAuth user successfully authenticated", - Data: userUUID, + Data: tok, StatusCode: http.StatusOK, }) } @@ -156,12 +174,8 @@ func (a *AuthHandler) RegisterOAuthUser(c *gin.Context) { qtx := database.New(a.db).WithTx(tx) - // Create a new user UUID - userUUID := uuid.New() - // Try to create a new user in the database - err = qtx.CreateUser(c, database.CreateUserParams{ - Uuid: userUUID, + _, err = qtx.CreateUser(c, database.CreateUserParams{ AuthType: "oauth", OauthID: input.OAuthID, }) @@ -232,8 +246,7 @@ func (a *AuthHandler) RegisterAnonymousUser(c *gin.Context) { qtx := database.New(a.db).WithTx(tx) // Try to create a new user in the database - err = qtx.CreateUser(c, database.CreateUserParams{ - Uuid: userUUID, + _, err = qtx.CreateUser(c, database.CreateUserParams{ AuthType: "anonymous", }) var e *pgconn.PgError @@ -289,18 +302,20 @@ func (a *AuthHandler) GetUserByAnonymousID(c *gin.Context) { StatusCode: http.StatusOK, }) } + func (a *AuthHandler) LogoutHandler(c *gin.Context) { // Assuming you use a token-based authentication mechanism like JWT // Invalidate the token (You might need to remove the token from a store or mark it as invalid in the DB) - token := c.Request.Header.Get("Authorization") - if token == "" { - merrors.Validation(c, "Authorization token required") + u, _ := c.Get("user") + user, _ := u.(*database.GetUserByTokenRow) + if user == AnonymousUser { + merrors.Unauthorized(c, "You are not logged in") return } // Example: remove token from the database - err := a.invalidateToken(c, token) + err := a.invalidateToken(c, user.Uuid) if err != nil { merrors.InternalServer(c, "Error invalidating token") return @@ -312,12 +327,13 @@ func (a *AuthHandler) LogoutHandler(c *gin.Context) { StatusCode: http.StatusOK, }) } -func (a *AuthHandler) invalidateToken(ctx context.Context, token string) error { + +func (a *AuthHandler) invalidateToken(ctx context.Context, userUUID uuid.UUID) error { // Create a new instance of database.Queries q := database.New(a.db) // Invalidate the token (e.g., delete it from the database) - err := q.DeleteTokenByToken(ctx, token) + err := q.DeleteTokenForUser(ctx, userUUID) if err != nil { return err } @@ -325,6 +341,8 @@ func (a *AuthHandler) invalidateToken(ctx context.Context, token string) error { return nil } +/* + func (a *AuthHandler) RefreshTokenHandler(c *gin.Context) { var input struct { RefreshToken string `json:"refresh_token" binding:"required"` @@ -362,3 +380,5 @@ func (a *AuthHandler) refreshAccessToken(ctx context.Context, refreshToken strin return newToken, nil } + +*/ diff --git a/internal/controllers/v1/auth/tokens.go b/internal/controllers/v1/auth/tokens.go new file mode 100644 index 0000000..58ec41a --- /dev/null +++ b/internal/controllers/v1/auth/tokens.go @@ -0,0 +1,52 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base32" + "time" + "triton-backend/internal/database" + + "github.com/google/uuid" +) + +type Token struct { + Plaintext string `json:"token"` + Hash []byte `json:"-"` + UserID uuid.UUID `json:"-"` + Expiry time.Time `json:"expiry"` + Scope string `json:"-"` +} + +var AnonymousUser = &database.GetUserByTokenRow{} + +func ValidateTokenPlaintext(tokenPlaintext string) (bool, string) { + if tokenPlaintext == "" { + return false, "token must be provided" + } + if len(tokenPlaintext) != 26 { + return false, "token must be 26 bytes long" + } + return true, "" +} + +func generateToken(userID uuid.UUID, ttl time.Duration, scope string) (*Token, error) { + token := &Token{ + UserID: userID, + Expiry: time.Now().Add(ttl), + Scope: scope, + } + + randomBytes := make([]byte, 16) + + _, err := rand.Read(randomBytes) + if err != nil { + return nil, err + } + + token.Plaintext = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes) + + hash := sha256.Sum256([]byte(token.Plaintext)) + token.Hash = hash[:] + return token, nil +} diff --git a/internal/controllers/v1/pastebin/handlers.go b/internal/controllers/v1/pastebin/handlers.go index e1eeb36..1f21fea 100644 --- a/internal/controllers/v1/pastebin/handlers.go +++ b/internal/controllers/v1/pastebin/handlers.go @@ -3,13 +3,13 @@ package pastebin import ( "errors" "net/http" + "triton-backend/internal/controllers/v1/auth" "triton-backend/internal/database" "triton-backend/internal/merrors" "triton-backend/internal/utils" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" - "github.com/google/uuid" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -28,11 +28,10 @@ func Handler(db *pgxpool.Pool) *PastebinHandler { func (p *PastebinHandler) CreatePastebin(c *gin.Context) { var input struct { - UserUUID uuid.UUID `json:"user_uuid" binding:"required,uuid"` - Title string `json:"title" binding:"required"` - Content string `json:"content" binding:"required"` - Extension string `json:"extension" binding:"required"` - URL string `json:"url" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Extension string `json:"extension" binding:"required"` + URL string `json:"url" binding:"required"` } binding.EnableDecoderDisallowUnknownFields = true err := c.ShouldBindJSON(&input) @@ -41,6 +40,13 @@ func (p *PastebinHandler) CreatePastebin(c *gin.Context) { return } + u, _ := c.Get("user") + user, _ := u.(*database.GetUserByTokenRow) + if user == auth.AnonymousUser { + merrors.Unauthorized(c, "You are not logged in") + return + } + tx, err := p.db.Begin(c) if err != nil { merrors.InternalServer(c, err.Error()) @@ -60,7 +66,7 @@ func (p *PastebinHandler) CreatePastebin(c *gin.Context) { } pastebin, err := qtx.CreatePastebin(c, database.CreatePastebinParams{ - UserUuid: input.UserUUID, + UserUuid: user.Uuid, Title: input.Title, Content: input.Content, UrlUuid: url_uuid, @@ -105,10 +111,19 @@ func (p *PastebinHandler) GetPastebin(c *gin.Context) { return } + owner := false + u, _ := c.Get("user") + user, _ := u.(*database.GetUserByTokenRow) + if user == auth.AnonymousUser { + owner = false + } else if pastebin.UserUuid == user.Uuid { + owner = true + } + c.JSON(http.StatusOK, utils.BaseResponse{ Success: true, Message: "Pastebin successfully retrieved", - Data: pastebin, + Data: gin.H{"pastebin": pastebin, "owner": owner}, StatusCode: http.StatusOK, }) diff --git a/internal/database/auth.sql.go b/internal/database/auth.sql.go index 3212685..b411d57 100644 --- a/internal/database/auth.sql.go +++ b/internal/database/auth.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 // source: auth.sql package database @@ -9,22 +9,25 @@ import ( "context" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) -const createUser = `-- name: CreateUser :exec -INSERT INTO users (uuid, auth_type, oauth_id) -VALUES ($1, $2, $3) +const createUser = `-- name: CreateUser :one +INSERT INTO users (auth_type, oauth_id) +VALUES ($1, $2) +RETURNING uuid ` type CreateUserParams struct { - Uuid uuid.UUID `json:"uuid"` - AuthType string `json:"auth_type"` - OauthID string `json:"oauth_id"` + AuthType string `json:"auth_type"` + OauthID string `json:"oauth_id"` } -func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) error { - _, err := q.db.Exec(ctx, createUser, arg.Uuid, arg.AuthType, arg.OauthID) - return err +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (uuid.UUID, error) { + row := q.db.QueryRow(ctx, createUser, arg.AuthType, arg.OauthID) + var uuid uuid.UUID + err := row.Scan(&uuid) + return uuid, err } const getUserByOAuthID = `-- name: GetUserByOAuthID :one @@ -44,3 +47,32 @@ func (q *Queries) GetUserByOAuthID(ctx context.Context, arg GetUserByOAuthIDPara err := row.Scan(&uuid) return uuid, err } + +const getUserByToken = `-- name: GetUserByToken :one +SELECT users.uuid, users.created_at, users.oauth_id +FROM users +INNER JOIN tokens +ON users.uuid = tokens.user_uuid +WHERE tokens.hash = $1 +AND tokens.scope = $2 +AND tokens.expiry > $3 +` + +type GetUserByTokenParams struct { + Hash []byte `json:"hash"` + Scope string `json:"scope"` + Expiry pgtype.Timestamptz `json:"expiry"` +} + +type GetUserByTokenRow struct { + Uuid uuid.UUID `json:"uuid"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + OauthID string `json:"oauth_id"` +} + +func (q *Queries) GetUserByToken(ctx context.Context, arg GetUserByTokenParams) (GetUserByTokenRow, error) { + row := q.db.QueryRow(ctx, getUserByToken, arg.Hash, arg.Scope, arg.Expiry) + var i GetUserByTokenRow + err := row.Scan(&i.Uuid, &i.CreatedAt, &i.OauthID) + return i, err +} diff --git a/internal/database/db.go b/internal/database/db.go index 8187a2b..1d02744 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 package database diff --git a/internal/database/migrations/000004_create_table_users.down.sql b/internal/database/migrations/000002_create_table_users.down.sql similarity index 100% rename from internal/database/migrations/000004_create_table_users.down.sql rename to internal/database/migrations/000002_create_table_users.down.sql diff --git a/internal/database/migrations/000004_create_table_users.up.sql b/internal/database/migrations/000002_create_table_users.up.sql similarity index 83% rename from internal/database/migrations/000004_create_table_users.up.sql rename to internal/database/migrations/000002_create_table_users.up.sql index cfad536..1511d30 100644 --- a/internal/database/migrations/000004_create_table_users.up.sql +++ b/internal/database/migrations/000002_create_table_users.up.sql @@ -1,5 +1,5 @@ CREATE TABLE IF NOT EXISTS public.users ( - uuid uuid NOT NULL, + uuid uuid DEFAULT uuid_generate_v4() NOT NULL, created_at timestamptz DEFAULT now() NOT NULL, auth_type text NOT NULL, oauth_id text NOT NULL, diff --git a/internal/database/migrations/000002_create_table_url.down.sql b/internal/database/migrations/000003_create_table_url.down.sql similarity index 100% rename from internal/database/migrations/000002_create_table_url.down.sql rename to internal/database/migrations/000003_create_table_url.down.sql diff --git a/internal/database/migrations/000002_create_table_url.up.sql b/internal/database/migrations/000003_create_table_url.up.sql similarity index 100% rename from internal/database/migrations/000002_create_table_url.up.sql rename to internal/database/migrations/000003_create_table_url.up.sql diff --git a/internal/database/migrations/000003_create_table_pastebin.down.sql b/internal/database/migrations/000004_create_table_pastebin.down.sql similarity index 100% rename from internal/database/migrations/000003_create_table_pastebin.down.sql rename to internal/database/migrations/000004_create_table_pastebin.down.sql diff --git a/internal/database/migrations/000003_create_table_pastebin.up.sql b/internal/database/migrations/000004_create_table_pastebin.up.sql similarity index 65% rename from internal/database/migrations/000003_create_table_pastebin.up.sql rename to internal/database/migrations/000004_create_table_pastebin.up.sql index 0d5e231..840c23f 100644 --- a/internal/database/migrations/000003_create_table_pastebin.up.sql +++ b/internal/database/migrations/000004_create_table_pastebin.up.sql @@ -6,5 +6,6 @@ CREATE TABLE IF NOT EXISTS public.pastebin ( url_uuid uuid NOT NULL, extension text NOT NULL, CONSTRAINT pastebin_unique UNIQUE (url_uuid), - CONSTRAINT pastebin_url_fk FOREIGN KEY (url_uuid) REFERENCES public.url(url_uuid) ON DELETE CASCADE ON UPDATE CASCADE + CONSTRAINT pastebin_url_fk FOREIGN KEY (url_uuid) REFERENCES public.url(url_uuid) ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT pastebin_user_fk FOREIGN KEY (user_uuid) REFERENCES public.users(uuid) ON DELETE CASCADE ON UPDATE CASCADE ); \ No newline at end of file diff --git a/internal/database/migrations/000005_create_table_tokens.up.sql b/internal/database/migrations/000005_create_table_tokens.up.sql index 33e2197..3d3efd0 100644 --- a/internal/database/migrations/000005_create_table_tokens.up.sql +++ b/internal/database/migrations/000005_create_table_tokens.up.sql @@ -1,9 +1,18 @@ -CREATE TABLE tokens ( +-- CREATE TABLE tokens ( +-- id SERIAL PRIMARY KEY, +-- token TEXT NOT NULL UNIQUE, +-- user_uuid UUID REFERENCES users(uuid) ON DELETE CASCADE, +-- refresh_token TEXT NOT NULL, +-- new_access_token TEXT NOT NULL, +-- is_valid BOOLEAN DEFAULT true, +-- created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +-- ); + +CREATE TABLE IF NOT EXISTS tokens ( id SERIAL PRIMARY KEY, - token TEXT NOT NULL UNIQUE, - user_uuid UUID REFERENCES users(uuid) ON DELETE CASCADE, - refresh_token TEXT NOT NULL, - new_access_token TEXT NOT NULL, - is_valid BOOLEAN DEFAULT true, + hash bytea NOT NULL, + user_uuid uuid REFERENCES users(uuid) ON DELETE CASCADE, + expiry timestamp(0) with time zone NOT NULL, + scope text NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP -); +); \ No newline at end of file diff --git a/internal/database/models.go b/internal/database/models.go index 65afb75..af9bdfe 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 package database @@ -19,13 +19,12 @@ type Pastebin struct { } type Token struct { - ID int32 `json:"id"` - Token string `json:"token"` - UserUuid pgtype.UUID `json:"user_uuid"` - RefreshToken string `json:"refresh_token"` - NewAccessToken string `json:"new_access_token"` - IsValid *bool `json:"is_valid"` - CreatedAt pgtype.Timestamptz `json:"created_at"` + ID int32 `json:"id"` + Hash []byte `json:"hash"` + UserUuid uuid.UUID `json:"user_uuid"` + Expiry pgtype.Timestamptz `json:"expiry"` + Scope string `json:"scope"` + CreatedAt pgtype.Timestamptz `json:"created_at"` } type Url struct { diff --git a/internal/database/pastebin.sql.go b/internal/database/pastebin.sql.go index 6629168..abd2e68 100644 --- a/internal/database/pastebin.sql.go +++ b/internal/database/pastebin.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 // source: pastebin.sql package database @@ -46,7 +46,7 @@ func (q *Queries) CreatePastebin(ctx context.Context, arg CreatePastebinParams) } const getPastebin = `-- name: GetPastebin :one -SELECT p.title, p.content, p.extension +SELECT p.user_uuid, p.title, p.content, p.extension FROM pastebin p INNER JOIN url u ON p.url_uuid = u.url_uuid @@ -54,14 +54,20 @@ WHERE u.url_name = $1::text ` type GetPastebinRow struct { - Title string `json:"title"` - Content string `json:"content"` - Extension string `json:"extension"` + UserUuid uuid.UUID `json:"user_uuid"` + Title string `json:"title"` + Content string `json:"content"` + Extension string `json:"extension"` } func (q *Queries) GetPastebin(ctx context.Context, url string) (GetPastebinRow, error) { row := q.db.QueryRow(ctx, getPastebin, url) var i GetPastebinRow - err := row.Scan(&i.Title, &i.Content, &i.Extension) + err := row.Scan( + &i.UserUuid, + &i.Title, + &i.Content, + &i.Extension, + ) return i, err } diff --git a/internal/database/queries/auth.sql b/internal/database/queries/auth.sql index 1561867..ef9d662 100644 --- a/internal/database/queries/auth.sql +++ b/internal/database/queries/auth.sql @@ -1,8 +1,18 @@ --- name: CreateUser :exec -INSERT INTO users (uuid, auth_type, oauth_id) -VALUES (@uuid, @auth_type, @oauth_id); +-- name: CreateUser :one +INSERT INTO users (auth_type, oauth_id) +VALUES (@auth_type, @oauth_id) +RETURNING uuid; -- name: GetUserByOAuthID :one SELECT uuid FROM users WHERE auth_type = @auth_type AND oauth_id = @oauth_id; + +-- name: GetUserByToken :one +SELECT users.uuid, users.created_at, users.oauth_id +FROM users +INNER JOIN tokens +ON users.uuid = tokens.user_uuid +WHERE tokens.hash = $1 +AND tokens.scope = $2 +AND tokens.expiry > $3; \ No newline at end of file diff --git a/internal/database/queries/pastebin.sql b/internal/database/queries/pastebin.sql index 28e760a..459d9cc 100644 --- a/internal/database/queries/pastebin.sql +++ b/internal/database/queries/pastebin.sql @@ -4,7 +4,7 @@ VALUES (@user_uuid, @title, @content, @url_uuid, @extension) RETURNING *; -- name: GetPastebin :one -SELECT p.title, p.content, p.extension +SELECT p.user_uuid, p.title, p.content, p.extension FROM pastebin p INNER JOIN url u ON p.url_uuid = u.url_uuid diff --git a/internal/database/queries/token.sql b/internal/database/queries/token.sql index 7d5313a..abe6e52 100644 --- a/internal/database/queries/token.sql +++ b/internal/database/queries/token.sql @@ -1,8 +1,16 @@ -- name: DeleteTokenByToken :exec DELETE FROM tokens -WHERE token = @token; +WHERE hash = @hash; --- name: GetNewAccessTokenByRefreshToken :one -SELECT new_access_token -FROM tokens -WHERE refresh_token = @refresh_token AND is_valid = true; +-- -- name: GetNewAccessTokenByRefreshToken :one +-- SELECT new_access_token +-- FROM tokens +-- WHERE refresh_token = @refresh_token AND is_valid = true; + +-- name: CreateNewToken :exec +INSERT INTO tokens (hash,user_uuid,expiry,scope) +VALUES (@hash, @user_uuid, @expiry, @scope); + +-- name: DeleteTokenForUser :exec +DELETE FROM tokens +WHERE user_uuid = @user_uuid; \ No newline at end of file diff --git a/internal/database/token.sql.go b/internal/database/token.sql.go index fd9885a..7018f49 100644 --- a/internal/database/token.sql.go +++ b/internal/database/token.sql.go @@ -1,33 +1,60 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 // source: token.sql package database import ( "context" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) +const createNewToken = `-- name: CreateNewToken :exec + +INSERT INTO tokens (hash,user_uuid,expiry,scope) +VALUES ($1, $2, $3, $4) +` + +type CreateNewTokenParams struct { + Hash []byte `json:"hash"` + UserUuid uuid.UUID `json:"user_uuid"` + Expiry pgtype.Timestamptz `json:"expiry"` + Scope string `json:"scope"` +} + +// -- name: GetNewAccessTokenByRefreshToken :one +// SELECT new_access_token +// FROM tokens +// WHERE refresh_token = @refresh_token AND is_valid = true; +func (q *Queries) CreateNewToken(ctx context.Context, arg CreateNewTokenParams) error { + _, err := q.db.Exec(ctx, createNewToken, + arg.Hash, + arg.UserUuid, + arg.Expiry, + arg.Scope, + ) + return err +} + const deleteTokenByToken = `-- name: DeleteTokenByToken :exec DELETE FROM tokens -WHERE token = $1 +WHERE hash = $1 ` -func (q *Queries) DeleteTokenByToken(ctx context.Context, token string) error { - _, err := q.db.Exec(ctx, deleteTokenByToken, token) +func (q *Queries) DeleteTokenByToken(ctx context.Context, hash []byte) error { + _, err := q.db.Exec(ctx, deleteTokenByToken, hash) return err } -const getNewAccessTokenByRefreshToken = `-- name: GetNewAccessTokenByRefreshToken :one -SELECT new_access_token -FROM tokens -WHERE refresh_token = $1 AND is_valid = true +const deleteTokenForUser = `-- name: DeleteTokenForUser :exec +DELETE FROM tokens +WHERE user_uuid = $1 ` -func (q *Queries) GetNewAccessTokenByRefreshToken(ctx context.Context, refreshToken string) (string, error) { - row := q.db.QueryRow(ctx, getNewAccessTokenByRefreshToken, refreshToken) - var new_access_token string - err := row.Scan(&new_access_token) - return new_access_token, err +func (q *Queries) DeleteTokenForUser(ctx context.Context, userUuid uuid.UUID) error { + _, err := q.db.Exec(ctx, deleteTokenForUser, userUuid) + return err } diff --git a/internal/database/url.sql.go b/internal/database/url.sql.go index a7333af..473c96d 100644 --- a/internal/database/url.sql.go +++ b/internal/database/url.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.26.0 // source: url.sql package database diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..046669e --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,83 @@ +package server + +import ( + "crypto/sha256" + "errors" + "strings" + "time" + "triton-backend/internal/controllers/v1/auth" + "triton-backend/internal/database" + "triton-backend/internal/merrors" + + "github.com/gin-gonic/gin" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" +) + +func CORSMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } +} + +func (s *Server) authenticate() gin.HandlerFunc { + return func(ctx *gin.Context) { + ctx.Writer.Header().Add("Vary", "Authorization") + + authorizationHeader := ctx.Request.Header.Get("Authorization") + + if authorizationHeader == "" { + ctx.Set("user", auth.AnonymousUser) + ctx.Next() + return + } + + headerParts := strings.Split(authorizationHeader, " ") + if len(headerParts) != 2 || headerParts[0] != "Bearer" { + ctx.Writer.Header().Set("WWW-Authenticate", "Bearer") + merrors.Unauthorized(ctx, "invalid or missing authentication token") + return + } + + token := headerParts[1] + + if v, err := auth.ValidateTokenPlaintext(token); !v { + ctx.Writer.Header().Set("WWW-Authenticate", "Bearer") + merrors.Unauthorized(ctx, err) + return + } + hash := sha256.Sum256([]byte(token)) + + q := database.New(s.db) + + user, err := q.GetUserByToken(ctx, database.GetUserByTokenParams{ + Hash: hash[:], + Scope: "authentication", + Expiry: pgtype.Timestamptz{Time: time.Now(), Valid: true}, + }) + if err != nil { + switch { + case errors.Is(err, pgx.ErrNoRows): + ctx.Writer.Header().Set("WWW-Authenticate", "Bearer") + merrors.Unauthorized(ctx, "invalid or missing authentication token") + default: + merrors.InternalServer(ctx, err.Error()) + } + return + } + + ctx.Set("user", &user) + + ctx.Next() + } +} diff --git a/internal/server/routes.go b/internal/server/routes.go index c7e731d..6815947 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -18,6 +18,7 @@ func (s *Server) RegisterRoutes() http.Handler { r.GET("/health", s.healthHandler) v1 := r.Group("/v1") + v1.Use(s.authenticate()) v1.POST("/pastebin/create", s.PastebinHandler.CreatePastebin) v1.GET("/pastebin/:url", s.PastebinHandler.GetPastebin) @@ -45,7 +46,7 @@ func (s *Server) RegisterRoutes() http.Handler { authGroup.POST("/logout", s.AuthHandler.LogoutHandler) // Needs implementation // Refresh OAuth Token - authGroup.POST("/token/refresh", s.AuthHandler.RefreshTokenHandler) // Needs implementation + // authGroup.POST("/token/refresh", s.AuthHandler.RefreshTokenHandler) // Needs implementation } return r @@ -78,19 +79,3 @@ func (s *Server) healthHandler(c *gin.Context) { stats["message"] = "It's healthy" c.JSON(http.StatusOK, stats) } - -func CORSMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") - - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) - return - } - - c.Next() - } -} diff --git a/sqlc.yaml b/sqlc.yaml index 8982aa3..b771897 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -18,6 +18,11 @@ sql: json_tags_id_uppercase: true overrides: - db_type: "uuid" + go_type: + import: "github.com/google/uuid" + type: "UUID" + - db_type: "uuid" + nullable: true go_type: import: "github.com/google/uuid" type: "UUID" \ No newline at end of file