diff --git a/Makefile b/Makefile index 50694b6b..9b362610 100644 --- a/Makefile +++ b/Makefile @@ -135,7 +135,7 @@ test: copy-helpers generatorcheck: printf $(COLOR) "Check generated code is not stale..." - #(cd ./cmd/proxygenerator && go mod tidy && go run ./ -verifyOnly) + (cd ./cmd/proxygenerator && go mod tidy && go run ./ -verifyOnly) check: generatorcheck diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index db36a2a8..f31286b9 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -508,11 +508,18 @@ func generateInterceptor(cfg config) error { // UnimplementedWorkflowServiceServer is auto-generated via our API package // The methods on this type refer to all possible Request/Response types so we can use this to walk through all of our protobuf types - serviceTypes, err := lookupTypes("go.temporal.io/api/workflowservice/v1", []string{"UnimplementedWorkflowServiceServer"}) + workflowServiceTypes, err := lookupTypes("go.temporal.io/api/workflowservice/v1", []string{"UnimplementedWorkflowServiceServer"}) if err != nil { return err } - service := serviceTypes[0] + workflowService := workflowServiceTypes[0] + + // UnimplementedOperatorServiceServer is auto-generated via our API package + operatorServiceTypes, err := lookupTypes("go.temporal.io/api/operatorservice/v1", []string{"UnimplementedOperatorServiceServer"}) + if err != nil { + return err + } + operatorService := operatorServiceTypes[0] exportTypes, err := lookupTypes("go.temporal.io/api/export/v1", []string{"WorkflowExecutions"}) if err != nil { @@ -523,7 +530,19 @@ func generateInterceptor(cfg config) error { payloadRecords := map[string]*TypeRecord{} failureRecords := map[string]*TypeRecord{} - for _, meth := range typeutil.IntuitiveMethodSet(service, nil) { + for _, meth := range typeutil.IntuitiveMethodSet(workflowService, nil) { + if !meth.Obj().Exported() { + continue + } + + sig := meth.Obj().Type().(*types.Signature) + walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords) + walk(failureTypes, sig.Params().At(1).Type(), &failureRecords) + walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords) + walk(failureTypes, sig.Results().At(0).Type(), &failureRecords) + } + + for _, meth := range typeutil.IntuitiveMethodSet(operatorService, nil) { if !meth.Obj().Exported() { continue } diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 1c601a0c..3819d327 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -35,6 +35,7 @@ import ( "go.temporal.io/api/failure/v1" "go.temporal.io/api/history/v1" "go.temporal.io/api/nexus/v1" + "go.temporal.io/api/operatorservice/v1" "go.temporal.io/api/protocol/v1" "go.temporal.io/api/query/v1" "go.temporal.io/api/schedule/v1" @@ -1195,6 +1196,41 @@ func visitPayloads( return err } + case []*nexus.Endpoint: + for _, x := range o { + if err := visitPayloads(ctx, options, parent, x); err != nil { + return err + } + } + + case *nexus.Endpoint: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetSpec(), + ); err != nil { + return err + } + + case *nexus.EndpointSpec: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetDescription(), + ); err != nil { + return err + } + case *nexus.Request: if o == nil { @@ -1265,6 +1301,90 @@ func visitPayloads( return err } + case *operatorservice.CreateNexusEndpointRequest: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetSpec(), + ); err != nil { + return err + } + + case *operatorservice.CreateNexusEndpointResponse: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetEndpoint(), + ); err != nil { + return err + } + + case *operatorservice.GetNexusEndpointResponse: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetEndpoint(), + ); err != nil { + return err + } + + case *operatorservice.ListNexusEndpointsResponse: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetEndpoints(), + ); err != nil { + return err + } + + case *operatorservice.UpdateNexusEndpointRequest: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetSpec(), + ); err != nil { + return err + } + + case *operatorservice.UpdateNexusEndpointResponse: + + if o == nil { + continue + } + if err := visitPayloads( + ctx, + options, + o, + o.GetEndpoint(), + ); err != nil { + return err + } + case []*protocol.Message: for _, x := range o { if err := visitPayloads(ctx, options, parent, x); err != nil {