diff --git a/go.mod b/go.mod index b0080ef..bd82daa 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/cel-go v0.21.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect @@ -81,11 +82,13 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect + golang.org/x/mod v0.21.0 // indirect golang.org/x/net v0.29.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.18.0 // indirect golang.org/x/time v0.6.0 // indirect + golang.org/x/tools v0.24.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 8c3c88f..d45a794 100644 --- a/go.sum +++ b/go.sum @@ -105,6 +105,7 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -240,6 +241,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -319,6 +322,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/server/grpc.go b/internal/server/grpc.go index 5f7ff87..463c7d2 100644 --- a/internal/server/grpc.go +++ b/internal/server/grpc.go @@ -18,6 +18,7 @@ import ( "go.opentelemetry.io/otel/metric" + kratosmiddleware "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/metrics" kesselMetrics "github.com/project-kessel/relations-api/internal/server/middleware/metrics" kesselRecovery "github.com/project-kessel/relations-api/internal/server/middleware/recovery" @@ -38,41 +39,50 @@ func NewGRPCServer(c *conf.Server, relations *service.RelationshipsService, heal if err != nil { return nil, err } - var opts = []grpc.ServerOption{ - grpc.Middleware( - recovery.Recovery(), - middleware.ValidationMiddleware(validator), - logging.Server(logger), - metrics.Server( - metrics.WithSeconds(seconds), - metrics.WithRequests(requests), - ), + + unaryMiddleware := []kratosmiddleware.Middleware{ + recovery.Recovery(), + middleware.ValidationMiddleware(validator), + logging.Server(logger), + metrics.Server( + metrics.WithSeconds(seconds), + metrics.WithRequests(requests), ), - grpc.Options(googlegrpc.ChainStreamInterceptor( - middleware.StreamLogInterceptor(logger), - middleware.StreamValidationInterceptor(validator), - kesselRecovery.StreamRecoveryInterceptor(logger), - kesselMetrics.StreamMetricsInterceptor( - kesselMetrics.WithSeconds(seconds), - kesselMetrics.WithRequests(requests), - ), - )), } + streamingMiddleware := []googlegrpc.StreamServerInterceptor{ + middleware.StreamLogInterceptor(logger), + middleware.StreamValidationInterceptor(validator), + kesselRecovery.StreamRecoveryInterceptor(logger), + kesselMetrics.StreamMetricsInterceptor( + kesselMetrics.WithSeconds(seconds), + kesselMetrics.WithRequests(requests), + ), + } + if c.Auth.EnableAuth { jwks, err := FetchJwks(c.Auth.JwksUrl) if err != nil { return nil, err } - opts = append(opts, grpc.Middleware( + + unaryMiddleware = append(unaryMiddleware, selector.Server(jwt.Server(jwks.Keyfunc, jwt.WithSigningMethod(jwtv5.SigningMethodRS256))). Match(NewWhiteListMatcher). Build(), - ), - grpc.Options(googlegrpc.ChainStreamInterceptor(auth.StreamAuthInterceptor( - jwks.Keyfunc, - auth.WithSigningMethod(jwtv5.SigningMethodRS256)))), ) + streamingMiddleware = append(streamingMiddleware, auth.StreamAuthInterceptor( + jwks.Keyfunc, + auth.WithSigningMethod(jwtv5.SigningMethodRS256))) + } + + var opts = []grpc.ServerOption{ + grpc.Middleware( + unaryMiddleware..., + ), + grpc.Options(googlegrpc.ChainStreamInterceptor( + streamingMiddleware..., + )), } if c.Grpc.Network != "" { opts = append(opts, grpc.Network(c.Grpc.Network)) @@ -83,6 +93,7 @@ func NewGRPCServer(c *conf.Server, relations *service.RelationshipsService, heal if c.Grpc.Timeout != nil { opts = append(opts, grpc.Timeout(c.Grpc.Timeout.AsDuration())) } + srv := grpc.NewServer(opts...) v1beta1.RegisterKesselTupleServiceServer(srv, relations) v1beta1.RegisterKesselCheckServiceServer(srv, check) diff --git a/test/kessel_test.go b/test/kessel_test.go index b3c2e94..1dfb1fe 100644 --- a/test/kessel_test.go +++ b/test/kessel_test.go @@ -15,8 +15,10 @@ import ( v1beta1 "github.com/project-kessel/relations-api/api/kessel/relations/v1beta1" "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" ) var localKesselContainer *LocalKesselContainer @@ -259,6 +261,36 @@ func TestKesselAPIGRPC_LookupResources(t *testing.T) { assert.NoError(t, err) } +func TestKesselAPIGRPC_LookupResourcesInvalid(t *testing.T) { + //Ensures that validation middleware is still active with authentication enabled + t.Parallel() + kcurl := fmt.Sprintf("http://localhost:%s", localKesselContainer.kccontainer.GetPort("8080/tcp")) + token, err := GetJWTToken(kcurl, "admin", "admin") + if err != nil { + fmt.Print(err) + } + conn, err := grpc.NewClient( + fmt.Sprintf("localhost:%s", localKesselContainer.gRPCport), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpcutil.WithInsecureBearerToken(token.AccessToken), + ) + if err != nil { + fmt.Print(err) + } + + client := v1beta1.NewKesselLookupServiceClient(conn) + + stream, err := client.LookupResources( + context.Background(), &v1beta1.LookupResourcesRequest{}) + assert.NoError(t, err) + + _, err = stream.Recv() //Errors are returned with the first response, not the initial request + + status, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.InvalidArgument, status.Code()) +} + func pointerize(value string) *string { //Used to turn string literals into pointers return &value }