Skip to content

Commit

Permalink
RBAC custom privilege group ut coverage
Browse files Browse the repository at this point in the history
Signed-off-by: shaoting-huang <[email protected]>
  • Loading branch information
shaoting-huang committed Nov 9, 2024
1 parent 70605cf commit 7b17f35
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 13 deletions.
18 changes: 10 additions & 8 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
router.POST(RoleCategory+RevokePrivilegeAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.removePrivilegeFromRole))))

// privilege group
router.POST(PrivilegeGroupCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.createPrivilegeGroup))))
router.POST(PrivilegeGroupCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.dropPrivilegeGroup))))
router.POST(PrivilegeGroupCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.listPrivilegeGroups))))
router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.addPrivilegesToGroup))))
router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup))))
router.POST(PrivilegeGroupCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.createPrivilegeGroup))))
router.POST(PrivilegeGroupCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.dropPrivilegeGroup))))
router.POST(PrivilegeGroupCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.listPrivilegeGroups))))
router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.addPrivilegesToGroup))))
router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup))))

router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes)))))
router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex)))))
Expand Down Expand Up @@ -1857,9 +1857,11 @@ func (h *HandlersV2) removePrivilegesFromGroup(ctx context.Context, c *gin.Conte
func (h *HandlersV2) operatePrivilegeGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string, operateType milvuspb.OperatePrivilegeGroupType) (interface{}, error) {
httpReq := anyReq.(*PrivilegeGroupReq)
req := &milvuspb.OperatePrivilegeGroupRequest{
GroupName: httpReq.PrivilegeGroupName,
Privileges: httpReq.Privileges,
Type: operateType,
GroupName: httpReq.PrivilegeGroupName,
Privileges: lo.Map(httpReq.Privileges, func(p string, _ int) *milvuspb.PrivilegeEntity {
return &milvuspb.PrivilegeEntity{Name: p}
}),
Type: operateType,
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperatePrivilegeGroup", func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.OperatePrivilegeGroup(reqCtx, req.(*milvuspb.OperatePrivilegeGroupRequest))
Expand Down
28 changes: 26 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,10 @@ func TestMethodGet(t *testing.T) {
Status: &StatusSuccess,
Alias: DefaultAliasName,
}, nil).Once()
mp.EXPECT().ListPrivilegeGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListPrivilegeGroupsResponse{
Status: &StatusSuccess,
PrivilegeGroups: []*milvuspb.PrivilegeGroupInfo{{GroupName: "group1", Privileges: []*milvuspb.PrivilegeEntity{{Name: "*"}}}},
}, nil).Once()

testEngine := initHTTPServerV2(mp, false)
queryTestCases := []rawTestCase{}
Expand Down Expand Up @@ -1320,6 +1324,9 @@ func TestMethodGet(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(AliasCategory, DescribeAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PrivilegeGroupCategory, ListAction),
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
Expand All @@ -1329,7 +1336,8 @@ func TestMethodGet(t *testing.T) {
`"indexName": "` + DefaultIndexName + `",` +
`"userName": "` + util.UserRoot + `",` +
`"roleName": "` + util.RoleAdmin + `",` +
`"aliasName": "` + DefaultAliasName + `"` +
`"aliasName": "` + DefaultAliasName + `",` +
`"privilegeGroupName": "pg"` +
`}`))
req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader)
w := httptest.NewRecorder()
Expand Down Expand Up @@ -1369,6 +1377,7 @@ func TestMethodDelete(t *testing.T) {
mp.EXPECT().DropRole(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().DropPrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []rawTestCase{}
queryTestCases = append(queryTestCases, rawTestCase{
Expand All @@ -1389,10 +1398,13 @@ func TestMethodDelete(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(AliasCategory, DropAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PrivilegeGroupCategory, DropAction),
})
for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "partitionName": "` + DefaultPartitionName +
`", "userName": "` + util.UserRoot + `", "roleName": "` + util.RoleAdmin + `", "indexName": "` + DefaultIndexName + `", "aliasName": "` + DefaultAliasName + `"}`))
`", "userName": "` + util.UserRoot + `", "roleName": "` + util.RoleAdmin + `", "indexName": "` + DefaultIndexName + `", "aliasName": "` + DefaultAliasName + `", "privilegeGroupName": "pg"}`))
req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader)
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
Expand Down Expand Up @@ -1431,6 +1443,8 @@ func TestMethodPost(t *testing.T) {
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
mp.EXPECT().CreateAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().AlterAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().CreatePrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once()
mp.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
mp.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(&internalpb.ImportResponse{
Status: commonSuccessStatus, JobID: "1234567890",
}, nil).Once()
Expand Down Expand Up @@ -1523,6 +1537,15 @@ func TestMethodPost(t *testing.T) {
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(ImportJobCategory, DescribeAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PrivilegeGroupCategory, CreateAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PrivilegeGroupCategory, AddPrivilegesToGroupAction),
})
queryTestCases = append(queryTestCases, rawTestCase{
path: versionalV2(PrivilegeGroupCategory, RemovePrivilegesFromGroupAction),
})

