diff --git a/.gitignore b/.gitignore index d685dbc3..e82ca76a 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,9 @@ # Ignore editor config .vscode +# for VS Code debug sessions +**/__debug_bin* +**/ginkgo.report # Ignore IntelliJ config .idea/ diff --git a/Makefile b/Makefile index 1266f05e..9ac298f1 100644 --- a/Makefile +++ b/Makefile @@ -361,3 +361,7 @@ crc/login: @crc console --credentials -ojson | jq -r .clusterConfig.adminCredentials.password | oc login --username kubeadmin --insecure-skip-tls-verify=true https://api.crc.testing:6443 @oc whoami --show-token | $(container_tool) login --username kubeadmin --password-stdin "$(external_image_registry)" --tls-verify=false .PHONY: crc/login + +.PHONY: fmt-imports +fmt-imports: $(GCI) + find . -name '*.go' -not -path './vendor/*' | xargs $(GCI) write -s standard -s default -s "prefix(k8s)" -s "prefix(sigs.k8s)" -s "prefix(github.com)" -s "prefix(gitlab)" -s "prefix(github.com/openshift-online/rh-trex)" --custom-order --skip-generated \ No newline at end of file diff --git a/cmd/trex/environments/service_types.go b/cmd/trex/environments/service_types.go index 67f802ed..0cfe6e5a 100644 --- a/cmd/trex/environments/service_types.go +++ b/cmd/trex/environments/service_types.go @@ -22,7 +22,9 @@ type GenericServiceLocator func() services.GenericService func NewGenericServiceLocator(env *Env) GenericServiceLocator { return func() services.GenericService { - return services.NewGenericService(dao.NewGenericDao(&env.Database.SessionFactory)) + return services.NewGenericService( + dao.NewGenericDao(&env.Database.SessionFactory), + env.Clients.OCM) } } diff --git a/go.mod b/go.mod index b3289446..310be9ad 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/jinzhu/inflection v1.0.0 github.com/lib/pq v1.10.5 github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 + github.com/onsi/ginkgo/v2 v2.8.1 github.com/onsi/gomega v1.27.1 github.com/openshift-online/ocm-sdk-go v0.1.334 github.com/prometheus/client_golang v1.16.0 @@ -25,6 +26,7 @@ require ( github.com/spf13/cobra v0.0.5 github.com/spf13/pflag v1.0.5 github.com/yaacov/tree-search-language v0.0.0-20190923184055-1c2dad2e354b + go.uber.org/mock v0.4.0 gopkg.in/resty.v1 v1.12.0 gorm.io/driver/postgres v1.0.5 gorm.io/gorm v1.20.5 @@ -38,6 +40,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/docker/distribution v2.8.1+incompatible // indirect + github.com/go-logr/logr v1.2.3 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/gorilla/css v1.0.0 // indirect diff --git a/go.sum b/go.sum index 6fedf786..e215018b 100644 --- a/go.sum +++ b/go.sum @@ -384,7 +384,6 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= @@ -511,6 +510,8 @@ go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= diff --git a/pkg/api/dinosaur_types.go b/pkg/api/dinosaur_types.go index 8e03bed7..531baad4 100644 --- a/pkg/api/dinosaur_types.go +++ b/pkg/api/dinosaur_types.go @@ -1,6 +1,13 @@ package api -import "gorm.io/gorm" +import ( + "github.com/openshift-online/rh-trex/pkg/util" + "gorm.io/gorm" +) + +var ( + DinosaurTypeName = util.GetBaseType(Dinosaur{}) +) type Dinosaur struct { Meta diff --git a/pkg/api/presenters/kind.go b/pkg/api/presenters/kind.go index e7db32c9..f15876ab 100644 --- a/pkg/api/presenters/kind.go +++ b/pkg/api/presenters/kind.go @@ -1,19 +1,12 @@ package presenters import ( - "github.com/openshift-online/rh-trex/pkg/api" "github.com/openshift-online/rh-trex/pkg/api/openapi" - "github.com/openshift-online/rh-trex/pkg/errors" + "github.com/openshift-online/rh-trex/pkg/util" ) func ObjectKind(i interface{}) *string { - result := "" - switch i.(type) { - case api.Dinosaur, *api.Dinosaur: - result = "Dinosaur" - case errors.ServiceError, *errors.ServiceError: - result = "Error" - } + result := util.GetBaseType(i) return openapi.PtrString(result) } diff --git a/pkg/api/presenters/object_reference_test.go b/pkg/api/presenters/object_reference_test.go new file mode 100644 index 00000000..d33bdbbe --- /dev/null +++ b/pkg/api/presenters/object_reference_test.go @@ -0,0 +1,21 @@ +package presenters + +import ( + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" + "github.com/openshift-online/rh-trex/pkg/api" +) + +var _ = Describe("Object Reference Presenter", func() { + It("Populates Kind", func() { + object := api.Dinosaur{ + Meta: api.Meta{ + ID: "123", + }, + } + presented := PresentReference("123", object) + Expect(*presented.Id).To(Equal(object.ID)) + Expect(*presented.Kind).To(Equal("Dinosaur")) + Expect(*presented.Href).To(Equal("/api/rh-trex/v1/dinosaurs/123")) + }) +}) diff --git a/pkg/api/presenters/path.go b/pkg/api/presenters/path.go index fead9647..6a99cb03 100644 --- a/pkg/api/presenters/path.go +++ b/pkg/api/presenters/path.go @@ -4,9 +4,7 @@ import ( "fmt" "github.com/openshift-online/rh-trex/pkg/api/openapi" - - "github.com/openshift-online/rh-trex/pkg/api" - "github.com/openshift-online/rh-trex/pkg/errors" + "github.com/openshift-online/rh-trex/pkg/util" ) const ( @@ -18,12 +16,5 @@ func ObjectPath(id string, obj interface{}) *string { } func path(i interface{}) string { - switch i.(type) { - case api.Dinosaur, *api.Dinosaur: - return "dinosaurs" - case errors.ServiceError, *errors.ServiceError: - return "errors" - default: - return "" - } + return fmt.Sprintf("%ss", util.ToSnakeCase(util.GetBaseType(i))) } diff --git a/pkg/api/presenters/suite_test.go b/pkg/api/presenters/suite_test.go new file mode 100644 index 00000000..4edf2278 --- /dev/null +++ b/pkg/api/presenters/suite_test.go @@ -0,0 +1,13 @@ +package presenters + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +func TestAccessProtection(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Presenters Suite") +} diff --git a/pkg/auth/actions.go b/pkg/auth/actions.go new file mode 100644 index 00000000..d725ffa0 --- /dev/null +++ b/pkg/auth/actions.go @@ -0,0 +1,9 @@ +package auth + +const ( + GetAction = "get" + ListAction = "list" + UpdateAction = "update" + CreateAction = "create" + DeleteAction = "delete" +) diff --git a/pkg/client/ocm/authorization.go b/pkg/client/ocm/authorization.go index bde9a06e..f65a8458 100644 --- a/pkg/client/ocm/authorization.go +++ b/pkg/client/ocm/authorization.go @@ -7,9 +7,11 @@ import ( azv1 "github.com/openshift-online/ocm-sdk-go/authorizations/v1" ) +//go:generate mockgen -source=authorization.go -package=ocm -destination=mock_authorization.go type OCMAuthorization interface { SelfAccessReview(ctx context.Context, action, resourceType, organizationID, subscriptionID, clusterID string) (allowed bool, err error) AccessReview(ctx context.Context, username, action, resourceType, organizationID, subscriptionID, clusterID string) (allowed bool, err error) + ResourceReview(ctx context.Context, username string, action string, resource string) (*azv1.ResourceReview, error) } type authorization service @@ -75,3 +77,18 @@ func (a authorization) AccessReview(ctx context.Context, username, action, resou return response.Allowed(), nil } + +func (a authorization) ResourceReview(ctx context.Context, username string, action string, resource string) (*azv1.ResourceReview, error) { + con := a.client.connection + resourceReviewClient := con.Authorizations().V1().ResourceReview() + + request, err := azv1.NewResourceReviewRequest().AccountUsername(username).Action(action).ResourceType(resource).Build() + if err != nil { + return nil, err + } + response, err := resourceReviewClient.Post().Request(request).SendContext(ctx) + if err != nil { + return nil, err + } + return response.Review(), nil +} diff --git a/pkg/client/ocm/authorization_mock.go b/pkg/client/ocm/authorization_mock.go index 56470ed1..c14320d2 100644 --- a/pkg/client/ocm/authorization_mock.go +++ b/pkg/client/ocm/authorization_mock.go @@ -2,6 +2,8 @@ package ocm import ( "context" + + azv1 "github.com/openshift-online/ocm-sdk-go/authorizations/v1" ) // authorizationMock returns allowed=true for every request @@ -16,3 +18,11 @@ func (a authorizationMock) SelfAccessReview(ctx context.Context, action, resourc func (a authorizationMock) AccessReview(ctx context.Context, username, action, resourceType, organizationID, subscriptionID, clusterID string) (allowed bool, err error) { return true, nil } + +func (a authorizationMock) ResourceReview(ctx context.Context, username string, action string, resource string) (*azv1.ResourceReview, error) { + response, err := azv1.NewResourceReview().AccountUsername(username).Action(action).ResourceType(resource).OrganizationIDs("*").Build() + if err != nil { + return nil, err + } + return response, nil +} diff --git a/pkg/client/ocm/mock_authorization.go b/pkg/client/ocm/mock_authorization.go new file mode 100644 index 00000000..94f5253b --- /dev/null +++ b/pkg/client/ocm/mock_authorization.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: authorization.go +// +// Generated by this command: +// +// mockgen -source=authorization.go -package=ocm -destination=mock_authorization.go +// +// Package ocm is a generated GoMock package. +package ocm + +import ( + context "context" + reflect "reflect" + + v1 "github.com/openshift-online/ocm-sdk-go/authorizations/v1" + gomock "go.uber.org/mock/gomock" +) + +// MockOCMAuthorization is a mock of OCMAuthorization interface. +type MockOCMAuthorization struct { + ctrl *gomock.Controller + recorder *MockOCMAuthorizationMockRecorder +} + +// MockOCMAuthorizationMockRecorder is the mock recorder for MockOCMAuthorization. +type MockOCMAuthorizationMockRecorder struct { + mock *MockOCMAuthorization +} + +// NewMockOCMAuthorization creates a new mock instance. +func NewMockOCMAuthorization(ctrl *gomock.Controller) *MockOCMAuthorization { + mock := &MockOCMAuthorization{ctrl: ctrl} + mock.recorder = &MockOCMAuthorizationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOCMAuthorization) EXPECT() *MockOCMAuthorizationMockRecorder { + return m.recorder +} + +// AccessReview mocks base method. +func (m *MockOCMAuthorization) AccessReview(ctx context.Context, username, action, resourceType, organizationID, subscriptionID, clusterID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessReview", ctx, username, action, resourceType, organizationID, subscriptionID, clusterID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AccessReview indicates an expected call of AccessReview. +func (mr *MockOCMAuthorizationMockRecorder) AccessReview(ctx, username, action, resourceType, organizationID, subscriptionID, clusterID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessReview", reflect.TypeOf((*MockOCMAuthorization)(nil).AccessReview), ctx, username, action, resourceType, organizationID, subscriptionID, clusterID) +} + +// ResourceReview mocks base method. +func (m *MockOCMAuthorization) ResourceReview(ctx context.Context, username, action, resource string) (*v1.ResourceReview, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResourceReview", ctx, username, action, resource) + ret0, _ := ret[0].(*v1.ResourceReview) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResourceReview indicates an expected call of ResourceReview. +func (mr *MockOCMAuthorizationMockRecorder) ResourceReview(ctx, username, action, resource any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceReview", reflect.TypeOf((*MockOCMAuthorization)(nil).ResourceReview), ctx, username, action, resource) +} + +// SelfAccessReview mocks base method. +func (m *MockOCMAuthorization) SelfAccessReview(ctx context.Context, action, resourceType, organizationID, subscriptionID, clusterID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SelfAccessReview", ctx, action, resourceType, organizationID, subscriptionID, clusterID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SelfAccessReview indicates an expected call of SelfAccessReview. +func (mr *MockOCMAuthorizationMockRecorder) SelfAccessReview(ctx, action, resourceType, organizationID, subscriptionID, clusterID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SelfAccessReview", reflect.TypeOf((*MockOCMAuthorization)(nil).SelfAccessReview), ctx, action, resourceType, organizationID, subscriptionID, clusterID) +} diff --git a/pkg/client/ocm/resource_types.go b/pkg/client/ocm/resource_types.go new file mode 100644 index 00000000..102dcd6e --- /dev/null +++ b/pkg/client/ocm/resource_types.go @@ -0,0 +1,9 @@ +package ocm + +type Resource string + +const ( + ClusterResource Resource = "Cluster" + SubscriptionResource Resource = "Subscription" + OrganizationResource Resource = "Organization" +) diff --git a/pkg/dao/dinosaur.go b/pkg/dao/dinosaur.go index e100c9dc..61e2741c 100644 --- a/pkg/dao/dinosaur.go +++ b/pkg/dao/dinosaur.go @@ -7,8 +7,28 @@ import ( "github.com/openshift-online/rh-trex/pkg/api" "github.com/openshift-online/rh-trex/pkg/db" + "github.com/openshift-online/rh-trex/pkg/util" ) +var ( + dinosaurTableName = util.ToSnakeCase(api.DinosaurTypeName) + "s" + dinosaurColumns = []string{ + "id", + "created_at", + "updated_at", + "species", + } +) + +func DinosaurApiToModel() TableMappingRelation { + result := map[string]string{} + applyBaseMapping(result, dinosaurColumns, dinosaurTableName) + return TableMappingRelation{ + Mapping: result, + relationTableName: dinosaurTableName, + } +} + type DinosaurDao interface { Get(ctx context.Context, id string) (*api.Dinosaur, error) Create(ctx context.Context, dinosaur *api.Dinosaur) (*api.Dinosaur, error) diff --git a/pkg/dao/generic.go b/pkg/dao/generic.go index c7680e50..3ef71835 100644 --- a/pkg/dao/generic.go +++ b/pkg/dao/generic.go @@ -2,6 +2,7 @@ package dao import ( "context" + "fmt" "strings" "github.com/jinzhu/inflection" @@ -10,6 +11,41 @@ import ( "github.com/openshift-online/rh-trex/pkg/db" ) +type TableMappingRelation struct { + Mapping map[string]string + relationTableName string +} + +type relationMapping func() TableMappingRelation + +func applyBaseMapping(result map[string]string, columns []string, tableName string) { + for _, c := range columns { + mappingKey := c + mappingValue := fmt.Sprintf("%s.%s", tableName, c) + columnParts := strings.Split(c, ".") + if len(columnParts) == 1 { + mappingKey = mappingValue + } + if len(columnParts) == 2 { + mappingValue = strings.Split(mappingKey, ".")[1] + } + result[mappingKey] = mappingValue + } +} + +func applyRelationMapping(result map[string]string, relations []relationMapping) { + for _, relation := range relations { + tableMappingRelation := relation() + for k, v := range tableMappingRelation.Mapping { + if _, ok := result[k]; ok { + result[tableMappingRelation.relationTableName+"."+k] = v + } else { + result[k] = v + } + } + } +} + type Where struct { sql string values []any diff --git a/pkg/dao/generic_test.go b/pkg/dao/generic_test.go new file mode 100644 index 00000000..59346b75 --- /dev/null +++ b/pkg/dao/generic_test.go @@ -0,0 +1,63 @@ +package dao + +import ( + "fmt" + "strings" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +var _ = Describe("applyBaseMapping", func() { + It("generates base mapping", func() { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "test_table") + for k, v := range result { + if strings.HasPrefix(k, "test_table") { + Expect(k).To(Equal(v)) + continue + } + // nested fields from table + i := strings.Index(k, ".") + Expect(k[i+1:]).To(Equal(v)) + } + }) +}) + +var _ = Describe("applyRelationMapping", func() { + It("generates relation mapping", func() { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "base_table") + applyRelationMapping(result, []relationMapping{ + func() TableMappingRelation { + result := map[string]string{} + applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "relation_table") + return TableMappingRelation{ + relationTableName: "relation_table", + Mapping: result, + } + }, + }) + for k, v := range result { + if strings.HasPrefix(k, "base_table") { + Expect(k).To(Equal(v)) + continue + } + if strings.HasPrefix(k, "relation_table") { + if c := strings.Count(k, "."); c > 1 { + i := strings.Index(k, ".") + i = strings.Index(k[i+1:], ".") + i + Expect(k[i+2:]).To(Equal(v)) + continue + } + Expect(k).To(Equal(v)) + continue + } + + // nested fields from base table + i := strings.Index(k, ".") + Expect(k[i+1:]).To(Equal(v)) + fmt.Println(k, v) + } + }) +}) diff --git a/pkg/dao/suite_test.go b/pkg/dao/suite_test.go new file mode 100644 index 00000000..523b0755 --- /dev/null +++ b/pkg/dao/suite_test.go @@ -0,0 +1,13 @@ +package dao + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +func TestAccessProtection(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Dao Suite") +} diff --git a/pkg/db/db_session/default.go b/pkg/db/db_session/default.go index 9e9fd9b0..40acc7d5 100644 --- a/pkg/db/db_session/default.go +++ b/pkg/db/db_session/default.go @@ -158,6 +158,9 @@ func (f *Default) CheckConnection() error { // THIS MUST **NOT** BE CALLED UNTIL THE SERVER/PROCESS IS EXITING!! // This should only ever be called once for the entire duration of the application and only at the end. func (f *Default) Close() error { + if f.db == nil { + return nil + } return f.db.Close() } diff --git a/pkg/db/db_session/test.go b/pkg/db/db_session/test.go index 57726953..5ed59155 100644 --- a/pkg/db/db_session/test.go +++ b/pkg/db/db_session/test.go @@ -209,6 +209,9 @@ func (f *Test) CheckConnection() error { } func (f *Test) Close() error { + if f.db == nil { + return nil + } return f.db.Close() } diff --git a/pkg/db/migrations/migrations_test.go b/pkg/db/migrations/migrations_test.go new file mode 100644 index 00000000..3537871a --- /dev/null +++ b/pkg/db/migrations/migrations_test.go @@ -0,0 +1,43 @@ +package migrations + +import ( + "fmt" + "os" + "strings" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMigration(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Migration Suite") +} + +var _ = Describe("Migrate", func() { + It("Expects same amount of files and migrationList", func() { + cwd, err := os.Getwd() + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + files, err := os.ReadDir(cwd) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + amountGoFiles := []string{} + for _, file := range files { + if !strings.Contains(file.Name(), ".go") { + continue + } + if strings.Contains(file.Name(), "migration_structs") || strings.Contains(file.Name(), "migrations_test") { + continue + } + amountGoFiles = append(amountGoFiles, file.Name()) + } + // Disconsiders migration_structs.go and test files + Expect(amountGoFiles).To(HaveLen(len(MigrationList))) + }) +}) diff --git a/pkg/db/sql_helpers.go b/pkg/db/sql_helpers.go index cc867489..ad0497cc 100644 --- a/pkg/db/sql_helpers.go +++ b/pkg/db/sql_helpers.go @@ -3,6 +3,7 @@ package db import ( "fmt" "reflect" + "slices" "strings" "github.com/jinzhu/inflection" @@ -11,6 +12,11 @@ import ( "gorm.io/gorm" ) +const ( + invalidFieldNameMsg = "%s is not a valid field name" + disallowedFieldNameMsg = "%s is a disallowed field name" +) + // Check if a field name starts with properties. func startsWithProperties(s string) bool { return strings.HasPrefix(s, "properties.") @@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool { } // getField gets the sql field associated with a name. -func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) { +func getField( + name string, + disallowedFields []string, + apiToModel map[string]string, +) (field string, err *errors.ServiceError) { // We want to accept names with trailing and leading spaces trimmedName := strings.Trim(name, " ") - // Check for properties ->> '' - if strings.HasPrefix(trimmedName, "properties ->>") { - field = trimmedName - return + mappedField, ok := apiToModel[trimmedName] + if !ok { + return "", errors.BadRequest(invalidFieldNameMsg, name) } // Check for nested field, e.g., subscription_labels.key - checkName := trimmedName - fieldParts := strings.Split(trimmedName, ".") + checkName := mappedField + fieldParts := strings.Split(checkName, ".") if len(fieldParts) > 2 { - err = errors.BadRequest("%s is not a valid field name", name) + err = errors.BadRequest(invalidFieldNameMsg, name) return } - if len(fieldParts) > 1 { - checkName = fieldParts[1] - } // Check for allowed fields - _, ok := disallowedFields[checkName] - if ok { - err = errors.BadRequest("%s is not a valid field name", name) + if slices.Contains(disallowedFields, checkName) { + err = errors.BadRequest(disallowedFieldNameMsg, name) return } - field = trimmedName + field = checkName return } @@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node { // b. replace the field name with the SQL column name. func FieldNameWalk( n tsl.Node, - disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) { + disallowedFields []string, + apiToModel map[string]string) (newNode tsl.Node, err *errors.ServiceError) { var field string var l, r tsl.Node @@ -124,7 +130,7 @@ func FieldNameWalk( } // Check field name in the disallowedFields field names. - field, err = getField(userFieldName, disallowedFields) + field, err = getField(userFieldName, disallowedFields, apiToModel) if err != nil { return } @@ -137,7 +143,7 @@ func FieldNameWalk( default: // o/w continue walking the tree. if n.Left != nil { - l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields) + l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields, apiToModel) if err != nil { return } @@ -148,7 +154,7 @@ func FieldNameWalk( switch v := n.Right.(type) { case tsl.Node: // It's a regular node, just add it. - r, err = FieldNameWalk(v, disallowedFields) + r, err = FieldNameWalk(v, disallowedFields, apiToModel) if err != nil { return } @@ -162,7 +168,7 @@ func FieldNameWalk( // Add all nodes in the right side array. for _, e := range v { - r, err = FieldNameWalk(e, disallowedFields) + r, err = FieldNameWalk(e, disallowedFields, apiToModel) if err != nil { return } @@ -189,7 +195,10 @@ func FieldNameWalk( } // cleanOrderBy takes the orderBy arg and cleans it. -func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) { +func cleanOrderBy(userArg string, + disallowedFields []string, + apiToModel map[string]string, + tableName string) (orderBy string, err *errors.ServiceError) { var orderField string // We want to accept user params with trailing and leading spaces @@ -197,15 +206,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s // Each OrderBy can be a "" or a " asc|desc" order := strings.Split(trimedName, " ") - direction := "none valid" - - if len(order) == 1 { - orderField, err = getField(order[0], disallowedFields) - direction = "asc" - } else if len(order) == 2 { - orderField, err = getField(order[0], disallowedFields) + direction := "asc" + if len(order) == 2 { direction = order[1] } + field := order[0] + if orderParts := strings.Split(order[0], "."); len(orderParts) == 1 { + field = fmt.Sprintf("%s.%s", tableName, field) + } + orderField, err = getField(field, disallowedFields, apiToModel) if err != nil || (direction != "asc" && direction != "desc") { err = errors.BadRequest("bad order value '%s'", userArg) return @@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s // ArgsToOrderBy returns cleaned orderBy list. func ArgsToOrderBy( orderByArgs []string, - disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) { + disallowedFields []string, + apiToModel map[string]string, + tableName string) (orderBy []string, err *errors.ServiceError) { var order string if len(orderByArgs) != 0 { orderBy = []string{} for _, o := range orderByArgs { - order, err = cleanOrderBy(o, disallowedFields) + order, err = cleanOrderBy(o, disallowedFields, apiToModel, tableName) if err != nil { return } diff --git a/pkg/handlers/dinosaur.go b/pkg/handlers/dinosaur.go index 552dc4b7..c6b1a08f 100644 --- a/pkg/handlers/dinosaur.go +++ b/pkg/handlers/dinosaur.go @@ -82,7 +82,7 @@ func (h dinosaurHandler) List(w http.ResponseWriter, r *http.Request) { listArgs := services.NewListArguments(r.URL.Query()) var dinosaurs = []api.Dinosaur{} - paging, err := h.generic.List(ctx, "username", listArgs, &dinosaurs) + paging, err := h.generic.List(ctx, listArgs, &dinosaurs) if err != nil { return nil, err } diff --git a/pkg/services/generic.go b/pkg/services/generic.go index b2d8bf55..1076bbdf 100644 --- a/pkg/services/generic.go +++ b/pkg/services/generic.go @@ -15,6 +15,8 @@ import ( sqlFilter "github.com/yaacov/tree-search-language/pkg/walkers/sql" "github.com/openshift-online/rh-trex/pkg/api" + "github.com/openshift-online/rh-trex/pkg/auth" + "github.com/openshift-online/rh-trex/pkg/client/ocm" "github.com/openshift-online/rh-trex/pkg/dao" "github.com/openshift-online/rh-trex/pkg/db" "github.com/openshift-online/rh-trex/pkg/errors" @@ -22,22 +24,33 @@ import ( ) type GenericService interface { - List(ctx context.Context, username string, args *ListArguments, resourceList interface{}) (*api.PagingMeta, *errors.ServiceError) + List(ctx context.Context, args *ListArguments, resourceList interface{}) (*api.PagingMeta, *errors.ServiceError) } -func NewGenericService(genericDao dao.GenericDao) GenericService { - return &sqlGenericService{genericDao: genericDao} +func NewGenericService(genericDao dao.GenericDao, ocmClient *ocm.Client) GenericService { + return &sqlGenericService{genericDao: genericDao, ocmClient: ocmClient} } var _ GenericService = &sqlGenericService{} type sqlGenericService struct { genericDao dao.GenericDao + ocmClient *ocm.Client } var ( - SearchDisallowedFields = map[string]map[string]string{} - allFieldsAllowed = map[string]string{} + searchDisallowedFields = map[string][]string{} + allFieldsAllowed = []string{} + // Some mappings are not required as they match AMS resource 1:1 + // Such as Organization + modelToAmsResource = map[string]string{} + + // TODO: This should be more dynamic + // prefarably utilizing the openapi json via reflect + // and the column names from the model + openapiToModelFields = map[string]dao.TableMappingRelation{ + api.DinosaurTypeName: dao.DinosaurApiToModel(), + } ) // wrap all needed pieces for the LIST funciton @@ -48,24 +61,31 @@ type listContext struct { pagingMeta *api.PagingMeta ulog *logger.OCMLogger resourceList interface{} - disallowedFields *map[string]string + disallowedFields []string + openapiToModel map[string]string resourceType string joins map[string]dao.TableRelation groupBy []string set map[string]bool } -func (s *sqlGenericService) newListContext(ctx context.Context, username string, args *ListArguments, resourceList interface{}) (*listContext, interface{}, *errors.ServiceError) { +func newListContext( + ctx context.Context, + args *ListArguments, + resourceList interface{}, +) (*listContext, interface{}, *errors.ServiceError) { + username := auth.GetUsernameFromContext(ctx) log := logger.NewOCMLogger(ctx) resourceModel := reflect.TypeOf(resourceList).Elem().Elem() resourceTypeStr := resourceModel.Name() if resourceTypeStr == "" { return nil, nil, errors.GeneralError("Could not determine resource type") } - disallowedFields := SearchDisallowedFields[resourceTypeStr] + disallowedFields := searchDisallowedFields[resourceTypeStr] if disallowedFields == nil { disallowedFields = allFieldsAllowed } + openapiToModel := openapiToModelFields[resourceTypeStr] args.Search = strings.Trim(args.Search, " ") return &listContext{ ctx: ctx, @@ -74,18 +94,75 @@ func (s *sqlGenericService) newListContext(ctx context.Context, username string, pagingMeta: &api.PagingMeta{Page: args.Page}, ulog: &log, resourceList: resourceList, - disallowedFields: &disallowedFields, + disallowedFields: disallowedFields, + openapiToModel: openapiToModel.Mapping, resourceType: resourceTypeStr, }, reflect.New(resourceModel).Interface(), nil } +func resourceIncludesOrgId(model interface{}) bool { + resourceModel := reflect.TypeOf(model).Elem() + _, found := resourceModel.FieldByName("OrganizationId") + return found +} + +func isAllowedToAllOrgs(allowedOrgs []string) bool { + return len(allowedOrgs) == 1 && allowedOrgs[0] == "*" +} + +func (s *sqlGenericService) populateSearchRestriction(listCtx *listContext, model any) *errors.ServiceError { + ctx := listCtx.ctx + resourceName := listCtx.resourceType + if name, ok := modelToAmsResource[resourceName]; ok { + resourceName = string(name) + } + if resourceIncludesOrgId(model) { + resourceReview, err := s.ocmClient.Authorization.ResourceReview( + ctx, + listCtx.username, + auth.GetAction, + resourceName, + ) + if err != nil { + return errors.GeneralError( + "Failed to verify resource review for user '%s' on resource '%s': %v", + listCtx.username, + listCtx.resourceType, + err, + ) + } + + // TODO setup a search builder + allowedOrgs := resourceReview.OrganizationIDs() + // If user doesn't have access to all orgs include search for allowed only + if !isAllowedToAllOrgs(allowedOrgs) { + if listCtx.args.Search != "" { + listCtx.args.Search += " and " + } + for i := range allowedOrgs { + allowedOrgs[i] = fmt.Sprintf("'%s'", allowedOrgs[i]) + } + listCtx.args.Search += fmt.Sprintf("organization_id in (%s)", strings.Join(allowedOrgs, ",")) + } + } + return nil +} + // resourceList must be a pointer to a slice of database resource objects -func (s *sqlGenericService) List(ctx context.Context, username string, args *ListArguments, resourceList interface{}) (*api.PagingMeta, *errors.ServiceError) { - listCtx, model, err := s.newListContext(ctx, username, args, resourceList) +func (s *sqlGenericService) List( + ctx context.Context, + args *ListArguments, + resourceList interface{}, +) (*api.PagingMeta, *errors.ServiceError) { + listCtx, model, err := newListContext(ctx, args, resourceList) if err != nil { return nil, err } + if err = s.populateSearchRestriction(listCtx, model); err != nil { + return nil, err + } + // the ordering for the sub functions matters. builders := []listBuilder{ // build SQL to load related resource. for now, it delegates to gorm.preload. @@ -100,7 +177,7 @@ func (s *sqlGenericService) List(ctx context.Context, username string, args *Lis // TODO: add any custom builder functions } - d := s.genericDao.GetInstanceDao(ctx, model) + d := s.genericDao.GetInstanceDao(listCtx.ctx, model) // run all the "builders". they cumulatively add constructs to gorm by the context. // it stops when a builder function raises error or signals finished. @@ -137,7 +214,8 @@ func (s *sqlGenericService) buildPreload(listCtx *listContext, d *dao.GenericDao func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao) (bool, *errors.ServiceError) { if len(listCtx.args.OrderBy) != 0 { - orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields) + orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, listCtx.disallowedFields, + listCtx.openapiToModel, (*d).GetTableName()) if serviceErr != nil { return false, serviceErr } @@ -148,7 +226,10 @@ func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao return false, nil } -func (s *sqlGenericService) buildSearchValues(listCtx *listContext, d *dao.GenericDao) (string, []any, *errors.ServiceError) { +func (s *sqlGenericService) buildSearchValues( + listCtx *listContext, + d *dao.GenericDao, +) (string, []any, *errors.ServiceError) { if listCtx.args.Search == "" { s.addJoins(listCtx, d) return "", nil, nil @@ -274,7 +355,11 @@ func zeroSlice(i interface{}, cap int64) *errors.ServiceError { // walk the TSL tree looking for fields like, e.g., creator.username, and then: // (1) look up the related table by its 1st part - creator // (2) replace it by table name - creator.username -> accounts.username -func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTree tsl.Node, genericDao *dao.GenericDao) (tsl.Node, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForRelatedTables( + listCtx *listContext, + tslTree tsl.Node, + genericDao *dao.GenericDao, +) (tsl.Node, *errors.ServiceError) { resourceTable := (*genericDao).GetTableName() if listCtx.joins == nil { listCtx.joins = map[string]dao.TableRelation{} @@ -282,17 +367,21 @@ func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTr walkFn := func(field string) (string, error) { fieldParts := strings.Split(field, ".") if len(fieldParts) > 1 && fieldParts[0] != resourceTable { - fieldName := fieldParts[0] - _, exists := listCtx.joins[fieldName] + nestedResource := fieldParts[0] + _, exists := listCtx.joins[nestedResource] if !exists { - if relation, ok := (*genericDao).GetTableRelation(fieldName); ok { - listCtx.joins[fieldName] = relation - } else { - return field, fmt.Errorf("%s is not a related resource of %s", fieldName, listCtx.resourceType) + // Populates relation if join exists + if relation, ok := (*genericDao).GetTableRelation(nestedResource); ok { + listCtx.joins[nestedResource] = relation + } else if _, ok := listCtx.openapiToModel[field]; !ok { + // If also not exposed as a nested resource consider this is an error + return field, fmt.Errorf("%s is not a related resource of %s", strings.Join(fieldParts, "."), listCtx.resourceType) } } - //replace by table name - fieldParts[0] = listCtx.joins[fieldName].ForeignTableName + // replace by table name if coming from join + if value, ok := listCtx.joins[nestedResource]; ok { + fieldParts[0] = value.ForeignTableName + } return strings.Join(fieldParts, "."), nil } return field, nil @@ -307,7 +396,11 @@ func (s *sqlGenericService) treeWalkForRelatedTables(listCtx *listContext, tslTr } // prepend table name to these "free" identifiers since they could cause "ambiguous" errors -func (s *sqlGenericService) treeWalkForAddingTableName(listCtx *listContext, tslTree tsl.Node, dao *dao.GenericDao) (tsl.Node, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForAddingTableName( + listCtx *listContext, + tslTree tsl.Node, + dao *dao.GenericDao, +) (tsl.Node, *errors.ServiceError) { resourceTable := (*dao).GetTableName() walkFn := func(field string) (string, error) { @@ -329,9 +422,12 @@ func (s *sqlGenericService) treeWalkForAddingTableName(listCtx *listContext, tsl return tslTree, nil } -func (s *sqlGenericService) treeWalkForSqlizer(listCtx *listContext, tslTree tsl.Node) (tsl.Node, squirrel.Sqlizer, *errors.ServiceError) { +func (s *sqlGenericService) treeWalkForSqlizer( + listCtx *listContext, + tslTree tsl.Node, +) (tsl.Node, squirrel.Sqlizer, *errors.ServiceError) { // Check field names in tree - tslTree, serviceErr := db.FieldNameWalk(tslTree, *listCtx.disallowedFields) + tslTree, serviceErr := db.FieldNameWalk(tslTree, listCtx.disallowedFields, listCtx.openapiToModel) if serviceErr != nil { return tslTree, nil, serviceErr } diff --git a/pkg/services/generic_test.go b/pkg/services/generic_test.go index 618c73a2..1ec6f02f 100644 --- a/pkg/services/generic_test.go +++ b/pkg/services/generic_test.go @@ -2,80 +2,202 @@ package services import ( "context" - "testing" + "net/url" + "reflect" + "github.com/openshift-online/rh-trex/pkg/auth" + "github.com/openshift-online/rh-trex/pkg/client/ocm" "github.com/openshift-online/rh-trex/pkg/dao" "github.com/openshift-online/rh-trex/pkg/db" + "go.uber.org/mock/gomock" "github.com/onsi/gomega/types" "github.com/yaacov/tree-search-language/pkg/tsl" + azv1 "github.com/openshift-online/ocm-sdk-go/authorizations/v1" "github.com/openshift-online/rh-trex/pkg/api" "github.com/openshift-online/rh-trex/pkg/config" "github.com/openshift-online/rh-trex/pkg/db/db_session" "github.com/openshift-online/rh-trex/pkg/errors" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -func TestSQLTranslation(t *testing.T) { - RegisterTestingT(t) - dbConfig := config.NewDatabaseConfig() - err := dbConfig.ReadFiles() - Expect(err).ToNot(HaveOccurred()) - var dbFactory db.SessionFactory = db_session.NewProdFactory(dbConfig) - defer dbFactory.Close() +type GenericTestDinosaur struct { + api.Meta + Species string + // This is to illustrate resource review in action + // It passes integration tests as it's mocked + // does not work for local envs pointing to integration AMS via proxy + OrganizationId string +} - g := dao.NewGenericDao(&dbFactory) - genericService := sqlGenericService{genericDao: g} +var _ = Describe("populates search restriction", func() { + var ctx context.Context + var ctrl *gomock.Controller + var genericService sqlGenericService + var genericDao dao.GenericDao + var authorizationMock *ocm.MockOCMAuthorization + var ocmClientMock *ocm.Client + username := "test-user" + BeforeEach(func() { + ctx = context.Background() + ctx = auth.SetUsernameContext(ctx, username) + ctrl = gomock.NewController(GinkgoT()) + dbConfig := config.NewDatabaseConfig() + err := dbConfig.ReadFiles() + Expect(err).ToNot(HaveOccurred()) + var dbFactory db.SessionFactory = db_session.NewTestFactory(dbConfig) + defer dbFactory.Close() - // ill-formatted search or disallowed fields should be rejected - tests := []map[string]interface{}{ - { - "search": "garbage", - "error": "rh-trex-21: Failed to parse search query: garbage", - }, - { - "search": "id in ('123')", - "error": "rh-trex-21: dinosaurs.id is not a valid field name", + authorizationMock = ocm.NewMockOCMAuthorization(ctrl) + ocmClientMock = &ocm.Client{ + Authorization: authorizationMock, + } + genericDao = dao.NewGenericDao(&dbFactory) + genericService = sqlGenericService{ + genericDao: genericDao, + ocmClient: ocmClientMock, + } + }) + Context("Resource includes organization ID field", func() { + When("Auth allows all orgs", func() { + It("Allows all orgs", func() { + args := NewListArguments(url.Values{}) + listCtx, model, serviceErr := newListContext(ctx, args, &[]GenericTestDinosaur{}) + resourceModel := reflect.TypeOf(&GenericTestDinosaur{}).Elem() + Expect(model).To(Equal(reflect.New(resourceModel).Interface())) + Expect(serviceErr).ToNot(HaveOccurred()) + response, err := azv1.NewResourceReview(). + AccountUsername(listCtx.username). + Action(auth.GetAction). + ResourceType("GenericTestDinosaur"). + OrganizationIDs("*"). + Build() + Expect(err).ToNot(HaveOccurred()) + authorizationMock.EXPECT(). + ResourceReview(listCtx.ctx, listCtx.username, auth.GetAction, "GenericTestDinosaur"). + Return(response, nil) + serviceErr = genericService.populateSearchRestriction(listCtx, model) + Expect(serviceErr).ToNot(HaveOccurred()) + Expect(listCtx.args.Search).To(BeEmpty()) + }) + }) + When("Auth restricts orgs", func() { + It("Allows only returned orgs", func() { + args := NewListArguments(url.Values{}) + listCtx, model, serviceErr := newListContext(ctx, args, &[]GenericTestDinosaur{}) + resourceModel := reflect.TypeOf(&GenericTestDinosaur{}).Elem() + Expect(model).To(Equal(reflect.New(resourceModel).Interface())) + Expect(serviceErr).ToNot(HaveOccurred()) + response, err := azv1.NewResourceReview(). + AccountUsername(listCtx.username). + Action(auth.GetAction). + ResourceType("GenericTestDinosaur"). + OrganizationIDs("123", "124"). + Build() + Expect(err).ToNot(HaveOccurred()) + authorizationMock.EXPECT(). + ResourceReview(listCtx.ctx, listCtx.username, auth.GetAction, "GenericTestDinosaur"). + Return(response, nil) + serviceErr = genericService.populateSearchRestriction(listCtx, model) + Expect(serviceErr).ToNot(HaveOccurred()) + Expect(listCtx.args.Search).ToNot(BeEmpty()) + Expect(listCtx.args.Search).To(Equal("organization_id in ('123','124')")) + }) + It("Includes pre existing search", func() { + args := NewListArguments(url.Values{}) + args.Search = "justification like '%test%'" + listCtx, model, serviceErr := newListContext(ctx, args, &[]GenericTestDinosaur{}) + resourceModel := reflect.TypeOf(&GenericTestDinosaur{}).Elem() + Expect(model).To(Equal(reflect.New(resourceModel).Interface())) + Expect(serviceErr).ToNot(HaveOccurred()) + response, err := azv1.NewResourceReview(). + AccountUsername(listCtx.username). + Action(auth.GetAction). + ResourceType("GenericTestDinosaur"). + OrganizationIDs("123", "124"). + Build() + Expect(err).ToNot(HaveOccurred()) + authorizationMock.EXPECT(). + ResourceReview(listCtx.ctx, listCtx.username, auth.GetAction, "GenericTestDinosaur"). + Return(response, nil) + serviceErr = genericService.populateSearchRestriction(listCtx, model) + Expect(serviceErr).ToNot(HaveOccurred()) + Expect(listCtx.args.Search).ToNot(BeEmpty()) + Expect( + listCtx.args.Search, + ).To(Equal("justification like '%test%' and organization_id in ('123','124')")) + }) + }) + }) +}) + +var _ = Describe("Sql Translation", func() { + var genericService sqlGenericService + var genericDao dao.GenericDao + BeforeEach(func() { + dbConfig := config.NewDatabaseConfig() + err := dbConfig.ReadFiles() + Expect(err).ToNot(HaveOccurred()) + var dbFactory db.SessionFactory = db_session.NewTestFactory(dbConfig) + defer dbFactory.Close() + + genericDao = dao.NewGenericDao(&dbFactory) + genericService = sqlGenericService{genericDao: genericDao} + }) + DescribeTable( + "Errors", + func( + search string, errorMsg string) { + listCtx, model, serviceErr := newListContext( + context.Background(), + &ListArguments{Search: search}, + &[]api.Dinosaur{}, + ) + Expect(serviceErr).ToNot(HaveOccurred()) + d := genericDao.GetInstanceDao(context.Background(), model) + listCtx.disallowedFields = []string{"dinosaurs.id"} + _, serviceErr = genericService.buildSearch(listCtx, &d) + Expect(serviceErr).To(HaveOccurred()) + Expect(serviceErr.Code).To(Equal(errors.ErrorBadRequest)) + Expect(serviceErr.Error()).To(Equal(errorMsg)) }, - } - for _, test := range tests { - list := []api.Dinosaur{} - search := test["search"].(string) - errorMsg := test["error"].(string) - listCtx, model, serviceErr := genericService.newListContext(context.Background(), "", &ListArguments{Search: search}, &list) - Expect(serviceErr).ToNot(HaveOccurred()) - d := g.GetInstanceDao(context.Background(), model) - (*listCtx.disallowedFields)["id"] = "id" - _, serviceErr = genericService.buildSearch(listCtx, &d) - Expect(serviceErr).To(HaveOccurred()) - Expect(serviceErr.Code).To(Equal(errors.ErrorBadRequest)) - Expect(serviceErr.Error()).To(Equal(errorMsg)) - } + Entry("Garbage", "garbage", "rh-trex-21: Failed to parse search query: garbage"), + Entry("Disallowed field name", "id in ('123')", "rh-trex-21: dinosaurs.id is a disallowed field name"), + Entry("Unknown field name", "bike = '123'", "rh-trex-21: dinosaurs.bike is not a valid field name"), + Entry( + "Unknown relation field", + "status.bike = '123'", + "rh-trex-21: status.bike is not a related resource of Dinosaur", + ), + ) - // tests for sql parsing - tests = []map[string]interface{}{ - { - "search": "username in ('ooo.openshift')", - "sql": "username IN (?)", - "values": ConsistOf("ooo.openshift"), + DescribeTable( + "Sql Parsing", + func( + search string, sqlReal string, valuesReal types.GomegaMatcher) { + listCtx, _, serviceErr := newListContext( + context.Background(), + &ListArguments{Search: search}, + &[]api.Dinosaur{}, + ) + Expect(serviceErr).ToNot(HaveOccurred()) + tslTree, err := tsl.ParseTSL(search) + Expect(err).ToNot(HaveOccurred()) + _, sqlizer, serviceErr := genericService.treeWalkForSqlizer(listCtx, tslTree) + Expect(serviceErr).ToNot(HaveOccurred()) + sql, values, err := sqlizer.ToSql() + Expect(err).ToNot(HaveOccurred()) + Expect(sql).To(Equal(sqlReal)) + Expect(values).To(valuesReal) }, - } - for _, test := range tests { - list := []api.Dinosaur{} - search := test["search"].(string) - sqlReal := test["sql"].(string) - valuesReal := test["values"].(types.GomegaMatcher) - listCtx, _, serviceErr := genericService.newListContext(context.Background(), "", &ListArguments{Search: search}, &list) - Expect(serviceErr).ToNot(HaveOccurred()) - tslTree, err := tsl.ParseTSL(search) - Expect(err).ToNot(HaveOccurred()) - _, sqlizer, serviceErr := genericService.treeWalkForSqlizer(listCtx, tslTree) - Expect(serviceErr).ToNot(HaveOccurred()) - sql, values, err := sqlizer.ToSql() - Expect(err).ToNot(HaveOccurred()) - Expect(sql).To(Equal(sqlReal)) - Expect(values).To(valuesReal) - } -} + Entry( + "Valid search", + "dinosaurs.species like '%test%'", + "dinosaurs.species LIKE ?", + ConsistOf("%test%"), + ), + ) +}) diff --git a/pkg/services/suite_test.go b/pkg/services/suite_test.go new file mode 100644 index 00000000..ce99c236 --- /dev/null +++ b/pkg/services/suite_test.go @@ -0,0 +1,13 @@ +package services + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +func TestServices(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Services Suite") +} diff --git a/pkg/util/strings.go b/pkg/util/strings.go new file mode 100644 index 00000000..ee0597ca --- /dev/null +++ b/pkg/util/strings.go @@ -0,0 +1,16 @@ +package util + +import ( + "regexp" + "strings" +) + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func ToSnakeCase(str string) string { + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + snake = strings.ReplaceAll(snake, " ", "") + return strings.ToLower(snake) +} diff --git a/pkg/util/strings_test.go b/pkg/util/strings_test.go new file mode 100644 index 00000000..a869ab0b --- /dev/null +++ b/pkg/util/strings_test.go @@ -0,0 +1,17 @@ +package util + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Strings util", func() { + Describe("ToSnakeCase", func() { + It("transforms", func() { + Expect(ToSnakeCase("asd")).To(Equal("asd")) + Expect(ToSnakeCase("AsdAsd")).To(Equal("asd_asd")) + Expect(ToSnakeCase("asdAsd")).To(Equal("asd_asd")) + Expect(ToSnakeCase("Asd Asd")).To(Equal("asd_asd")) + }) + }) +}) diff --git a/pkg/util/suite_test.go b/pkg/util/suite_test.go new file mode 100644 index 00000000..d164965a --- /dev/null +++ b/pkg/util/suite_test.go @@ -0,0 +1,13 @@ +package util + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +func TestAccessProtection(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils Suite") +} diff --git a/pkg/util/test/types.go b/pkg/util/test/types.go new file mode 100644 index 00000000..f433ee06 --- /dev/null +++ b/pkg/util/test/types.go @@ -0,0 +1,3 @@ +package test + +type TestSomeType struct{} diff --git a/pkg/util/types.go b/pkg/util/types.go new file mode 100644 index 00000000..1ee03bf9 --- /dev/null +++ b/pkg/util/types.go @@ -0,0 +1,24 @@ +package util + +import ( + "reflect" + "strings" +) + +// Removes package from result +// Removes pointer from result +func GetBaseType(myvar any) string { + result := GetType(myvar) + if i := strings.Index(result, "."); i != -1 { + result = result[i+1:] + } + if strings.HasPrefix("*", result) { + result = result[1:] + } + return result +} + +func GetType(myvar any) string { + result := reflect.TypeOf(myvar).String() + return result +} diff --git a/pkg/util/types_test.go b/pkg/util/types_test.go new file mode 100644 index 00000000..b066cd21 --- /dev/null +++ b/pkg/util/types_test.go @@ -0,0 +1,29 @@ +package util + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openshift-online/rh-trex/pkg/util/test" +) + +type SomeType struct{} + +var _ = Describe("Strings util", func() { + Describe("GetBaseType", func() { + It("computes", func() { + Expect(GetBaseType("asd")).To(Equal("string")) + Expect(GetBaseType(SomeType{})).To(Equal("SomeType")) + Expect(GetBaseType(&SomeType{})).To(Equal("SomeType")) + Expect(GetBaseType(test.TestSomeType{})).To(Equal("TestSomeType")) + }) + }) + + Describe("GetType", func() { + It("computes", func() { + Expect(GetType("asd")).To(Equal("string")) + Expect(GetType(SomeType{})).To(Equal("util.SomeType")) + Expect(GetType(&SomeType{})).To(Equal("*util.SomeType")) + Expect(GetType(test.TestSomeType{})).To(Equal("test.TestSomeType")) + }) + }) +}) diff --git a/templates/generate-handlers.txt b/templates/generate-handlers.txt index 441c17ad..19148ae6 100755 --- a/templates/generate-handlers.txt +++ b/templates/generate-handlers.txt @@ -83,7 +83,7 @@ func (h {{.KindLowerSingular}}Handler) List(w http.ResponseWriter, r *http.Reque listArgs := services.NewListArguments(r.URL.Query()) var {{.KindLowerPlural}} = []api.{{.Kind}}{} - paging, err := h.generic.List(ctx, "username", listArgs, &{{.KindLowerPlural}}) + paging, err := h.generic.List(ctx, listArgs, &{{.KindLowerPlural}}) if err != nil { return nil, err } diff --git a/test/mocks/mocks.go b/test/mocks/mocks.go deleted file mode 100644 index c9dcdc32..00000000 --- a/test/mocks/mocks.go +++ /dev/null @@ -1,19 +0,0 @@ -package mocks - -import ( - "net/http" - "net/http/httptest" - "time" -) - -// Returns a server that will wait waitTime when hit at endpoint -func NewMockServerTimeout(endpoint string, waitTime time.Duration) (*httptest.Server, func()) { - apiHandler := http.NewServeMux() - apiHandler.HandleFunc(endpoint, - func(w http.ResponseWriter, r *http.Request) { - time.Sleep(waitTime) - }, - ) - server := httptest.NewServer(apiHandler) - return server, server.Close -} diff --git a/test/mocks/ocm.go b/test/mocks/ocm.go deleted file mode 100644 index 3b405416..00000000 --- a/test/mocks/ocm.go +++ /dev/null @@ -1,65 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/openshift-online/rh-trex/pkg/client/ocm" -) - -/* -The OCM Validator Mock will simply return true to all access_review requests instead -of reaching out to the AMS system or using the built-in OCM mock. It will record -the action and resourceType sent to it in the struct itself. This can be used -to validate that the expected action/resourceType for a particular endpoint was -determined in the authorization middleware - -Use: - h, client := test.RegisterIntegration(t) - authzMock, ocmMock := mocks.NewOCMAuthzValidatorMockClient() - // Use the OCM client mock, re-load services so they pick up the mock - h.Env().Clients.OCM = ocmMock - // The built-in mock has to be disabled or the server will use it instead - h.Env().Config.OCM.EnableMock = false - // Services and the server should be re-loaded to pick up the client with this mock - h.Env().LoadServices() - h.RestartServer() - - // Make a request, then validate the action and resourceType - Expect(authzMock.Action).To(Equal("get")) - Expect(authzMock.ResourceType).To(Equal("JQSJobQueue")) - authzMock.Reset() -*/ - -var _ ocm.OCMAuthorization = &OCMAuthzValidatorMock{} - -type OCMAuthzValidatorMock struct { - Action string - ResourceType string -} - -func NewOCMAuthzValidatorMockClient() (*OCMAuthzValidatorMock, *ocm.Client) { - authz := &OCMAuthzValidatorMock{ - Action: "", - ResourceType: "", - } - client := &ocm.Client{} - client.Authorization = authz - return authz, client -} - -func (m *OCMAuthzValidatorMock) SelfAccessReview(ctx context.Context, action, resourceType, organizationID, subscriptionID, clusterID string) (allowed bool, err error) { - m.Action = action - m.ResourceType = resourceType - return true, nil -} - -func (m *OCMAuthzValidatorMock) AccessReview(ctx context.Context, username, action, resourceType, organizationID, subscriptionID, clusterID string) (allowed bool, err error) { - m.Action = action - m.ResourceType = resourceType - return true, nil -} - -func (m OCMAuthzValidatorMock) Reset() { - m.Action = "" - m.ResourceType = "" -}