diff --git a/internal/data/spicedb.go b/internal/data/spicedb.go index 799dfe7..3c1abc0 100644 --- a/internal/data/spicedb.go +++ b/internal/data/spicedb.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "io" "os" "strings" @@ -250,15 +248,27 @@ func (s *SpiceDbRepository) ImportBulkTuples(stream grpc.ClientStreamingServer[a } var totalImported uint64 + client, err := s.client.ImportBulkRelationships(context.Background()) + if err != nil { + return fmt.Errorf("failed to create SpiceDB client: %w", err) + } + for { - req, err := stream.Recv() - if err != nil { - if !errors.Is(err, io.EOF) { - if err := stream.SendAndClose(&apiV1beta1.ImportBulkTuplesResponse{NumImported: totalImported}); err != nil { - return status.Errorf(codes.Internal, "failed to send response: %v", err) + req, streamErr := stream.Recv() + if streamErr != nil { + if req == nil && errors.Is(streamErr, io.EOF) { + if res, closeErr := client.CloseAndRecv(); closeErr != nil { + return fmt.Errorf("error receiving response from Spicedb for bulkimport request: %w", closeErr) + } else { + log.Infof("total number of relationships loaded: %d", res.NumLoaded) + totalImported = res.NumLoaded + return stream.SendAndClose(&apiV1beta1.ImportBulkTuplesResponse{NumImported: totalImported}) } } - return err + if !errors.Is(streamErr, io.EOF) { + return streamErr + } + return streamErr } inputRelationships := (*req).Tuples batch := []*v1.Relationship{} @@ -266,10 +276,6 @@ func (s *SpiceDbRepository) ImportBulkTuples(stream grpc.ClientStreamingServer[a tuple.Relation = addRelationPrefix(tuple.Relation, relationPrefix) batch = append(batch, createSpiceDbRelationship(tuple)) } - client, err := s.client.ImportBulkRelationships(context.Background()) - if err != nil { - return fmt.Errorf("failed to create SpiceDB client: %w", err) - } if err = client.Send((*v1.ImportBulkRelationshipsRequest)(&v1.BulkImportRelationshipsRequest{ Relationships: batch, })); err != nil { @@ -278,14 +284,8 @@ func (s *SpiceDbRepository) ImportBulkTuples(stream grpc.ClientStreamingServer[a } return err } - if res, err := client.CloseAndRecv(); err != nil { - return fmt.Errorf("error receiving response from Spicedb for bulkimport request: %w", err) - } else { - log.Infof("total number of relationships loaded: %d", res.NumLoaded) - totalImported = res.NumLoaded - return stream.SendAndClose(&apiV1beta1.ImportBulkTuplesResponse{NumImported: totalImported}) - } } + } func (s *SpiceDbRepository) CreateRelationships(ctx context.Context, rels []*apiV1beta1.Relationship, touch biz.TouchSemantics) error {