for _, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
Expand All @@ -1533,6 +1556,7 @@ func TestMethodPost(t *testing.T) {
`"indexParams": [{"indexName": "` + DefaultIndexName + `", "fieldName": "book_intro", "metricType": "L2", "params": {"nlist": 30, "index_type": "IVF_FLAT"}}],` +
`"userName": "` + util.UserRoot + `", "password": "Milvus", "newPassword": "milvus", "roleName": "` + util.RoleAdmin + `",` +
`"roleName": "` + util.RoleAdmin + `", "objectType": "Global", "objectName": "*", "privilege": "*",` +
`"privilegeGroupName": "pg", "privileges": ["create", "drop"],` +
`"aliasName": "` + DefaultAliasName + `",` +
`"jobId": "1234567890",` +
`"files": [["book.json"]]` +
Expand Down
5 changes: 2 additions & 3 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/gin-gonic/gin"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
)

Expand Down Expand Up @@ -275,8 +274,8 @@ func (req *RoleReq) GetRoleName() string {
}

type PrivilegeGroupReq struct {
PrivilegeGroupName string `json:"privilegeGroupName" binding:"required"`
Privileges []*milvuspb.PrivilegeEntity `json:"privileges"`
PrivilegeGroupName string `json:"privilegeGroupName" binding:"required"`
Privileges []string `json:"privileges"`
}

type GrantReq struct {
Expand Down
12 changes: 12 additions & 0 deletions internal/rootcoord/meta_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2093,10 +2093,22 @@ func TestMetaTable_PrivilegeGroup(t *testing.T) {
}
err := mt.CreatePrivilegeGroup("pg1")
assert.Error(t, err)
err = mt.CreatePrivilegeGroup("")
assert.Error(t, err)
err = mt.CreatePrivilegeGroup("Insert")
assert.Error(t, err)
err = mt.CreatePrivilegeGroup("pg2")
assert.NoError(t, err)
err = mt.DropPrivilegeGroup("")
assert.Error(t, err)
err = mt.DropPrivilegeGroup("pg1")
assert.NoError(t, err)
err = mt.OperatePrivilegeGroup("", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup)
assert.Error(t, err)
err = mt.OperatePrivilegeGroup("pg3", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup)
assert.Error(t, err)
_, err = mt.GetPrivilegeGroupRoles("")
assert.Error(t, err)
_, err = mt.ListPrivilegeGroups()
assert.NoError(t, err)
}
30 changes: 30 additions & 0 deletions internal/rootcoord/root_coord_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,36 @@ func TestRootCoord_RBACError(t *testing.T) {
}
})

t.Run("operate privilege group failed", func(t *testing.T) {
{
resp, err := c.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
}
mockMeta := c.meta.(*mockMetaTable)
mockMeta.ListPrivilegeGroupsFunc = func() ([]*milvuspb.PrivilegeGroupInfo, error) {
return nil, errors.New("mock error")
}
mockMeta.CreatePrivilegeGroupFunc = func(groupName string) error {
return errors.New("mock error")
}
{
resp, err := c.ListPrivilegeGroups(ctx, &milvuspb.ListPrivilegeGroupsRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
}
{
resp, err := c.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
}
{
resp, err := c.CreatePrivilegeGroup(ctx, &milvuspb.CreatePrivilegeGroupRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
}
})

t.Run("select grant failed", func(t *testing.T) {
{
resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{})
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/rbac/privilege_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,41 @@ func (s *PrivilegeGroupTestSuite) TestCustomPrivilegeGroup() {
s.True(merr.Ok(dropRoleResp))
}

func (s *PrivilegeGroupTestSuite) TestInvalidPrivilegeGroup() {
ctx := GetContext(context.Background(), "root:123456")

createResp, err := s.Cluster.Proxy.CreatePrivilegeGroup(ctx, &milvuspb.CreatePrivilegeGroupRequest{
GroupName: "",
})
s.NoError(err)
s.False(merr.Ok(createResp))

dropResp, err := s.Cluster.Proxy.DropPrivilegeGroup(ctx, &milvuspb.DropPrivilegeGroupRequest{
GroupName: "group1",
})
s.NoError(err)
s.True(merr.Ok(dropResp))

dropResp, err = s.Cluster.Proxy.DropPrivilegeGroup(ctx, &milvuspb.DropPrivilegeGroupRequest{
GroupName: "",
})
s.NoError(err)
s.False(merr.Ok(dropResp))

operateResp, err := s.Cluster.Proxy.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{
GroupName: "",
})
s.NoError(err)
s.False(merr.Ok(operateResp))

operateResp, err = s.Cluster.Proxy.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{
GroupName: "group1",
Privileges: []*milvuspb.PrivilegeEntity{{Name: "123"}},
})
s.NoError(err)
s.False(merr.Ok(operateResp))
}

func (s *PrivilegeGroupTestSuite) operatePrivilege(ctx context.Context, role, privilege, objectType string, operateType milvuspb.OperatePrivilegeType) {
resp, err := s.Cluster.Proxy.OperatePrivilege(ctx, &milvuspb.OperatePrivilegeRequest{
Type: operateType,
Expand Down

0 comments on commit 7b17f35

Please sign in to comment.