From f49fe8429ce667d20600a5f1d78f7bb8257a7a1c Mon Sep 17 00:00:00 2001 From: dimitris Date: Thu, 7 Nov 2024 17:45:00 +0200 Subject: [PATCH] Commit and Execute - Custom transmission schedule (#306) --- commit/report.go | 29 +++++--- commit/report_test.go | 11 ++- execute/plugin.go | 48 ++++++++++--- internal/plugincommon/transmitters.go | 54 ++++++++++++++ internal/plugincommon/transmitters_test.go | 83 ++++++++++++++++++++++ 5 files changed, 205 insertions(+), 20 deletions(-) create mode 100644 internal/plugincommon/transmitters.go create mode 100644 internal/plugincommon/transmitters_test.go diff --git a/commit/report.go b/commit/report.go index fc363c70c..9a982699a 100644 --- a/commit/report.go +++ b/commit/report.go @@ -5,15 +5,23 @@ import ( "encoding/hex" "encoding/json" "fmt" + "time" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" + "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon/consensus" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3" ) +const ( + // transmissionDelayMultiplier is used to calculate the transmission delay for each oracle. + transmissionDelayMultiplier = 3 * time.Second +) + // ReportInfo is the info data that will be sent with the along with the report // It will be used to determine if the report should be accepted or not type ReportInfo struct { @@ -84,12 +92,24 @@ func (p *Plugin) Reports( return nil, fmt.Errorf("encode report info: %w", err) } + transmissionSchedule, err := plugincommon.GetTransmissionSchedule( + p.chainSupport, + maps.Keys(p.oracleIDToP2PID), + transmissionDelayMultiplier, + ) + if err != nil { + return nil, fmt.Errorf("get transmission schedule: %w", err) + } + p.lggr.Debugw("transmission schedule override", + "transmissionSchedule", transmissionSchedule, "oracleIDToP2PID", p.oracleIDToP2PID) + return []ocr3types.ReportPlus[[]byte]{ { ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{ Report: encodedReport, Info: infoBytes, }, + TransmissionScheduleOverride: transmissionSchedule, }, }, nil } @@ -127,15 +147,6 @@ func (p *Plugin) ShouldAcceptAttestedReport( func (p *Plugin) ShouldTransmitAcceptedReport( ctx context.Context, u uint64, r ocr3types.ReportWithInfo[[]byte], ) (bool, error) { - isWriter, err := p.chainSupport.SupportsDestChain(p.oracleID) - if err != nil { - return false, fmt.Errorf("can't know if it's a writer: %w", err) - } - if !isWriter { - p.lggr.Infow("not a writer, skipping report transmission") - return false, nil - } - // we only transmit reports if we are the "active" instance. // we can check this by reading the OCR configs from the home chain. isCandidate, err := p.isCandidateInstance(ctx) diff --git a/commit/report_test.go b/commit/report_test.go index ce8614cde..3c35f79f7 100644 --- a/commit/report_test.go +++ b/commit/report_test.go @@ -4,11 +4,14 @@ import ( "fmt" "testing" + "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/smartcontractkit/chainlink-ccip/internal/libs/testhelpers/rand" + "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/plugincommon" reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" @@ -130,10 +133,14 @@ func TestPluginReports(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + cs := plugincommon.NewMockChainSupport(t) p := Plugin{ - lggr: lggr, - reportCodec: reportCodec, + lggr: lggr, + reportCodec: reportCodec, + oracleIDToP2PID: map[commontypes.OracleID]libocrtypes.PeerID{1: {1}}, + chainSupport: cs, } + cs.EXPECT().SupportsDestChain(commontypes.OracleID(1)).Return(true, nil).Maybe() outcomeBytes, err := tc.outc.Encode() require.NoError(t, err) diff --git a/execute/plugin.go b/execute/plugin.go index e6c8adc98..6ce2c1838 100644 --- a/execute/plugin.go +++ b/execute/plugin.go @@ -6,6 +6,7 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" + "golang.org/x/exp/maps" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" @@ -15,6 +16,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" + "github.com/smartcontractkit/chainlink-ccip/execute/exectypes" "github.com/smartcontractkit/chainlink-ccip/execute/internal/gas" "github.com/smartcontractkit/chainlink-ccip/execute/report" @@ -31,6 +34,11 @@ import ( // maxReportSizeBytes that should be returned as an execution report payload. const maxReportSizeBytes = 250_000 +const ( + // transmissionDelayMultiplier is used to calculate the transmission delay for each oracle. + transmissionDelayMultiplier = 3 * time.Second +) + // Plugin implements the main ocr3 plugin logic. type Plugin struct { donID plugintypes.DonID @@ -39,11 +47,12 @@ type Plugin struct { destChain cciptypes.ChainSelector // providers - ccipReader readerpkg.CCIPReader - reportCodec cciptypes.ExecutePluginCodec - msgHasher cciptypes.MessageHasher - homeChain reader.HomeChain - discovery *discovery.ContractDiscoveryProcessor + ccipReader readerpkg.CCIPReader + reportCodec cciptypes.ExecutePluginCodec + msgHasher cciptypes.MessageHasher + homeChain reader.HomeChain + discovery *discovery.ContractDiscoveryProcessor + chainSupport plugincommon.ChainSupport oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID tokenDataObserver tokendata.TokenDataObserver @@ -96,6 +105,13 @@ func NewPlugin( reportingCfg.F, oracleIDToP2pID, ), + chainSupport: plugincommon.NewCCIPChainSupport( + lggr, + homeChain, + oracleIDToP2pID, + reportingCfg.OracleID, + destChain, + ), } } @@ -254,11 +270,25 @@ func (p *Plugin) Reports( return nil, fmt.Errorf("unable to encode report: %w", err) } + transmissionSchedule, err := plugincommon.GetTransmissionSchedule( + p.chainSupport, + maps.Keys(p.oracleIDToP2pID), + transmissionDelayMultiplier, + ) + if err != nil { + return nil, fmt.Errorf("get transmission schedule: %w", err) + } + p.lggr.Debugw("transmission schedule override", + "transmissionSchedule", transmissionSchedule, "oracleIDToP2PID", p.oracleIDToP2pID) + report := []ocr3types.ReportPlus[[]byte]{ - {ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{ - Report: encoded, - Info: nil, - }}, + { + ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{ + Report: encoded, + Info: nil, + }, + TransmissionScheduleOverride: transmissionSchedule, + }, } return report, nil diff --git a/internal/plugincommon/transmitters.go b/internal/plugincommon/transmitters.go new file mode 100644 index 000000000..488fdd2cb --- /dev/null +++ b/internal/plugincommon/transmitters.go @@ -0,0 +1,54 @@ +package plugincommon + +import ( + "fmt" + "time" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" +) + +// GetTransmissionSchedule returns a TransmissionSchedule for the provided oracles. +// It uses the ChainSupport service to query which oracles support the destination chain. +// It returns an error if no oracles support the destination chain. +// Read more about TransmissionDelay at ocr3types.TransmissionSchedule +// The transmissionDelayMultiplier is used in the following way: +// +// Assume that we have transmitters: [1, 3, 5] +// And transmissionDelayMultiplier = 5s +// Then the transmission delays will be: [5s, 10s, 15s] +func GetTransmissionSchedule( + chainSupport ChainSupport, + allTheOracles []commontypes.OracleID, + transmissionDelayMultiplier time.Duration, +) (*ocr3types.TransmissionSchedule, error) { + transmitters := make([]commontypes.OracleID, 0, len(allTheOracles)) + for _, oracleID := range allTheOracles { + supportsDestChain, err := chainSupport.SupportsDestChain(oracleID) + if err != nil { + return nil, fmt.Errorf("supports dest chain %d: %w", oracleID, err) + } + if supportsDestChain { + transmitters = append(transmitters, oracleID) + } + } + + transmissionDelays := make([]time.Duration, len(transmitters)) + + for i := range transmissionDelays { + transmissionDelays[i] = (transmissionDelayMultiplier) * time.Duration(i+1) + } + + if len(transmitters) == 0 { + return nil, fmt.Errorf("no transmitters") + } + + if len(transmitters) != len(transmissionDelays) { + return nil, fmt.Errorf("critical issue mismatched transmitters and transmission delays") + } + + return &ocr3types.TransmissionSchedule{ + Transmitters: transmitters, + TransmissionDelays: transmissionDelays, + }, nil +} diff --git a/internal/plugincommon/transmitters_test.go b/internal/plugincommon/transmitters_test.go new file mode 100644 index 000000000..01388304b --- /dev/null +++ b/internal/plugincommon/transmitters_test.go @@ -0,0 +1,83 @@ +package plugincommon_test + +import ( + "errors" + "testing" + "time" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/stretchr/testify/require" + + plugincommon2 "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" + "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/plugincommon" +) + +func TestGetTransmissionSchedule(t *testing.T) { + testCases := []struct { + name string + allTheOracles []commontypes.OracleID + oraclesSupportingDest []commontypes.OracleID + transmissionDelayMultiplier time.Duration + expectedTransmitters []commontypes.OracleID + expectedTransmissionDelays []time.Duration + expectedError bool + chainSupportReturnsError bool + }{ + { + name: "no oracles supporting dest leads to error", + allTheOracles: []commontypes.OracleID{1, 2, 3}, + oraclesSupportingDest: []commontypes.OracleID{}, + transmissionDelayMultiplier: 5 * time.Second, + expectedTransmitters: []commontypes.OracleID{}, + expectedTransmissionDelays: []time.Duration{}, + expectedError: true, + }, + { + name: "some transmitters supporting dest", + allTheOracles: []commontypes.OracleID{1, 2, 3}, + oraclesSupportingDest: []commontypes.OracleID{1, 3}, + transmissionDelayMultiplier: 5 * time.Second, + expectedTransmitters: []commontypes.OracleID{1, 3}, + expectedTransmissionDelays: []time.Duration{5 * time.Second, 10 * time.Second}, + expectedError: false, + }, + { + name: "chainsupport returns error", + allTheOracles: []commontypes.OracleID{1, 2, 3}, + oraclesSupportingDest: []commontypes.OracleID{1, 3}, + transmissionDelayMultiplier: 5 * time.Second, + expectedError: true, + chainSupportReturnsError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := plugincommon.NewMockChainSupport(t) + destSupportingOraclesSet := mapset.NewSet(tc.oraclesSupportingDest...) + for _, oracleID := range tc.allTheOracles { + var err error + if tc.chainSupportReturnsError { + err = errors.New("some error") + } + cs.On("SupportsDestChain", oracleID). + Return(destSupportingOraclesSet.Contains(oracleID), err).Maybe() + } + + transmissionSchedule, err := plugincommon2.GetTransmissionSchedule( + cs, + tc.allTheOracles, + tc.transmissionDelayMultiplier, + ) + if tc.expectedError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectedTransmitters, transmissionSchedule.Transmitters) + require.Equal(t, tc.expectedTransmissionDelays, transmissionSchedule.TransmissionDelays) + }) + } +}