diff --git a/internal/biz/biz.go b/internal/biz/biz.go index 69cde28..d6f9747 100644 --- a/internal/biz/biz.go +++ b/internal/biz/biz.go @@ -5,4 +5,4 @@ import ( ) // ProviderSet is biz providers. -var ProviderSet = wire.NewSet(NewCreateRelationshipsUsecase) +var ProviderSet = wire.NewSet(NewCreateRelationshipsUsecase, NewReadRelationshipsUsecase) diff --git a/internal/biz/relationships.go b/internal/biz/relationships.go index 8a943c8..a67ab82 100644 --- a/internal/biz/relationships.go +++ b/internal/biz/relationships.go @@ -3,6 +3,7 @@ package biz import ( v1 "ciam-rebac/api/rebac/v1" "context" + "github.com/go-kratos/kratos/v2/log" ) @@ -11,7 +12,7 @@ type TouchSemantics bool type ZanzibarRepository interface { CreateRelationships(context.Context, []*v1.Relationship, TouchSemantics) error - ReadRelationships(context.Context, []*v1.RelationshipFilter) ([]*v1.Relationship, error) + ReadRelationships(context.Context, *v1.RelationshipFilter) ([]*v1.Relationship, error) DeleteRelationships(context.Context, []*v1.RelationshipFilter) ([]*v1.Relationship, error) } @@ -28,3 +29,17 @@ func (rc *CreateRelationshipsUsecase) CreateRelationships(ctx context.Context, r rc.log.WithContext(ctx).Infof("CreateRelationships: %v %s", r, touch) return rc.repo.CreateRelationships(ctx, r, TouchSemantics(touch)) } + +type ReadRelationshipsUsecase struct { + repo ZanzibarRepository + log *log.Helper +} + +func NewReadRelationshipsUsecase(repo ZanzibarRepository, logger log.Logger) *ReadRelationshipsUsecase { + return &ReadRelationshipsUsecase{repo: repo, log: log.NewHelper(logger)} +} + +func (rc *ReadRelationshipsUsecase) ReadRelationships(ctx context.Context, r *v1.RelationshipFilter) ([]*v1.Relationship, error) { + rc.log.WithContext(ctx).Infof("ReadRelationships: %v", r) + return rc.repo.ReadRelationships(ctx, r) +} diff --git a/internal/data/spicedb.go b/internal/data/spicedb.go index 7335755..0119ea7 100644 --- a/internal/data/spicedb.go +++ b/internal/data/spicedb.go @@ -5,7 +5,10 @@ import ( "ciam-rebac/internal/biz" "ciam-rebac/internal/conf" "context" + "errors" "fmt" + "io" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/authzed-go/v1" "github.com/authzed/grpcutil" @@ -86,9 +89,62 @@ func (s *SpiceDbRepository) CreateRelationships(ctx context.Context, rels []*api return err } -func (s *SpiceDbRepository) ReadRelationships(ctx context.Context, filter []*apiV1.RelationshipFilter) ([]*apiV1.Relationship, error) { - //TODO implement me - panic("implement me") +func (s *SpiceDbRepository) ReadRelationships(ctx context.Context, filter *apiV1.RelationshipFilter) ([]*apiV1.Relationship, error) { + req := &v1.ReadRelationshipsRequest{} + + if filter != nil { + req.RelationshipFilter = &v1.RelationshipFilter{ + ResourceType: filter.ObjectType, + OptionalResourceId: filter.ObjectId, + OptionalRelation: filter.Relation, + } + + if filter.SubjectFilter != nil { + req.RelationshipFilter.OptionalSubjectFilter = &v1.SubjectFilter{ + SubjectType: filter.SubjectFilter.SubjectType, + OptionalSubjectId: filter.SubjectFilter.SubjectId, + } + + if filter.SubjectFilter.Relation != "" { + req.RelationshipFilter.OptionalSubjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{ + Relation: filter.SubjectFilter.Relation, + } + } + } + } + + client, err := s.client.ReadRelationships(ctx, req) + + if err != nil { + return nil, err + } + + results := make([]*apiV1.Relationship, 0, 0) + resp, err := client.Recv() + for err == nil { + results = append(results, &apiV1.Relationship{ + Object: &apiV1.ObjectReference{ + Type: resp.Relationship.Resource.ObjectType, + Id: resp.Relationship.Resource.ObjectId, + }, + Relation: resp.Relationship.Relation, + Subject: &apiV1.SubjectReference{ + Relation: resp.Relationship.Subject.OptionalRelation, + Object: &apiV1.ObjectReference{ + Type: resp.Relationship.Subject.Object.ObjectType, + Id: resp.Relationship.Subject.Object.ObjectId, + }, + }, + }) + + resp, err = client.Recv() + } + + if !errors.Is(err, io.EOF) { + return nil, err + } + + return results, nil } func (s *SpiceDbRepository) DeleteRelationships(ctx context.Context, filter []*apiV1.RelationshipFilter) ([]*apiV1.Relationship, error) { diff --git a/internal/data/spicedb_test.go b/internal/data/spicedb_test.go index e4df01b..401088b 100644 --- a/internal/data/spicedb_test.go +++ b/internal/data/spicedb_test.go @@ -5,13 +5,14 @@ import ( "ciam-rebac/internal/biz" "context" "fmt" + "os" + "testing" + "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware/tracing" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "os" - "testing" ) var container *LocalSpiceDbContainer @@ -158,6 +159,42 @@ func TestCreateRelationshipFailsWithBadObjectType(t *testing.T) { assert.Contains(t, err.Error(), "object definition `"+badObjectType+"` not found") } +func TestWriteAndReadBackRelationships(t *testing.T) { + t.Parallel() + + ctx := context.Background() + spiceDbRepo, err := container.CreateSpiceDbRepository() + if !assert.NoError(t, err) { + return + } + + assert.NoError(t, err) + rels := []*apiV1.Relationship{ + createRelationship("bob", "user", "", "member", "group", "bob_club"), + } + + err = spiceDbRepo.CreateRelationships(ctx, rels, biz.TouchSemantics(true)) + if !assert.NoError(t, err) { + return + } + + readrels, err := spiceDbRepo.ReadRelationships(ctx, &apiV1.RelationshipFilter{ + ObjectId: "bob_club", + ObjectType: "group", + Relation: "member", + SubjectFilter: &apiV1.SubjectFilter{ + SubjectId: "bob", + SubjectType: "user", + }, + }) + + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, 1, len(readrels)) +} + func createRelationship(subjectId string, subjectType string, subjectRelationship string, relationship string, objectType string, objectId string) *apiV1.Relationship { subject := &apiV1.SubjectReference{ Object: &apiV1.ObjectReference{ diff --git a/internal/service/relationships.go b/internal/service/relationships.go index 6321a07..82f0113 100644 --- a/internal/service/relationships.go +++ b/internal/service/relationships.go @@ -12,6 +12,7 @@ import ( type RelationshipsService struct { pb.UnimplementedRelationshipsServer createUsecase *biz.CreateRelationshipsUsecase + readUsecase *biz.ReadRelationshipsUsecase log *log.Helper } @@ -26,10 +27,19 @@ func (s *RelationshipsService) CreateRelationships(ctx context.Context, req *pb. return &pb.CreateRelationshipsResponse{}, nil } + func (s *RelationshipsService) ReadRelationships(ctx context.Context, req *pb.ReadRelationshipsRequest) (*pb.ReadRelationshipsResponse, error) { s.log.Infof("Read relationships request: %v", req) - return &pb.ReadRelationshipsResponse{}, nil + + if relationships, err := s.readUsecase.ReadRelationships(ctx, req.GetFilter()); err != nil { + return nil, err + } else { + return &pb.ReadRelationshipsResponse{ + Relationships: relationships, + }, nil + } } + func (s *RelationshipsService) DeleteRelationships(ctx context.Context, req *pb.DeleteRelationshipsRequest) (*pb.DeleteRelationshipsResponse, error) { s.log.Infof("Delete relationships request: %v", req) return &pb.DeleteRelationshipsResponse{}, nil