diff --git a/pkg/controller/postgres/postgres_controller.go b/pkg/controller/postgres/postgres_controller.go index 8fda12a5..1d19dceb 100644 --- a/pkg/controller/postgres/postgres_controller.go +++ b/pkg/controller/postgres/postgres_controller.go @@ -220,17 +220,12 @@ func (r *ReconcilePostgres) Reconcile(request reconcile.Request) (_ reconcile.Re } // Set privileges on schema - err = r.pg.SetSchemaPrivileges(database, owner, reader, schema, readerPrivs, reqLogger) + err = r.pg.SetSchemaPrivileges(database, owner, reader, schema, readerPrivs, false, reqLogger) if err != nil { reqLogger.Error(err, fmt.Sprintf("Could not give %s permissions \"%s\"", reader, readerPrivs)) continue } - err = r.pg.SetSchemaPrivileges(database, owner, writer, schema, writerPrivs, reqLogger) - if err != nil { - reqLogger.Error(err, fmt.Sprintf("Could not give %s permissions \"%s\"", writer, writerPrivs)) - continue - } - err = r.pg.SetSchemaPrivilegesCreate(database, owner, writer, schema, writerPrivs, reqLogger) + err = r.pg.SetSchemaPrivileges(database, owner, writer, schema, writerPrivs, true, reqLogger) if err != nil { reqLogger.Error(err, fmt.Sprintf("Could not give %s permissions \"%s\"", writer, writerPrivs)) continue diff --git a/pkg/controller/postgres/postgres_controller_test.go b/pkg/controller/postgres/postgres_controller_test.go index 685aae41..97bb74a9 100644 --- a/pkg/controller/postgres/postgres_controller_test.go +++ b/pkg/controller/postgres/postgres_controller_test.go @@ -682,12 +682,10 @@ var _ = Describe("ReconcilePostgres", func() { // Expected method calls // customers schema pg.EXPECT().CreateSchema(name, name+"-group", "customers", gomock.Any()).Return(nil).Times(1) - pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any()).Return(nil).Times(2) - pg.EXPECT().SetSchemaPrivilegesCreate(name, name+"-group", name+"-writer", "customers", gomock.Any(), gomock.Any()).Return(nil).Times(1) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(2) // stores schema pg.EXPECT().CreateSchema(name, name+"-group", "stores", gomock.Any()).Return(nil).Times(1) - pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "stores", gomock.Any(), gomock.Any()).Return(nil).Times(2) - pg.EXPECT().SetSchemaPrivilegesCreate(name, name+"-group", name+"-writer", "stores", gomock.Any(), gomock.Any()).Return(nil).Times(1) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "stores", gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(2) }) It("should update status", func() { @@ -709,12 +707,11 @@ var _ = Describe("ReconcilePostgres", func() { // Expected method calls // customers schema errors pg.EXPECT().CreateSchema(name, name+"-group", "customers", gomock.Any()).Return(fmt.Errorf("Could not create schema")).Times(1) - pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any()).Return(nil).Times(0) - pg.EXPECT().SetSchemaPrivilegesCreate(name, name+"-group", name+"-writer", "customers", gomock.Any(), gomock.Any()).Return(nil).Times(0) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any() ,gomock.Any()).Return(nil).Times(0) // stores schema pg.EXPECT().CreateSchema(name, name+"-group", "stores", gomock.Any()).Return(nil).Times(1) - pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "stores", gomock.Any(), gomock.Any()).Return(nil).Times(2) - pg.EXPECT().SetSchemaPrivilegesCreate(name, name+"-group", name+"-writer", "stores", gomock.Any(), gomock.Any()).Return(nil).Times(1) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "stores", gomock.Any(), false, gomock.Any()).Return(nil).Times(1) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "stores", gomock.Any(), true, gomock.Any()).Return(nil).Times(1) }) It("should update status", func() { @@ -755,8 +752,7 @@ var _ = Describe("ReconcilePostgres", func() { // Expected method calls // customers schema pg.EXPECT().CreateSchema(name, name+"-group", "customers", gomock.Any()).Return(nil).Times(1) - pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any()).Return(nil).Times(2) - pg.EXPECT().SetSchemaPrivilegesCreate(name, name+"-group", name+"-writer", "customers", gomock.Any(), gomock.Any()).Return(nil).Times(1) + pg.EXPECT().SetSchemaPrivileges(name, name+"-group", gomock.Any(), "customers", gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(2) // stores schema already exists pg.EXPECT().CreateSchema(name, name+"-group", "stores", gomock.Any()).Times(0) // Call reconcile diff --git a/pkg/postgres/database.go b/pkg/postgres/database.go index da198605..4fc95bef 100644 --- a/pkg/postgres/database.go +++ b/pkg/postgres/database.go @@ -95,7 +95,7 @@ func (c *pg) CreateExtension(db, extension string, logger logr.Logger) error { return nil } -func (c *pg) SetSchemaPrivileges(db, creator, role, schema, privs string, logger logr.Logger) error { +func (c *pg) SetSchemaPrivileges(db, creator, role, schema, privs string, createSchema bool, logger logr.Logger) error { tmpDb, err := GetConnection(c.user, c.pass, c.host, db, c.args, logger) if err != nil { return err @@ -119,20 +119,14 @@ func (c *pg) SetSchemaPrivileges(db, creator, role, schema, privs string, logger if err != nil { return err } - return nil -} -func (c *pg) SetSchemaPrivilegesCreate(db, creator, role, schema, privs string, logger logr.Logger) error { - tmpDb, err := GetConnection(c.user, c.pass, c.host, db, c.args, logger) - if err != nil { - return err + // Grant role usage on schema if createSchema + if createSchema { + _, err = tmpDb.Exec(fmt.Sprintf(GRANT_CREATE_TABLE, schema, role)) + if err != nil { + return err + } } - defer tmpDb.Close() - // Grant role usage on schema - _, err = tmpDb.Exec(fmt.Sprintf(GRANT_CREATE_TABLE, schema, role)) - if err != nil { - return err - } return nil } diff --git a/pkg/postgres/mock/postgres.go b/pkg/postgres/mock/postgres.go index 47386eab..ce58f671 100644 --- a/pkg/postgres/mock/postgres.go +++ b/pkg/postgres/mock/postgres.go @@ -133,31 +133,17 @@ func (mr *MockPGMockRecorder) GrantRole(role, grantee interface{}) *gomock.Call } // SetSchemaPrivileges mocks base method -func (m *MockPG) SetSchemaPrivileges(db, creator, role, schema, privs string, logger logr.Logger) error { +func (m *MockPG) SetSchemaPrivileges(db, creator, role, schema, privs string, createSchema bool, logger logr.Logger) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetSchemaPrivileges", db, creator, role, schema, privs, logger) + ret := m.ctrl.Call(m, "SetSchemaPrivileges", db, creator, role, schema, privs, createSchema, logger) ret0, _ := ret[0].(error) return ret0 } // SetSchemaPrivileges indicates an expected call of SetSchemaPrivileges -func (mr *MockPGMockRecorder) SetSchemaPrivileges(db, creator, role, schema, privs, logger interface{}) *gomock.Call { +func (mr *MockPGMockRecorder) SetSchemaPrivileges(db, creator, role, schema, privs, createSchema, logger interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSchemaPrivileges", reflect.TypeOf((*MockPG)(nil).SetSchemaPrivileges), db, creator, role, schema, privs, logger) -} - -// SetSchemaPrivilegesCreate mocks base method -func (m *MockPG) SetSchemaPrivilegesCreate(db, creator, role, schema, privs string, logger logr.Logger) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetSchemaPrivilegesCreate", db, creator, role, schema, privs, logger) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetSchemaPrivilegesCreate indicates an expected call of SetSchemaPrivilegesCreate -func (mr *MockPGMockRecorder) SetSchemaPrivilegesCreate(db, creator, role, schema, privs, logger interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSchemaPrivilegesCreate", reflect.TypeOf((*MockPG)(nil).SetSchemaPrivilegesCreate), db, creator, role, schema, privs, logger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSchemaPrivileges", reflect.TypeOf((*MockPG)(nil).SetSchemaPrivileges), db, creator, role, schema, privs, createSchema, logger) } // RevokeRole mocks base method diff --git a/pkg/postgres/postgres.go b/pkg/postgres/postgres.go index f8fa0f88..f596dd46 100644 --- a/pkg/postgres/postgres.go +++ b/pkg/postgres/postgres.go @@ -16,8 +16,7 @@ type PG interface { CreateUserRole(role, password string) (string, error) UpdatePassword(role, password string) error GrantRole(role, grantee string) error - SetSchemaPrivileges(db, creator, role, schema, privs string, logger logr.Logger) error - SetSchemaPrivilegesCreate(db, creator, role, schema, privs string, logger logr.Logger) error + SetSchemaPrivileges(db, creator, role, schema, privs string, createSchema bool, logger logr.Logger) error RevokeRole(role, revoked string) error AlterDefaultLoginRole(role, setRole string) error DropDatabase(db string, logger logr.Logger) error