diff --git a/core/capabilities/webapi/trigger.go b/core/capabilities/webapi/trigger.go index db0df7d1410..611879c7a0a 100644 --- a/core/capabilities/webapi/trigger.go +++ b/core/capabilities/webapi/trigger.go @@ -1,18 +1,279 @@ -package webapi +package trigger import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + + ethCommon "github.com/ethereum/go-ethereum/common" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types/core" + "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" - "github.com/smartcontractkit/chainlink/v2/core/services/job" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" ) -func NewTrigger(config string, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, lggr logger.Logger) (job.ServiceCtx, error) { - // TODO (CAPPL-22, CAPPL-24): - // - decode config - // - create an implementation of the capability API and add it to the Registry - // - create a handler and register it with Gateway Connector - // - manage trigger subscriptions - // - process incoming trigger events and related metadata - return nil, nil +const defaultSendChannelBufferSize = 1000 + +const TriggerType = "web-trigger@1.0.0" + +var webapiTriggerInfo = capabilities.MustNewCapabilityInfo( + TriggerType, + capabilities.CapabilityTypeTrigger, + "A trigger to start workflow execution from a web api call", +) + +type Input struct { +} +type Config struct { + AllowedSenders []string `toml:"allowedSenders"` + AllowedTopics []string `toml:"allowedTopics"` + RateLimiter common.RateLimiterConfig `toml:"rateLimiter"` + // RequiredParams is advisory to the web trigger message sender it is not enforced. + RequiredParams []string `toml:"requiredParams"` +} + +type webapiTrigger struct { + allowedSenders map[string]bool + allowedTopics map[string]bool + ch chan<- capabilities.TriggerResponse + config Config + rateLimiter *common.RateLimiter +} + +type triggerConnectorHandler struct { + services.StateMachine + + capabilities.CapabilityInfo + capabilities.Validator[Config, Input, capabilities.TriggerResponse] + connector connector.GatewayConnector + lggr logger.Logger + mu sync.Mutex + registeredWorkflows map[string]webapiTrigger +} + +var _ capabilities.TriggerCapability = (*triggerConnectorHandler)(nil) +var _ services.Service = &triggerConnectorHandler{} + +func NewTrigger(config string, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, lggr logger.Logger) (*triggerConnectorHandler, error) { + if connector == nil { + return nil, errors.New("missing connector") + } + handler := &triggerConnectorHandler{ + Validator: capabilities.NewValidator[Config, Input, capabilities.TriggerResponse](capabilities.ValidatorArgs{Info: webapiTriggerInfo}), + connector: connector, + registeredWorkflows: map[string]webapiTrigger{}, + lggr: lggr.Named("WorkflowConnectorHandler"), + } + + return handler, nil +} + +// processTrigger iterates over each topic, checking against senders and rateLimits, then starting event processing and responding +func (h *triggerConnectorHandler) processTrigger(ctx context.Context, gatewayID string, body *api.MessageBody, sender ethCommon.Address, payload webapicapabilities.TriggerRequestPayload) error { + // Pass on the payload with the expectation that it's in an acceptable format for the executor + wrappedPayload, err := values.WrapMap(payload) + if err != nil { + return fmt.Errorf("error wrapping payload %s", err) + } + topics := payload.Topics + + // empty topics is error for V1 + if len(topics) == 0 { + return fmt.Errorf("empty Workflow Topics") + } + + // workflows that have matched topics + matchedWorkflows := 0 + // workflows that have matched topic and passed all checks + fullyMatchedWorkflows := 0 + for _, trigger := range h.registeredWorkflows { + for _, topic := range topics { + if trigger.allowedTopics[topic] { + matchedWorkflows++ + if !trigger.allowedSenders[sender.String()] { + err = fmt.Errorf("unauthorized Sender %s, messageID %s", sender.String(), body.MessageId) + h.lggr.Debugw(err.Error()) + continue + } + if !trigger.rateLimiter.Allow(body.Sender) { + err = fmt.Errorf("request rate-limited for sender %s, messageID %s", sender.String(), body.MessageId) + continue + } + fullyMatchedWorkflows++ + TriggerEventID := body.Sender + payload.TriggerEventID + tr := capabilities.TriggerResponse{ + Event: capabilities.TriggerEvent{ + TriggerType: TriggerType, + ID: TriggerEventID, + Outputs: wrappedPayload, + }, + } + select { + case <-ctx.Done(): + return nil + case trigger.ch <- tr: + // Sending n topics that match a workflow with n allowedTopics, can only be triggered once. + break + } + } + } + } + if matchedWorkflows == 0 { + return fmt.Errorf("no Matching Workflow Topics") + } + + if fullyMatchedWorkflows > 0 { + return nil + } + return err +} + +func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) { + // TODO: Validate Signature + body := &msg.Body + sender := ethCommon.HexToAddress(body.Sender) + var payload webapicapabilities.TriggerRequestPayload + err := json.Unmarshal(body.Payload, &payload) + if err != nil { + h.lggr.Errorw("error decoding payload", "err", err) + err = h.sendResponse(ctx, gatewayID, body, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s decoding payload", err.Error()).Error()}) + if err != nil { + h.lggr.Errorw("error sending response", "err", err) + } + return + } + + switch body.Method { + case webapicapabilities.MethodWebAPITrigger: + resp := h.processTrigger(ctx, gatewayID, body, sender, payload) + var response webapicapabilities.TriggerResponsePayload + if resp == nil { + response = webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"} + } else { + response = webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: resp.Error()} + h.lggr.Errorw("Error processing trigger", "gatewayID", gatewayID, "body", body, "response", resp) + } + err = h.sendResponse(ctx, gatewayID, body, response) + if err != nil { + h.lggr.Errorw("Error sending response", "body", body, "response", response, "err", err) + } + return + + default: + h.lggr.Errorw("unsupported method", "id", gatewayID, "method", body.Method) + err = h.sendResponse(ctx, gatewayID, body, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("unsupported method %s", body.Method).Error()}) + if err != nil { + h.lggr.Errorw("error sending response", "err", err) + } + } +} + +func (h *triggerConnectorHandler) RegisterTrigger(ctx context.Context, req capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { + cfg := req.Config + if cfg == nil { + return nil, errors.New("config is required to register a web api trigger") + } + + reqConfig, err := h.ValidateConfig(cfg) + if err != nil { + return nil, err + } + + if len(reqConfig.AllowedSenders) == 0 { + return nil, errors.New("allowedSenders must have at least 1 entry") + } + + h.mu.Lock() + defer h.mu.Unlock() + _, errBool := h.registeredWorkflows[req.TriggerID] + if errBool { + return nil, fmt.Errorf("triggerId %s already registered", req.TriggerID) + } + + rateLimiter, err := common.NewRateLimiter(reqConfig.RateLimiter) + if err != nil { + return nil, err + } + + allowedSendersMap := map[string]bool{} + for _, k := range reqConfig.AllowedSenders { + allowedSendersMap[k] = true + } + + allowedTopicsMap := map[string]bool{} + for _, k := range reqConfig.AllowedTopics { + allowedTopicsMap[k] = true + } + + ch := make(chan capabilities.TriggerResponse, defaultSendChannelBufferSize) + + h.registeredWorkflows[req.TriggerID] = webapiTrigger{ + allowedTopics: allowedTopicsMap, + allowedSenders: allowedSendersMap, + ch: ch, + config: *reqConfig, + rateLimiter: rateLimiter, + } + + return ch, nil +} + +func (h *triggerConnectorHandler) UnregisterTrigger(ctx context.Context, req capabilities.TriggerRegistrationRequest) error { + h.mu.Lock() + defer h.mu.Unlock() + workflow, ok := h.registeredWorkflows[req.TriggerID] + if !ok { + return fmt.Errorf("triggerId %s not registered", req.TriggerID) + } + + close(workflow.ch) + delete(h.registeredWorkflows, req.TriggerID) + return nil +} + +func (h *triggerConnectorHandler) Start(ctx context.Context) error { + return h.StartOnce("GatewayConnectorServiceWrapper", func() error { + return h.connector.AddHandler([]string{"web_trigger"}, h) + }) +} +func (h *triggerConnectorHandler) Close() error { + return h.StopOnce("GatewayConnectorServiceWrapper", func() error { + return nil + }) +} + +func (h *triggerConnectorHandler) HealthReport() map[string]error { + return map[string]error{h.Name(): h.Healthy()} +} + +func (h *triggerConnectorHandler) Name() string { + return "WebAPITrigger" +} + +func (h *triggerConnectorHandler) sendResponse(ctx context.Context, gatewayID string, requestBody *api.MessageBody, payload any) error { + payloadJSON, err := json.Marshal(payload) + if err != nil { + h.lggr.Errorw("error marshalling payload", "err", err) + payloadJSON, _ = json.Marshal(webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s marshalling payload", err.Error()).Error()}) + } + + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: requestBody.MessageId, + DonId: requestBody.DonId, + Method: requestBody.Method, + Receiver: requestBody.Sender, + Payload: payloadJSON, + }, + } + + return h.connector.SendToGateway(ctx, gatewayID, msg) } diff --git a/core/capabilities/webapi/trigger_test.go b/core/capabilities/webapi/trigger_test.go new file mode 100644 index 00000000000..d370b1ec7ac --- /dev/null +++ b/core/capabilities/webapi/trigger_test.go @@ -0,0 +1,383 @@ +package trigger + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + registrymock "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + corelogger "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" +) + +const ( + privateKey1 = "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023f" + privateKey2 = "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023e" + triggerID1 = "5" + triggerID2 = "6" + workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0" + workflowExecutionID1 = "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed" + owner1 = "0x00000000000000000000000000000000000000aa" + address1 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc5" + address2 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc6" +) + +type testHarness struct { + registry *registrymock.CapabilitiesRegistry + connector *gcmocks.GatewayConnector + lggr logger.Logger + config string + trigger *triggerConnectorHandler +} + +func workflowTriggerConfig(_ testHarness, addresses []string, topics []string) (*values.Map, error) { + var rateLimitConfig, err = values.NewMap(map[string]any{ + "GlobalRPS": 100.0, + "GlobalBurst": 101, + "PerSenderRPS": 102.0, + "PerSenderBurst": 103, + }) + if err != nil { + return nil, err + } + + triggerRegistrationConfig, err := values.NewMap(map[string]interface{}{ + "RateLimiter": rateLimitConfig, + "AllowedSenders": addresses, + "AllowedTopics": topics, + "RequiredParams": []string{"bid", "ask"}, + }) + return triggerRegistrationConfig, err +} + +func setup(t *testing.T) testHarness { + registry := registrymock.NewCapabilitiesRegistry(t) + connector := gcmocks.NewGatewayConnector(t) + lggr := corelogger.TestLogger(t) + config := "" + + trigger, err := NewTrigger(config, registry, connector, lggr) + require.NoError(t, err) + + return testHarness{ + registry: registry, + connector: connector, + lggr: lggr, + config: config, + trigger: trigger, + } +} + +func gatewayRequest(t *testing.T, privateKey string, topics string, methodName string) *api.Message { + messageID := "12345" + if methodName == "" { + methodName = webapicapabilities.MethodWebAPITrigger + } + donID := "workflow_don_1" + + key, err := crypto.HexToECDSA(privateKey) + require.NoError(t, err) + + payload := `{ + "trigger_id": "` + TriggerType + `", + "trigger_event_id": "action_1234567890", + "timestamp": 1234567890, + "topics": ` + topics + `, + "params": { + "bid": "101", + "ask": "102" + } + } +` + payloadJSON := []byte(payload) + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: messageID, + Method: methodName, + DonId: donID, + Payload: json.RawMessage(payloadJSON), + }, + } + err = msg.Sign(key) + require.NoError(t, err) + return msg +} + +func getResponseFromArg(arg interface{}) (webapicapabilities.TriggerResponsePayload, error) { + var response webapicapabilities.TriggerResponsePayload + err := json.Unmarshal((&(arg.(*api.Message)).Body).Payload, &response) + return response, err +} + +func requireNoChanMsg[T any](t *testing.T, ch <-chan T) { + timedOut := false + select { + case <-ch: + case <-time.After(100 * time.Millisecond): + timedOut = true + } + require.True(t, timedOut) +} + +func requireChanMsg[T capabilities.TriggerResponse](t *testing.T, ch <-chan capabilities.TriggerResponse) (capabilities.TriggerResponse, error) { + timedOut := false + select { + case resp := <-ch: + return resp, nil + case <-time.After(100 * time.Millisecond): + timedOut = true + } + require.False(t, timedOut) + return capabilities.TriggerResponse{}, errors.New("channel timeout") +} + +func TestTriggerExecute(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + ctx, cancelContext := context.WithDeadline(ctx, time.Now().Add(10*time.Second)) + Config, _ := workflowTriggerConfig(th, []string{address1}, []string{"daily_price_update", "ad_hoc_price_update"}) + triggerReq := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID1, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config, + } + channel, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.NoError(t, err) + + Config2, err := workflowTriggerConfig(th, []string{address1}, []string{"daily_price_update2", "ad_hoc_price_update"}) + require.NoError(t, err) + + triggerReq2 := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID2, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config2, + } + channel2, err := th.trigger.RegisterTrigger(ctx, triggerReq2) + require.NoError(t, err) + + t.Run("happy case single topic to single workflow", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey1, `["daily_price_update"]`, "") + + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + + received, chanErr := requireChanMsg(t, channel) + require.Equal(t, received.Event.TriggerType, TriggerType) + require.NoError(t, chanErr) + + requireNoChanMsg(t, channel2) + data := received.Event.Outputs + var payload webapicapabilities.TriggerRequestPayload + unwrapErr := data.UnwrapTo(&payload) + require.NoError(t, unwrapErr) + require.Equal(t, payload.Topics, []string{"daily_price_update"}) + }) + + t.Run("happy case single different topic 2 workflows.", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey1, `["ad_hoc_price_update"]`, "") + + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + + sent := <-channel + require.Equal(t, sent.Event.TriggerType, TriggerType) + data := sent.Event.Outputs + var payload webapicapabilities.TriggerRequestPayload + unwrapErr := data.UnwrapTo(&payload) + require.NoError(t, unwrapErr) + require.Equal(t, payload.Topics, []string{"ad_hoc_price_update"}) + + sent2 := <-channel2 + require.Equal(t, sent2.Event.TriggerType, TriggerType) + data2 := sent2.Event.Outputs + var payload2 webapicapabilities.TriggerRequestPayload + err2 := data2.UnwrapTo(&payload2) + require.NoError(t, err2) + require.Equal(t, payload2.Topics, []string{"ad_hoc_price_update"}) + }) + + t.Run("sad case empty topic 2 workflows", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey1, `[]`, "") + + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "empty Workflow Topics"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + + requireNoChanMsg(t, channel) + requireNoChanMsg(t, channel2) + }) + + t.Run("sad case topic with no workflows", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey1, `["foo"]`, "") + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "no Matching Workflow Topics"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + requireNoChanMsg(t, channel) + requireNoChanMsg(t, channel2) + }) + + t.Run("sad case Not Allowed Sender", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey2, `["ad_hoc_price_update"]`, "") + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unauthorized Sender 0x2dAC9f74Ee66e2D55ea1B8BE284caFedE048dB3A, messageID 12345"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + requireNoChanMsg(t, channel) + requireNoChanMsg(t, channel2) + }) + + t.Run("sad case Invalid Method", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey2, `["ad_hoc_price_update"]`, "boo") + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unsupported method boo"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + requireNoChanMsg(t, channel) + requireNoChanMsg(t, channel2) + }) + + err = th.trigger.UnregisterTrigger(ctx, triggerReq) + require.NoError(t, err) + err = th.trigger.UnregisterTrigger(ctx, triggerReq2) + require.NoError(t, err) + cancelContext() +} + +func TestRegisterNoAllowedSenders(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + Config, _ := workflowTriggerConfig(th, []string{}, []string{"daily_price_update"}) + + triggerReq := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID1, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config, + } + _, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.Error(t, err) + + gatewayRequest(t, privateKey1, `["daily_price_update"]`, "") +} + +func TestTriggerExecute2WorkflowsSameTopicDifferentAllowLists(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + ctx, cancelContext := context.WithDeadline(ctx, time.Now().Add(10*time.Second)) + Config, _ := workflowTriggerConfig(th, []string{address2}, []string{"daily_price_update"}) + triggerReq := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID1, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config, + } + channel, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.NoError(t, err) + + Config2, err := workflowTriggerConfig(th, []string{address1}, []string{"daily_price_update"}) + require.NoError(t, err) + + triggerReq2 := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID2, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config2, + } + channel2, err := th.trigger.RegisterTrigger(ctx, triggerReq2) + require.NoError(t, err) + + t.Run("happy case single topic to single workflow", func(t *testing.T) { + gatewayRequest := gatewayRequest(t, privateKey1, `["daily_price_update"]`, "") + + th.connector.On("SendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + resp, _ := getResponseFromArg(args.Get(2)) + require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + }).Return(nil).Once() + + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + + requireNoChanMsg(t, channel) + received, chanErr := requireChanMsg(t, channel2) + require.Equal(t, received.Event.TriggerType, TriggerType) + require.NoError(t, chanErr) + data := received.Event.Outputs + var payload webapicapabilities.TriggerRequestPayload + unwrapErr := data.UnwrapTo(&payload) + require.NoError(t, unwrapErr) + require.Equal(t, payload.Topics, []string{"daily_price_update"}) + }) + err = th.trigger.UnregisterTrigger(ctx, triggerReq) + require.NoError(t, err) + err = th.trigger.UnregisterTrigger(ctx, triggerReq2) + require.NoError(t, err) + cancelContext() +} + +func TestRegisterUnregister(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + Config, err := workflowTriggerConfig(th, []string{address1}, []string{"daily_price_update"}) + require.NoError(t, err) + + triggerReq := capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID1, + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + Config: Config, + } + + channel, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.NoError(t, err) + require.NotEmpty(t, th.trigger.registeredWorkflows[triggerID1]) + + err = th.trigger.UnregisterTrigger(ctx, triggerReq) + require.NoError(t, err) + _, open := <-channel + require.Equal(t, open, false) +} diff --git a/core/scripts/gateway/web_api_trigger/invoke_trigger.go b/core/scripts/gateway/web_api_trigger/invoke_trigger.go new file mode 100644 index 00000000000..00bc08b3489 --- /dev/null +++ b/core/scripts/gateway/web_api_trigger/invoke_trigger.go @@ -0,0 +1,147 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/joho/godotenv" + + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" +) + +// https://gateway-us-1.chain.link/web-trigger +// { +// jsonrpc: "2.0", +// id: "...", +// method: "web-trigger", +// params: { +// signature: "...", +// body: { +// don_id: "workflow_123", +// payload: { +// trigger_id: "web-trigger@1.0.0", +// trigger_event_id: "action_1234567890", +// timestamp: 1234567890, +// sub-events: [ +// { +// topics: ["daily_price_update"], +// params: { +// bid: "101", +// ask: "102" +// } +// }, +// { +// topics: ["daily_message", "summary"], +// params: { +// message: "all good!", +// } +// }, +// ] +// } +// } +// } +// } + +func main() { + gatewayURL := flag.String("gateway_url", "http://localhost:5002", "Gateway URL") + privateKey := flag.String("private_key", "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023f", "Private key to sign the message with") + messageID := flag.String("id", "12345", "Request ID") + methodName := flag.String("method", "web_trigger", "Method name") + donID := flag.String("don_id", "workflow_don_1", "DON ID") + + flag.Parse() + + if privateKey == nil || *privateKey == "" { + if err := godotenv.Load(); err != nil { + panic(err) + } + + privateKeyEnvVar := os.Getenv("PRIVATE_KEY") + privateKey = &privateKeyEnvVar + fmt.Println("Loaded private key from .env") + } + + // validate key and extract address + key, err := crypto.HexToECDSA(*privateKey) + if err != nil { + fmt.Println("error parsing private key", err) + return + } + + payload := `{ + "trigger_id": "web-trigger@1.0.0", + "trigger_event_id": "action_1234567890", + "timestamp": 1234567890, + "topics": ["daily_price_update"], + "params": { + "bid": "101", + "ask": "102" + } + } +` + payloadJSON := []byte(payload) + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: *messageID, + Method: *methodName, + DonId: *donID, + Payload: json.RawMessage(payloadJSON), + }, + } + if err = msg.Sign(key); err != nil { + fmt.Println("error signing message", err) + return + } + codec := api.JsonRPCCodec{} + rawMsg, err := codec.EncodeRequest(msg) + if err != nil { + fmt.Println("error JSON-RPC encoding", err) + return + } + + createRequest := func() (req *http.Request, err error) { + req, err = http.NewRequestWithContext(context.Background(), "POST", *gatewayURL, bytes.NewBuffer(rawMsg)) + if err == nil { + req.Header.Set("Content-Type", "application/json") + } + return + } + + client := &http.Client{} + + sendRequest := func() { + req, err2 := createRequest() + if err2 != nil { + fmt.Println("error creating a request", err2) + return + } + + resp, err2 := client.Do(req) + if err2 != nil { + fmt.Println("error sending a request", err2) + return + } + defer resp.Body.Close() + + body, err2 := io.ReadAll(resp.Body) + if err2 != nil { + fmt.Println("error sending a request", err2) + return + } + + var prettyJSON bytes.Buffer + if err2 = json.Indent(&prettyJSON, body, "", " "); err2 != nil { + fmt.Println(string(body)) + } else { + fmt.Println(prettyJSON.String()) + } + } + sendRequest() +} diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 6793350f317..92ad48b5395 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -10,6 +10,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" ) const ( @@ -38,6 +39,8 @@ func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json switch handlerType { case FunctionsHandlerType: return functions.NewFunctionsHandlerFromConfig(handlerConfig, donConfig, don, hf.legacyChains, hf.ds, hf.lggr) + case WebAPICapabilitiesType: + return webapicapabilities.NewWorkflowHandler(handlerConfig, donConfig, don, hf.lggr) case DummyHandlerType: return handlers.NewDummyHandler(donConfig, don, hf.lggr) default: diff --git a/core/services/gateway/handlers/webapicapabilities/handler.go b/core/services/gateway/handlers/webapicapabilities/handler.go index a38651d40fc..d6caf067dd0 100644 --- a/core/services/gateway/handlers/webapicapabilities/handler.go +++ b/core/services/gateway/handlers/webapicapabilities/handler.go @@ -1,6 +1,123 @@ package webapicapabilities +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "go.uber.org/multierr" + + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" +) + const ( // NOTE: more methods will go here. HTTP trigger/action/target; etc. - MethodWebAPITarget = "web_api_target" + MethodWebAPITarget = "web_api_target" + MethodWebAPITrigger = "web_api_trigger" ) + +type handler struct { + config HandlerConfig + donConfig *config.DONConfig + don handlers.DON + savedCallbacks map[string]*savedCallback + mu sync.Mutex + lggr logger.Logger +} + +type HandlerConfig struct { + MaxAllowedMessageAgeSec uint +} +type savedCallback struct { + id string + callbackCh chan<- handlers.UserCallbackPayload +} + +var _ handlers.Handler = (*handler)(nil) + +func NewWorkflowHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, lggr logger.Logger) (*handler, error) { + var cfg HandlerConfig + err := json.Unmarshal(handlerConfig, &cfg) + if err != nil { + return nil, err + } + + return &handler{ + config: cfg, + donConfig: donConfig, + don: don, + savedCallbacks: make(map[string]*savedCallback), + lggr: lggr.Named("WorkflowHandler." + donConfig.DonId), + }, nil +} + +func (d *handler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { + d.mu.Lock() + d.savedCallbacks[msg.Body.MessageId] = &savedCallback{msg.Body.MessageId, callbackCh} + don := d.don + d.mu.Unlock() + body := msg.Body + var payload TriggerRequestPayload + err := json.Unmarshal(body.Payload, &payload) + if err != nil { + d.lggr.Errorw("error decoding payload", "err", err) + callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.UserMessageParseError, ErrMsg: fmt.Sprintf("error decoding payload %s", err.Error())} + close(callbackCh) + return nil + } + + if payload.Timestamp == 0 { + d.lggr.Errorw("error decoding payload") + callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.UserMessageParseError, ErrMsg: "error decoding payload"} + close(callbackCh) + return nil + } + + if uint(time.Now().Unix())-d.config.MaxAllowedMessageAgeSec > uint(payload.Timestamp) { + callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.HandlerError, ErrMsg: "stale message"} + close(callbackCh) + return nil + } + // TODO: apply allowlist and rate-limiting here + if msg.Body.Method != MethodWebAPITrigger { + d.lggr.Errorw("unsupported method", "method", body.Method) + callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.HandlerError, ErrMsg: fmt.Sprintf("invalid method %s", msg.Body.Method)} + close(callbackCh) + return nil + } + + // Send to all nodes. + for _, member := range d.donConfig.Members { + err = multierr.Combine(err, don.SendToNode(ctx, member.Address, msg)) + } + return err +} + +func (d *handler) HandleNodeMessage(ctx context.Context, msg *api.Message, _ string) error { + d.mu.Lock() + savedCb, found := d.savedCallbacks[msg.Body.MessageId] + delete(d.savedCallbacks, msg.Body.MessageId) + d.mu.Unlock() + + if found { + // Send first response from a node back to the user, ignore any other ones. + // TODO: in practice, we should wait for at least 2F+1 nodes to respond and then return an aggregated response + // back to the user. + savedCb.callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} + close(savedCb.callbackCh) + } + return nil +} + +func (d *handler) Start(context.Context) error { + return nil +} + +func (d *handler) Close() error { + return nil +} diff --git a/core/services/gateway/handlers/webapicapabilities/handler_test.go b/core/services/gateway/handlers/webapicapabilities/handler_test.go new file mode 100644 index 00000000000..ef278e40ffd --- /dev/null +++ b/core/services/gateway/handlers/webapicapabilities/handler_test.go @@ -0,0 +1,185 @@ +package webapicapabilities + +import ( + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/mock" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + gwcommon "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" + + handlermocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" +) + +const ( + defaultSendChannelBufferSize = 1000 + privateKey1 = "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023f" + privateKey2 = "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023e" + triggerID1 = "5" + triggerID2 = "6" + workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0" + workflowExecutionID1 = "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed" + owner1 = "0x00000000000000000000000000000000000000aa" + address1 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc5" +) + +func setupHandler(t *testing.T) (*handler, *handlermocks.DON, []gwcommon.TestNode) { + lggr := logger.TestLogger(t) + don := handlermocks.NewDON(t) + + handlerConfig := HandlerConfig{ + MaxAllowedMessageAgeSec: 30, + } + cfgBytes, err := json.Marshal(handlerConfig) + require.NoError(t, err) + donConfig := &config.DONConfig{ + Members: []config.NodeConfig{}, + F: 1, + } + nodes := gwcommon.NewTestNodes(t, 2) + for id, n := range nodes { + donConfig.Members = append(donConfig.Members, config.NodeConfig{ + Name: fmt.Sprintf("node_%d", id), + Address: n.Address, + }) + } + + handler, err := NewWorkflowHandler(json.RawMessage(cfgBytes), donConfig, don, lggr) + require.NoError(t, err) + return handler, don, nodes +} + +func triggerRequest(t *testing.T, privateKey string, topics string, methodName string, timestamp string, payload string) *api.Message { + messageID := "12345" + if methodName == "" { + methodName = MethodWebAPITrigger + } + if timestamp == "" { + timestamp = strconv.FormatInt(time.Now().Unix(), 10) + } + donID := "workflow_don_1" + + key, err := crypto.HexToECDSA(privateKey) + require.NoError(t, err) + if payload == "" { + payload = `{ + "trigger_id": "web-trigger@1.0.0", + "trigger_event_id": "action_1234567890", + "timestamp": ` + timestamp + `, + "topics": ` + topics + `, + "params": { + "bid": "101", + "ask": "102" + } + } + ` + } + payloadJSON := []byte(payload) + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: messageID, + Method: methodName, + DonId: donID, + Payload: json.RawMessage(payloadJSON), + }, + } + err = msg.Sign(key) + require.NoError(t, err) + return msg +} + +func requireNoChanMsg[T any](t *testing.T, ch <-chan T) { + timedOut := false + select { + case <-ch: + case <-time.After(100 * time.Millisecond): + timedOut = true + } + require.True(t, timedOut) +} + +func TestHandlerReceiveHTTPMessageFromClient(t *testing.T) { + handler, don, _ := setupHandler(t) + ctx := testutils.Context(t) + msg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "", "", "") + + t.Run("happy case", func(t *testing.T) { + ch := make(chan handlers.UserCallbackPayload, defaultSendChannelBufferSize) + + // sends to 2 dons + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + require.Equal(t, msg, args.Get(2)) + }).Return(nil).Once() + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + require.Equal(t, msg, args.Get(2)) + }).Return(nil).Once() + + err := handler.HandleUserMessage(ctx, msg, ch) + require.NoError(t, err) + requireNoChanMsg(t, ch) + + err = handler.HandleNodeMessage(ctx, msg, "") + require.NoError(t, err) + + resp := <-ch + require.Equal(t, handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""}, resp) + _, open := <-ch + require.Equal(t, open, false) + }) + + t.Run("sad case invalid method", func(t *testing.T) { + invalidMsg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "foo", "", "") + ch := make(chan handlers.UserCallbackPayload, defaultSendChannelBufferSize) + err := handler.HandleUserMessage(ctx, invalidMsg, ch) + require.NoError(t, err) + resp := <-ch + require.Equal(t, handlers.UserCallbackPayload{Msg: invalidMsg, ErrCode: api.HandlerError, ErrMsg: "invalid method foo"}, resp) + _, open := <-ch + require.Equal(t, open, false) + }) + + t.Run("sad case stale message", func(t *testing.T) { + invalidMsg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "", "123456", "") + ch := make(chan handlers.UserCallbackPayload, defaultSendChannelBufferSize) + err := handler.HandleUserMessage(ctx, invalidMsg, ch) + require.NoError(t, err) + resp := <-ch + require.Equal(t, handlers.UserCallbackPayload{Msg: invalidMsg, ErrCode: api.HandlerError, ErrMsg: "stale message"}, resp) + _, open := <-ch + require.Equal(t, open, false) + }) + + t.Run("sad case empty payload", func(t *testing.T) { + invalidMsg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "", "123456", "{}") + ch := make(chan handlers.UserCallbackPayload, defaultSendChannelBufferSize) + err := handler.HandleUserMessage(ctx, invalidMsg, ch) + require.NoError(t, err) + resp := <-ch + require.Equal(t, handlers.UserCallbackPayload{Msg: invalidMsg, ErrCode: api.UserMessageParseError, ErrMsg: "error decoding payload"}, resp) + _, open := <-ch + require.Equal(t, open, false) + }) + + t.Run("sad case invalid payload", func(t *testing.T) { + invalidMsg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "", "123456", `{"foo":"bar"}`) + ch := make(chan handlers.UserCallbackPayload, defaultSendChannelBufferSize) + err := handler.HandleUserMessage(ctx, invalidMsg, ch) + require.NoError(t, err) + resp := <-ch + require.Equal(t, handlers.UserCallbackPayload{Msg: invalidMsg, ErrCode: api.UserMessageParseError, ErrMsg: "error decoding payload"}, resp) + _, open := <-ch + require.Equal(t, open, false) + }) + // TODO: Validate Senders and rate limit chck, pending question in trigger about where senders and rate limits are validated +} diff --git a/core/services/gateway/handlers/webapicapabilities/webapi.go b/core/services/gateway/handlers/webapicapabilities/webapi.go index e300b61d85b..97ba401881b 100644 --- a/core/services/gateway/handlers/webapicapabilities/webapi.go +++ b/core/services/gateway/handlers/webapicapabilities/webapi.go @@ -1,5 +1,9 @@ package webapicapabilities +import ( + "github.com/smartcontractkit/chainlink-common/pkg/values" +) + type TargetRequestPayload struct { URL string `json:"url"` // URL to query, only http and https protocols are supported. Method string `json:"method,omitempty"` // HTTP verb, defaults to GET. @@ -15,3 +19,47 @@ type TargetResponsePayload struct { Headers map[string]string `json:"headers,omitempty"` // HTTP headers Body []byte `json:"body,omitempty"` // HTTP response body } + +// https://gateway-us-1.chain.link/web-trigger +// +// { +// jsonrpc: "2.0", +// id: "...", +// method: "web-trigger", +// params: { +// signature: "...", +// body: { +// don_id: "workflow_123", +// payload: { +// trigger_id: "web-trigger@1.0.0", +// trigger_event_id: "action_1234567890", +// timestamp: 1234567890, +// topics: ["daily_price_update"], +// params: { +// bid: "101", +// ask: "102" +// } +// } +// } +// } +// } +// +// from Web API Trigger Doc, with modifications. +// trigger_id - ID of the trigger corresponding to the capability ID +// trigger_event_id - uniquely identifies generated event (scoped to trigger_id and sender) +// timestamp - timestamp of the event (unix time), needs to be within certain freshness to be processed +// topics - an array of a single topic (string) to be started by this event +// params - key-value pairs for the workflow engine, untranslated. +type TriggerRequestPayload struct { + TriggerID string `json:"trigger_id"` + TriggerEventID string `json:"trigger_event_id"` + Timestamp int64 `json:"timestamp"` + Topics []string `json:"topics"` + Params values.Map `json:"params"` +} + +type TriggerResponsePayload struct { + ErrorMessage string `json:"error_message,omitempty"` + // ERROR, ACCEPTED, PENDING, COMPLETED + Status string `json:"status"` +} diff --git a/core/services/standardcapabilities/delegate.go b/core/services/standardcapabilities/delegate.go index 15c829fbf84..1e27d2ffb33 100644 --- a/core/services/standardcapabilities/delegate.go +++ b/core/services/standardcapabilities/delegate.go @@ -13,7 +13,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/core" gatewayconnector "github.com/smartcontractkit/chainlink/v2/core/capabilities/gateway_connector" - "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" + trigger "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" webapitarget "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi/target" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" @@ -82,7 +82,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.Ser return nil, errors.New("gateway connector is required for web API Trigger capability") } connector := d.gatewayConnectorWrapper.GetGatewayConnector() - triggerSrvc, err := webapi.NewTrigger(spec.StandardCapabilitiesSpec.Config, d.registry, connector, log) + triggerSrvc, err := trigger.NewTrigger(spec.StandardCapabilitiesSpec.Config, d.registry, connector, log) if err != nil { return nil, fmt.Errorf("failed to create a Web API Trigger service: %w", err) }