From 58db41accd2b47e25fe48008bbac27534a4e0e7c Mon Sep 17 00:00:00 2001 From: Bolek Kulbabinski <1416262+bolekk@users.noreply.github.com> Date: Mon, 16 Dec 2024 13:31:26 -0800 Subject: [PATCH] [KS-602] Fix for remote exectable client not respecting context --- core/capabilities/remote/executable/client.go | 28 +++++++++++++++---- .../remote/executable/client_test.go | 26 +++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/core/capabilities/remote/executable/client.go b/core/capabilities/remote/executable/client.go index 776ddb692ad..ccffd81b557 100644 --- a/core/capabilities/remote/executable/client.go +++ b/core/capabilities/remote/executable/client.go @@ -2,11 +2,12 @@ package executable import ( "context" - "errors" "fmt" "sync" "time" + "github.com/pkg/errors" + commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -43,6 +44,11 @@ var _ services.Service = &client{} const expiryCheckInterval = 30 * time.Second +var ( + ErrRequestExpired = errors.New("request expired by executable client") + ErrContextDoneBeforeResponseQuorum = errors.New("context done before remote client received a quorum of responses") +) + func NewClient(remoteCapabilityInfo commoncap.CapabilityInfo, localDonInfo commoncap.DON, dispatcher types.Dispatcher, requestTimeout time.Duration, lggr logger.Logger) *client { return &client{ @@ -122,7 +128,7 @@ func (c *client) expireRequests() { for messageID, req := range c.requestIDToCallerRequest { if req.Expired() { - req.Cancel(errors.New("request expired by executable client")) + req.Cancel(ErrRequestExpired) delete(c.requestIDToCallerRequest, messageID) } @@ -164,12 +170,22 @@ func (c *client) Execute(ctx context.Context, capReq commoncap.CapabilityRequest return commoncap.CapabilityResponse{}, fmt.Errorf("failed to send request: %w", err) } - resp := <-req.ResponseChan() - if resp.Err != nil { - return commoncap.CapabilityResponse{}, fmt.Errorf("error executing request: %w", resp.Err) + var respResult []byte + var respErr error + select { + case resp := <-req.ResponseChan(): + respResult = resp.Result + respErr = resp.Err + case <-ctx.Done(): + // NOTE: ClientRequest will not block on sending to ResponseChan() because that channel is buffered (with size 1) + return commoncap.CapabilityResponse{}, errors.Wrap(ErrContextDoneBeforeResponseQuorum, ctx.Err().Error()) + } + + if respErr != nil { + return commoncap.CapabilityResponse{}, fmt.Errorf("error executing request: %w", respErr) } - capabilityResponse, err := pb.UnmarshalCapabilityResponse(resp.Result) + capabilityResponse, err := pb.UnmarshalCapabilityResponse(respResult) if err != nil { return commoncap.CapabilityResponse{}, fmt.Errorf("failed to unmarshal capability response: %w", err) } diff --git a/core/capabilities/remote/executable/client_test.go b/core/capabilities/remote/executable/client_test.go index f4e6add82b0..82229cb8ec2 100644 --- a/core/capabilities/remote/executable/client_test.go +++ b/core/capabilities/remote/executable/client_test.go @@ -148,6 +148,7 @@ func Test_Client_TimesOutIfInsufficientCapabilityPeerResponses(t *testing.T) { responseTest := func(t *testing.T, response commoncap.CapabilityResponse, responseError error) { assert.Error(t, responseError) + assert.ErrorIs(t, responseError, executable.ErrRequestExpired) } capability := &TestCapability{} @@ -169,6 +170,31 @@ func Test_Client_TimesOutIfInsufficientCapabilityPeerResponses(t *testing.T) { }) } +func Test_Client_ContextCanceledBeforeQuorumReached(t *testing.T) { + ctx, cancel := context.WithCancel(testutils.Context(t)) + + responseTest := func(t *testing.T, response commoncap.CapabilityResponse, responseError error) { + assert.Error(t, responseError) + assert.ErrorIs(t, responseError, executable.ErrContextDoneBeforeResponseQuorum) + } + + capability := &TestCapability{} + transmissionSchedule, err := values.NewMap(map[string]any{ + "schedule": transmission.Schedule_AllAtOnce, + "deltaStage": "20s", + }) + require.NoError(t, err) + + cancel() + testClient(t, 2, 20*time.Second, 2, 2, + capability, + func(caller commoncap.ExecutableCapability) { + executeInputs, err := values.NewMap(map[string]any{"executeValue1": "aValue1"}) + require.NoError(t, err) + executeMethod(ctx, caller, transmissionSchedule, executeInputs, responseTest, t) + }) +} + func testClient(t *testing.T, numWorkflowPeers int, workflowNodeResponseTimeout time.Duration, numCapabilityPeers int, capabilityDonF uint8, underlying commoncap.ExecutableCapability, method func(caller commoncap.ExecutableCapability)) {