Skip to content

Commit

Permalink
Merge pull request #22 from TYuan0816/feature/oauth2
Browse files Browse the repository at this point in the history
Feature/oauth2
  • Loading branch information
ianchen0119 authored Feb 7, 2024
2 parents e70dc73 + b68ecce commit 48a6f17
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 29 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ require (
github.com/antihax/optional v1.0.0
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/bronze1man/radius v0.0.0-20190516032554-afd8baec892d
github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6
github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693
github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808
github.com/gin-gonic/gin v1.9.1
github.com/google/gopacket v1.1.19
github.com/google/uuid v1.3.0
github.com/pkg/errors v0.9.1
github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.8.3
github.com/urfave/cli v1.22.5
Expand Down Expand Up @@ -39,7 +40,6 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.0.1 // indirect
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6 h1:8P/wOkTAQMgZJe9pUUNSTE5PWeAdlMrsU9kLsI+VAVE=
github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA=
github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693 h1:gFyYBsErQAkx4OVHXYqjO0efO9gPWydQavQcjU0CkHY=
github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA=
github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808 h1:8/IoWEgcO2DLlLCqbsxwduD7CzXdKe/BFJU2tcAqnxo=
github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808/go.mod h1:d+79g84a3YHhzvjJ2IhurrBOavOA8xWIQ/GCywPXqQk=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
Expand Down
22 changes: 19 additions & 3 deletions internal/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ func Init() {
InitAusfContext(&ausfContext)
}

type NFContext interface {
AuthorizationCheck(token string, serviceName models.ServiceName) error
}

var _ NFContext = &AUSFContext{}

func NewAusfUeContext(identifier string) (ausfUeContext *AusfUeContext) {
ausfUeContext = new(AusfUeContext)
ausfUeContext.Supi = identifier // supi
Expand Down Expand Up @@ -160,12 +166,22 @@ func (a *AUSFContext) GetSelfID() string {
return a.NfId
}

func (c *AUSFContext) GetTokenCtx(scope, targetNF string) (
func (c *AUSFContext) GetTokenCtx(serviceName models.ServiceName, targetNF models.NfType) (
context.Context, *models.ProblemDetails, error,
) {
if !c.OAuth2Required {
return context.TODO(), nil, nil
}
return oauth.GetTokenCtx(models.NfType_AUSF,
c.NfId, c.NrfUri, scope, targetNF)
return oauth.GetTokenCtx(models.NfType_AUSF, targetNF,
c.NfId, c.NrfUri, string(serviceName))
}

func (c *AUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error {
if !c.OAuth2Required {
logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: OAuth2 not required\n")
return nil
}

logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: token[%s] serviceName[%s]\n", token, serviceName)
return oauth.VerifyOAuth(token, string(serviceName), c.NrfCertPem)
}
2 changes: 2 additions & 0 deletions internal/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var (
UeAuthLog *logrus.Entry
Auth5gAkaLog *logrus.Entry
AuthELog *logrus.Entry
UtilLog *logrus.Entry
)

func init() {
Expand All @@ -37,4 +38,5 @@ func init() {
UeAuthLog = NfLog.WithField(logger_util.FieldCategory, "UeAuth")
Auth5gAkaLog = NfLog.WithField(logger_util.FieldCategory, "5gAka")
AuthELog = NfLog.WithField(logger_util.FieldCategory, "Eap")
UtilLog = NfLog.WithField(logger_util.FieldCategory, "Util")
}
3 changes: 2 additions & 1 deletion internal/sbi/consumer/nf_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfType,
param Nnrf_NFDiscovery.SearchNFInstancesParamOpts,
) (*models.SearchResult, error) {
ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nnrf-disc", "NRF")
ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_DISC, models.NfType_NRF)
if err != nil {
return nil, err
}
Expand All @@ -24,6 +24,7 @@ func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfT

result, rsp, rspErr := client.NFInstancesStoreApi.SearchNFInstances(ctx,
targetNfType, requestNfType, &param)

if rspErr != nil {
return nil, fmt.Errorf("NFInstancesStoreApi Response error: %+w", rspErr)
}
Expand Down
10 changes: 7 additions & 3 deletions internal/sbi/consumer/nf_management.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package consumer

import (
"context"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -40,9 +39,14 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil
configuration.SetBasePath(nrfUri)
client := Nnrf_NFManagement.NewAPIClient(configuration)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF)
if err != nil {
return "", "", err
}

var res *http.Response
for {
nf, resTmp, err := client.NFInstanceIDDocumentApi.RegisterNFInstance(context.TODO(), nfInstanceId, profile)
nf, resTmp, err := client.NFInstanceIDDocumentApi.RegisterNFInstance(ctx, nfInstanceId, profile)
if err != nil || resTmp == nil {
logger.ConsumerLog.Errorf("AUSF register to NRF Error[%v]", err)
time.Sleep(2 * time.Second)
Expand Down Expand Up @@ -90,7 +94,7 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil
func SendDeregisterNFInstance() (*models.ProblemDetails, error) {
logger.ConsumerLog.Infof("Send Deregister NFInstance")

ctx, pd, err := ausf_context.GetSelf().GetTokenCtx("nnrf-nfm", "NRF")
ctx, pd, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF)
if err != nil {
return pd, err
}
Expand Down
9 changes: 7 additions & 2 deletions internal/sbi/producer/functions.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package producer

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
Expand Down Expand Up @@ -372,7 +371,13 @@ func sendAuthResultToUDM(id string, authType models.AuthType, success bool, serv
authEvent.NfInstanceId = self.GetSelfID()

client := createClientToUdmUeau(udmUrl)
_, rsp, confirmAuthErr := client.ConfirmAuthApi.ConfirmAuth(context.Background(), id, authEvent)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM)
if err != nil {
return err
}

_, rsp, confirmAuthErr := client.ConfirmAuthApi.ConfirmAuth(ctx, id, authEvent)
defer func() {
if rspCloseErr := rsp.Body.Close(); rspCloseErr != nil {
logger.ConsumerLog.Errorf("ConfirmAuth Response cannot close: %v", rspCloseErr)
Expand Down
9 changes: 7 additions & 2 deletions internal/sbi/producer/ue_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package producer

import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
Expand Down Expand Up @@ -124,7 +123,13 @@ func UeAuthPostRequestProcedure(updateAuthenticationInfo models.AuthenticationIn

udmUrl := getUdmUrl(self.NrfUri)
client := createClientToUdmUeau(udmUrl)
authInfoResult, rsp, err := client.GenerateAuthDataApi.GenerateAuthData(context.Background(), supiOrSuci, authInfoReq)

ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM)
if err != nil {
return nil, "", nil
}

authInfoResult, rsp, err := client.GenerateAuthDataApi.GenerateAuthData(ctx, supiOrSuci, authInfoReq)
if err != nil {
logger.UeAuthLog.Infoln(err.Error())
var problemDetails models.ProblemDetails
Expand Down
8 changes: 8 additions & 0 deletions internal/sbi/sorprotection/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (

"github.com/gin-gonic/gin"

ausf_context "github.com/free5gc/ausf/internal/context"
"github.com/free5gc/ausf/internal/logger"
"github.com/free5gc/ausf/internal/util"
"github.com/free5gc/ausf/pkg/factory"
"github.com/free5gc/openapi/models"
logger_util "github.com/free5gc/util/logger"
)

Expand Down Expand Up @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfSorprotectionResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_SORPROTECTION)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})

for _, route := range routes {
switch route.Method {
case "GET":
Expand Down
8 changes: 8 additions & 0 deletions internal/sbi/ueauthentication/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (

"github.com/gin-gonic/gin"

ausf_context "github.com/free5gc/ausf/internal/context"
"github.com/free5gc/ausf/internal/logger"
"github.com/free5gc/ausf/internal/util"
"github.com/free5gc/ausf/pkg/factory"
"github.com/free5gc/openapi/models"
logger_util "github.com/free5gc/util/logger"
)

Expand Down Expand Up @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfAuthResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_AUTH)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})

for _, route := range routes {
switch route.Method {
case "GET":
Expand Down
8 changes: 8 additions & 0 deletions internal/sbi/upuprotection/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (

"github.com/gin-gonic/gin"

ausf_context "github.com/free5gc/ausf/internal/context"
"github.com/free5gc/ausf/internal/logger"
"github.com/free5gc/ausf/internal/util"
"github.com/free5gc/ausf/pkg/factory"
"github.com/free5gc/openapi/models"
logger_util "github.com/free5gc/util/logger"
)

Expand Down Expand Up @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine {
func AddService(engine *gin.Engine) *gin.RouterGroup {
group := engine.Group(factory.AusfAuthResUriPrefix)

routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_UPUPROTECTION)
group.Use(func(c *gin.Context) {
routerAuthorizationCheck.Check(c, ausf_context.GetSelf())
})

for _, route := range routes {
switch route.Method {
case "GET":
Expand Down
34 changes: 34 additions & 0 deletions internal/util/router_auth_check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package util

import (
"net/http"

"github.com/gin-gonic/gin"

ausf_context "github.com/free5gc/ausf/internal/context"
"github.com/free5gc/ausf/internal/logger"
"github.com/free5gc/openapi/models"
)

type RouterAuthorizationCheck struct {
serviceName models.ServiceName
}

func NewRouterAuthorizationCheck(serviceName models.ServiceName) *RouterAuthorizationCheck {
return &RouterAuthorizationCheck{
serviceName: serviceName,
}
}

func (rac *RouterAuthorizationCheck) Check(c *gin.Context, ausfContext ausf_context.NFContext) {
token := c.Request.Header.Get("Authorization")
err := ausfContext.AuthorizationCheck(token, rac.serviceName)
if err != nil {
logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Unauthorized: %s", err.Error())
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
c.Abort()
return
}

logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Authorized")
}
93 changes: 93 additions & 0 deletions internal/util/router_auth_check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package util

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"

"github.com/free5gc/openapi/models"
)

const (
Valid = "valid"
Invalid = "invalid"
)

type mockAUSFContext struct{}

func newMockAUSFContext() *mockAUSFContext {
return &mockAUSFContext{}
}

func (m *mockAUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error {
if token == Valid {
return nil
}

return errors.New("invalid token")
}

func TestRouterAuthorizationCheck_Check(t *testing.T) {
// Mock gin.Context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)

var err error
c.Request, err = http.NewRequest("GET", "/", nil)
if err != nil {
t.Errorf("error on http request: %+v", err)
}

type Args struct {
token string
}
type Want struct {
statusCode int
}

tests := []struct {
name string
args Args
want Want
}{
{
name: "Valid Token",
args: Args{
token: Valid,
},
want: Want{
statusCode: http.StatusOK,
},
},
{
name: "Invalid Token",
args: Args{
token: Invalid,
},
want: Want{
statusCode: http.StatusUnauthorized,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w = httptest.NewRecorder()
c, _ = gin.CreateTestContext(w)
c.Request, err = http.NewRequest("GET", "/", nil)
if err != nil {
t.Errorf("error on http request: %+v", err)
}
c.Request.Header.Set("Authorization", tt.args.token)

rac := NewRouterAuthorizationCheck(models.ServiceName("testService"))
rac.Check(c, newMockAUSFContext())
if w.Code != tt.want.statusCode {
t.Errorf("StatusCode should be %d, but got %d", tt.want.statusCode, w.Code)
}
})
}
}
Loading

0 comments on commit 48a6f17

Please sign in to comment.