From efec077036730c837cee44d8346dc75cdbf75466 Mon Sep 17 00:00:00 2001 From: asoliman Date: Tue, 3 Sep 2024 06:57:55 +0400 Subject: [PATCH] Add PluginProcessor interface to create multiple Processors under OCR plugin. The idea is to separate the logic of different types of observations and outcomes into separate processors. This makes it easier to manage and test the logic of each type of observation/outcome without affecting each others. Some of them will implement state machines (e.g. merkleroot), others might implement simpler logic. (e.g. token) Also makes running them in parallel more streamlined. The OCR plugin becomes a coordinator/collector of these SubPlugins. Example Pseudo code: ``` OCRPlugin { nodeID merkleProcessor tokenProcessor feeProcessor ... } OCRPlugin.Observer { mObs := merkleProcessor.Observer tObs := tokenProcessor.Observer fObs := feeProcessor.Observer return Observation{mObs, tObs, fObs} } OCRPlugin.Validate { mObs := merkleProcessor.Validate tObs := tokenProcessor.Validate fObs := feeProcessor.Validate check errors for each return nil } OCRPlugin.Outcome { mOut := merkleProcessor.Outcome tOut := tokenProcessor.Outcome fOut := feeProcessor.Outcome return Outcome{mOut, tOut, fOut} } OCRPlugin.Report { return Report{mOut.X, tOut.Y, fOut.Z} } ``` Notice all PluginProcessor interface functions are using `prevOutcome` instead of `outCtx`. We're interested in the prevOutcome, and it makes it easier to have all decoding on the top level (OCR plugin), otherwise there might be cyclic dependencies or just complicating the code more. Signed-off-by: asoliman Plugin observations Still outcome is not working Signed-off-by: asoliman Plugin Outcome and Report Signed-off-by: asoliman Add tokenworker and some cleanups Signed-off-by: asoliman Add gasworker Signed-off-by: asoliman linting Signed-off-by: asoliman Add subplugin interface Signed-off-by: asoliman Cleaning up Signed-off-by: asoliman linting Signed-off-by: asoliman Add SubPlugin to Mockery Signed-off-by: asoliman Rename Worker to Processor Signed-off-by: asoliman Cleaning round Signed-off-by: asoliman Prallelize sub plugins Signed-off-by: asoliman Fix mockery Signed-off-by: asoliman Use SubPlugin interface Signed-off-by: asoliman linting Signed-off-by: asoliman run mockery Signed-off-by: asoliman Add docs Signed-off-by: asoliman Return empty report when merkleroots not generated Signed-off-by: asoliman Rename packages Signed-off-by: asoliman Rename variables Signed-off-by: asoliman docs and lint Signed-off-by: asoliman Cleaning Signed-off-by: asoliman Refactoring based on reviews Signed-off-by: asoliman Remove encode observation from merkleroot subprocessor Signed-off-by: asoliman Remove go routines for now to simplify logic Signed-off-by: asoliman Update commit/plugin.go Co-authored-by: Will Winder review comments Signed-off-by: asoliman refactor commit observation to use better names Signed-off-by: asoliman --- .mockery.yaml | 5 +- commit/chainfee/processor.go | 65 ++++ commit/chainfee/types.go | 16 + commit/chainfee/validate_test.go | 62 +++ commit/{ => merkleroot}/observation.go | 70 ++-- commit/{ => merkleroot}/observation_test.go | 80 ++-- commit/{ => merkleroot}/outcome.go | 55 ++- commit/{ => merkleroot}/outcome_test.go | 16 +- commit/merkleroot/processor.go | 59 +++ commit/merkleroot/types.go | 156 ++++++++ commit/merkleroot/validate_observation.go | 180 +++++++++ .../merkleroot/validate_observation_test.go | 293 +++++++++++++++ commit/plugin.go | 221 ++++++++--- commit/plugin_e2e_test.go | 80 ++-- commit/query.go | 12 - commit/report.go | 14 +- commit/tokenprice/processor.go | 65 ++++ commit/tokenprice/types.go | 23 ++ commit/tokenprice/validate_test.go | 66 ++++ commit/types.go | 215 +---------- commit/validate_observation.go | 204 ++-------- commit/validate_observation_test.go | 355 ------------------ execute/factory.go | 5 +- execute/report/report_test.go | 3 +- mocks/commit/{ => merkleroot}/observer.go | 102 +---- mocks/{commit => shared}/chain_support.go | 3 +- mocks/shared/plugin_processor.go | 258 +++++++++++++ pluginconfig/commit_test.go | 5 +- pluginconfig/execute_test.go | 3 +- plugintypes/commit_test.go | 3 +- {commit => shared}/chain_support.go | 28 +- {commit => shared}/chain_support_test.go | 11 +- shared/plugin_processor.go | 44 +++ 33 files changed, 1694 insertions(+), 1083 deletions(-) create mode 100644 commit/chainfee/processor.go create mode 100644 commit/chainfee/types.go create mode 100644 commit/chainfee/validate_test.go rename commit/{ => merkleroot}/observation.go (81%) rename commit/{ => merkleroot}/observation_test.go (87%) rename commit/{ => merkleroot}/outcome.go (87%) rename commit/{ => merkleroot}/outcome_test.go (85%) create mode 100644 commit/merkleroot/processor.go create mode 100644 commit/merkleroot/types.go create mode 100644 commit/merkleroot/validate_observation.go create mode 100644 commit/merkleroot/validate_observation_test.go delete mode 100644 commit/query.go create mode 100644 commit/tokenprice/processor.go create mode 100644 commit/tokenprice/types.go create mode 100644 commit/tokenprice/validate_test.go rename mocks/commit/{ => merkleroot}/observer.go (65%) rename mocks/{commit => shared}/chain_support.go (99%) create mode 100644 mocks/shared/plugin_processor.go rename {commit => shared}/chain_support.go (83%) rename {commit => shared}/chain_support_test.go (96%) create mode 100644 shared/plugin_processor.go diff --git a/.mockery.yaml b/.mockery.yaml index 2eb60d839..594a22ca5 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -8,9 +8,12 @@ packages: github.com/smartcontractkit/chainlink-ccip/execute/internal/gen: interfaces: ExecutePluginCodec: - github.com/smartcontractkit/chainlink-ccip/commit: + github.com/smartcontractkit/chainlink-ccip/commit/merkleroot: interfaces: Observer: + github.com/smartcontractkit/chainlink-ccip/shared: + interfaces: + PluginProcessor: ChainSupport: github.com/smartcontractkit/chainlink-ccip/internal/reader: interfaces: diff --git a/commit/chainfee/processor.go b/commit/chainfee/processor.go new file mode 100644 index 000000000..3b981cfc7 --- /dev/null +++ b/commit/chainfee/processor.go @@ -0,0 +1,65 @@ +package chainfee + +import ( + "context" + "fmt" + + mapset "github.com/deckarep/golang-set/v2" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/shared" +) + +type Processor struct { +} + +func NewProcessor() *Processor { + return &Processor{} +} + +func (w *Processor) Query(ctx context.Context, prevOutcome Outcome) (Query, error) { + return Query{}, nil +} + +func (w *Processor) Observation( + ctx context.Context, + prevOutcome Outcome, + query Query, +) (Observation, error) { + return Observation{}, nil +} + +func (w *Processor) Outcome( + prevOutcome Outcome, + query Query, + aos []shared.AttributedObservation[Observation], +) (Outcome, error) { + return Outcome{}, nil +} + +func (w *Processor) ValidateObservation( + prevOutcome Outcome, + query Query, + ao shared.AttributedObservation[Observation], +) error { + //TODO: Validate token prices + return nil +} + +func validateObservedGasPrices(gasPrices []cciptypes.GasPriceChain) error { + // Duplicate gas prices must not appear for the same chain and must not be empty. + gasPriceChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, g := range gasPrices { + if gasPriceChains.Contains(g.ChainSel) { + return fmt.Errorf("duplicate gas price for chain %d", g.ChainSel) + } + gasPriceChains.Add(g.ChainSel) + if g.GasPrice.IsEmpty() { + return fmt.Errorf("gas price must not be empty") + } + } + + return nil +} + +var _ shared.PluginProcessor[Query, Observation, Outcome] = &Processor{} diff --git a/commit/chainfee/types.go b/commit/chainfee/types.go new file mode 100644 index 000000000..a7c5b0151 --- /dev/null +++ b/commit/chainfee/types.go @@ -0,0 +1,16 @@ +package chainfee + +import ( + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +type Query struct { +} + +type Outcome struct { + GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` +} + +type Observation struct { + GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` +} diff --git a/commit/chainfee/validate_test.go b/commit/chainfee/validate_test.go new file mode 100644 index 000000000..d37b752a1 --- /dev/null +++ b/commit/chainfee/validate_test.go @@ -0,0 +1,62 @@ +package chainfee + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +func Test_validateObservedGasPrices(t *testing.T) { + testCases := []struct { + name string + gasPrices []cciptypes.GasPriceChain + expErr bool + }{ + { + name: "empty is valid", + gasPrices: []cciptypes.GasPriceChain{}, + expErr: false, + }, + { + name: "all valid", + gasPrices: []cciptypes.GasPriceChain{ + cciptypes.NewGasPriceChain(big.NewInt(10), 1), + cciptypes.NewGasPriceChain(big.NewInt(20), 2), + cciptypes.NewGasPriceChain(big.NewInt(1312), 3), + }, + expErr: false, + }, + { + name: "duplicate gas price", + gasPrices: []cciptypes.GasPriceChain{ + cciptypes.NewGasPriceChain(big.NewInt(10), 1), + cciptypes.NewGasPriceChain(big.NewInt(20), 2), + cciptypes.NewGasPriceChain(big.NewInt(1312), 1), // notice we already have a gas price for chain 1 + }, + expErr: true, + }, + { + name: "empty gas price", + gasPrices: []cciptypes.GasPriceChain{ + cciptypes.NewGasPriceChain(big.NewInt(10), 1), + cciptypes.NewGasPriceChain(big.NewInt(20), 2), + cciptypes.NewGasPriceChain(nil, 3), // nil + }, + expErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedGasPrices(tc.gasPrices) + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/commit/observation.go b/commit/merkleroot/observation.go similarity index 81% rename from commit/observation.go rename to commit/merkleroot/observation.go index c445e41ec..a20e44eb8 100644 --- a/commit/observation.go +++ b/commit/merkleroot/observation.go @@ -1,4 +1,4 @@ -package commit +package merkleroot import ( "context" @@ -8,38 +8,47 @@ import ( "sync" "time" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/hashutil" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/libocr/commontypes" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - "github.com/smartcontractkit/chainlink-ccip/internal/reader" "github.com/smartcontractkit/chainlink-ccip/plugintypes" + "github.com/smartcontractkit/chainlink-ccip/shared" + + "github.com/smartcontractkit/chainlink-ccip/internal/reader" ) -func (p *Plugin) ObservationQuorum(_ ocr3types.OutcomeContext, _ types.Query) (ocr3types.Quorum, error) { +func (w *Processor) ObservationQuorum(_ ocr3types.OutcomeContext, _ types.Query) (ocr3types.Quorum, error) { // Across all chains we require at least 2F+1 observations. return ocr3types.QuorumTwoFPlusOne, nil } -func (p *Plugin) Observation( - ctx context.Context, outCtx ocr3types.OutcomeContext, _ types.Query, -) (types.Observation, error) { +func (w *Processor) Query(ctx context.Context, prevOutcome Outcome) (Query, error) { + return Query{}, nil +} + +func (w *Processor) Observation( + ctx context.Context, + prevOutcome Outcome, + _ Query, +) (Observation, error) { tStart := time.Now() - observation, nextState := p.getObservation(ctx, outCtx) - p.lggr.Infow("Sending Observation", + observation, nextState := w.getObservation(ctx, prevOutcome) + w.lggr.Infow("Sending MerkleRootObs", "observation", observation, "nextState", nextState, "observationDuration", time.Since(tStart)) - return observation.Encode() + return observation, nil } -func (p *Plugin) getObservation(ctx context.Context, outCtx ocr3types.OutcomeContext) (Observation, State) { - previousOutcome, nextState := p.decodeOutcome(outCtx.PreviousOutcome) +func (w *Processor) getObservation(ctx context.Context, previousOutcome Outcome) (Observation, State) { + nextState := previousOutcome.NextState() switch nextState { case SelectingRangesForReport: - offRampNextSeqNums := p.observer.ObserveOffRampNextSeqNums(ctx) + offRampNextSeqNums := w.observer.ObserveOffRampNextSeqNums(ctx) return Observation{ // TODO: observe OnRamp max seq nums. The use of offRampNextSeqNums here effectively disables batching, // e.g. the ranges selected for each chain will be [x, x] (e.g. [46, 46]), which means reports will only @@ -47,22 +56,20 @@ func (p *Plugin) getObservation(ctx context.Context, outCtx ocr3types.OutcomeCon // need to be done in a future change. OnRampMaxSeqNums: offRampNextSeqNums, OffRampNextSeqNums: offRampNextSeqNums, - FChain: p.observer.ObserveFChain(), + FChain: w.observer.ObserveFChain(), }, nextState case BuildingReport: return Observation{ - MerkleRoots: p.observer.ObserveMerkleRoots(ctx, previousOutcome.RangesSelectedForReport), - GasPrices: p.observer.ObserveGasPrices(ctx), - TokenPrices: p.observer.ObserveTokenPrices(ctx), - FChain: p.observer.ObserveFChain(), + MerkleRoots: w.observer.ObserveMerkleRoots(ctx, previousOutcome.RangesSelectedForReport), + FChain: w.observer.ObserveFChain(), }, nextState case WaitingForReportTransmission: return Observation{ - OffRampNextSeqNums: p.observer.ObserveOffRampNextSeqNums(ctx), - FChain: p.observer.ObserveFChain(), + OffRampNextSeqNums: w.observer.ObserveOffRampNextSeqNums(ctx), + FChain: w.observer.ObserveFChain(), }, nextState default: - p.lggr.Errorw("Unexpected state", "state", nextState) + w.lggr.Errorw("Unexpected state", "state", nextState) return Observation{}, nextState } } @@ -74,10 +81,6 @@ type Observer interface { // ObserveMerkleRoots computes the merkle roots for the given sequence number ranges ObserveMerkleRoots(ctx context.Context, ranges []plugintypes.ChainRange) []cciptypes.MerkleRootChain - ObserveTokenPrices(ctx context.Context) []cciptypes.TokenPrice - - ObserveGasPrices(ctx context.Context) []cciptypes.GasPriceChain - ObserveFChain() map[cciptypes.ChainSelector]int } @@ -85,7 +88,7 @@ type ObserverImpl struct { lggr logger.Logger homeChain reader.HomeChain nodeID commontypes.OracleID - chainSupport ChainSupport + chainSupport shared.ChainSupport ccipReader reader.CCIP msgHasher cciptypes.MessageHasher } @@ -215,14 +218,6 @@ func (o ObserverImpl) computeMerkleRoot(ctx context.Context, msgs []cciptypes.Me return root, nil } -func (o ObserverImpl) ObserveTokenPrices(ctx context.Context) []cciptypes.TokenPrice { - return []cciptypes.TokenPrice{} -} - -func (o ObserverImpl) ObserveGasPrices(ctx context.Context) []cciptypes.GasPriceChain { - return []cciptypes.GasPriceChain{} -} - func (o ObserverImpl) ObserveFChain() map[cciptypes.ChainSelector]int { fChain, err := o.homeChain.GetFChain() if err != nil { @@ -232,6 +227,3 @@ func (o ObserverImpl) ObserveFChain() map[cciptypes.ChainSelector]int { } return fChain } - -// Interface compliance check -var _ Observer = (*ObserverImpl)(nil) diff --git a/commit/observation_test.go b/commit/merkleroot/observation_test.go similarity index 87% rename from commit/observation_test.go rename to commit/merkleroot/observation_test.go index dd4d83874..136a7569c 100644 --- a/commit/observation_test.go +++ b/commit/merkleroot/observation_test.go @@ -1,4 +1,4 @@ -package commit +package merkleroot import ( "context" @@ -10,17 +10,18 @@ import ( mapset "github.com/deckarep/golang-set/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-ccip/mocks/shared" "github.com/smartcontractkit/libocr/commontypes" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-ccip/internal/mocks" - commitmocks "github.com/smartcontractkit/chainlink-ccip/mocks/commit" + "github.com/smartcontractkit/chainlink-ccip/mocks/commit/merkleroot" reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" "github.com/smartcontractkit/chainlink-ccip/plugintypes" ) @@ -33,18 +34,6 @@ func Test_Observation(t *testing.T) { MerkleRoot: [32]byte{1}, }, } - gasPrices := []cciptypes.GasPriceChain{ - { - GasPrice: cciptypes.NewBigIntFromInt64(99), - ChainSel: 8, - }, - } - tokenPrices := []cciptypes.TokenPrice{ - { - TokenID: "token23", - Price: cciptypes.NewBigIntFromInt64(80761), - }, - } offRampNextSeqNums := []plugintypes.SeqNumChain{ { ChainSel: 456, @@ -58,7 +47,7 @@ func Test_Observation(t *testing.T) { testCases := []struct { name string previousOutcome Outcome - getObserver func(t *testing.T) *commitmocks.MockObserver + getObserver func(t *testing.T) *merkleroot.MockObserver expObs Observation }{ { @@ -66,8 +55,8 @@ func Test_Observation(t *testing.T) { previousOutcome: Outcome{ OutcomeType: ReportTransmitted, }, - getObserver: func(t *testing.T) *commitmocks.MockObserver { - observer := commitmocks.NewMockObserver(t) + getObserver: func(t *testing.T) *merkleroot.MockObserver { + observer := merkleroot.NewMockObserver(t) observer.EXPECT().ObserveOffRampNextSeqNums(mock.Anything).Once().Return(offRampNextSeqNums) observer.EXPECT().ObserveFChain().Once().Return(fChain) return observer @@ -89,23 +78,19 @@ func Test_Observation(t *testing.T) { }, }, }, - getObserver: func(t *testing.T) *commitmocks.MockObserver { - observer := commitmocks.NewMockObserver(t) + getObserver: func(t *testing.T) *merkleroot.MockObserver { + observer := merkleroot.NewMockObserver(t) observer.EXPECT().ObserveMerkleRoots(mock.Anything, []plugintypes.ChainRange{ { ChainSel: 1, SeqNumRange: cciptypes.SeqNumRange{5, 78}, }, }).Once().Return(merkleRoots) - observer.EXPECT().ObserveGasPrices(mock.Anything).Once().Return(gasPrices) - observer.EXPECT().ObserveTokenPrices(mock.Anything).Once().Return(tokenPrices) observer.EXPECT().ObserveFChain().Once().Return(fChain) return observer }, expObs: Observation{ MerkleRoots: merkleRoots, - GasPrices: gasPrices, - TokenPrices: tokenPrices, FChain: fChain, }, }, @@ -114,8 +99,8 @@ func Test_Observation(t *testing.T) { previousOutcome: Outcome{ OutcomeType: ReportInFlight, }, - getObserver: func(t *testing.T) *commitmocks.MockObserver { - observer := commitmocks.NewMockObserver(t) + getObserver: func(t *testing.T) *merkleroot.MockObserver { + observer := merkleroot.NewMockObserver(t) observer.EXPECT().ObserveOffRampNextSeqNums(mock.Anything).Once().Return(offRampNextSeqNums) observer.EXPECT().ObserveFChain().Once().Return(fChain) return observer @@ -133,24 +118,17 @@ func Test_Observation(t *testing.T) { observer := tc.getObserver(t) defer observer.AssertExpectations(t) - p := Plugin{ + p := Processor{ lggr: logger.Test(t), observer: observer, } - previousOutcomeEncoded, err := tc.previousOutcome.Encode() - assert.NoError(t, err) - - result, err := p.Observation( + actualObs, err := p.Observation( ctx, - ocr3types.OutcomeContext{PreviousOutcome: previousOutcomeEncoded}, - types.Query{}, + tc.previousOutcome, + Query{}, ) - assert.NoError(t, err) - - actualObs, err := DecodeCommitPluginObservation(result) - assert.NoError(t, err) - + require.NoError(t, err) assert.Equal(t, tc.expObs, actualObs) }) } @@ -164,12 +142,12 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) { testCases := []struct { name string expResult []plugintypes.SeqNumChain - getDeps func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) + getDeps func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) }{ { name: "Happy path", - getDeps: func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) { - chainSupport := commitmocks.NewMockChainSupport(t) + getDeps: func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) { + chainSupport := shared.NewMockChainSupport(t) chainSupport.EXPECT().SupportsDestChain(nodeID).Return(true, nil) chainSupport.EXPECT().KnownSourceChainsSlice().Return(knownSourceChains, nil) ccipReader := reader_mock.NewMockCCIP(t) @@ -184,8 +162,8 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) { }, { name: "nil is returned when supportsDestChain is false", - getDeps: func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) { - chainSupport := commitmocks.NewMockChainSupport(t) + getDeps: func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) { + chainSupport := shared.NewMockChainSupport(t) chainSupport.EXPECT().SupportsDestChain(nodeID).Return(false, nil) ccipReader := reader_mock.NewMockCCIP(t) return chainSupport, ccipReader @@ -194,8 +172,8 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) { }, { name: "nil is returned when supportsDestChain errors", - getDeps: func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) { - chainSupport := commitmocks.NewMockChainSupport(t) + getDeps: func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) { + chainSupport := shared.NewMockChainSupport(t) chainSupport.EXPECT().SupportsDestChain(nodeID).Return(false, errors.New("some error")) ccipReader := reader_mock.NewMockCCIP(t) return chainSupport, ccipReader @@ -204,8 +182,8 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) { }, { name: "nil is returned when knownSourceChains errors", - getDeps: func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) { - chainSupport := commitmocks.NewMockChainSupport(t) + getDeps: func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) { + chainSupport := shared.NewMockChainSupport(t) chainSupport.EXPECT().SupportsDestChain(nodeID).Return(true, nil) chainSupport.EXPECT().KnownSourceChainsSlice().Return(nil, errors.New("some error")) ccipReader := reader_mock.NewMockCCIP(t) @@ -215,8 +193,8 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) { }, { name: "nil is returned when nextSeqNums returns incorrect number of seq nums", - getDeps: func(t *testing.T) (*commitmocks.MockChainSupport, *reader_mock.MockCCIP) { - chainSupport := commitmocks.NewMockChainSupport(t) + getDeps: func(t *testing.T) (*shared.MockChainSupport, *reader_mock.MockCCIP) { + chainSupport := shared.NewMockChainSupport(t) chainSupport.EXPECT().SupportsDestChain(nodeID).Return(true, nil) chainSupport.EXPECT().KnownSourceChainsSlice().Return(knownSourceChains, nil) ccipReader := reader_mock.NewMockCCIP(t) @@ -442,7 +420,7 @@ func Test_ObserveMerkleRoots(t *testing.T) { ).Return(tc.msgsBetweenSeqNums[r.ChainSel], err) } - chainSupport := commitmocks.NewMockChainSupport(t) + chainSupport := shared.NewMockChainSupport(t) if tc.supportedChainsFails { chainSupport.On("SupportedChains", nodeID).Return( mapset.NewSet[cciptypes.ChainSelector](), fmt.Errorf("error"), diff --git a/commit/outcome.go b/commit/merkleroot/outcome.go similarity index 87% rename from commit/outcome.go rename to commit/merkleroot/outcome.go index c9144007b..6d206e2b3 100644 --- a/commit/outcome.go +++ b/commit/merkleroot/outcome.go @@ -1,14 +1,13 @@ -package commit +package merkleroot import ( "fmt" "sort" "time" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "golang.org/x/exp/maps" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-ccip/shared" "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" @@ -20,25 +19,28 @@ import ( // - chooses the seq num ranges for the next round // - builds a report // - checks for the transmission of a previous report -func (p *Plugin) Outcome( - outCtx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, -) (ocr3types.Outcome, error) { +func (w *Processor) Outcome( + prevOutcome Outcome, + query Query, + aos []shared.AttributedObservation[Observation], +) (Outcome, error) { tStart := time.Now() - outcome, nextState := p.getOutcome(outCtx, aos) - p.lggr.Infow("Sending Outcome", - "outcome", outcome, "oid", p.nodeID, "nextState", nextState, "outcomeDuration", time.Since(tStart)) - return outcome.Encode() + outcome, nextState := w.getOutcome(prevOutcome, query, aos) + w.lggr.Infow("Sending Outcome", + "outcome", outcome, "oid", w.nodeID, "nextState", nextState, "outcomeDuration", time.Since(tStart)) + return outcome, nil } -func (p *Plugin) getOutcome( - outCtx ocr3types.OutcomeContext, aos []types.AttributedObservation, +func (w *Processor) getOutcome( + previousOutcome Outcome, + commitQuery Query, + aos []shared.AttributedObservation[Observation], ) (Outcome, State) { - previousOutcome, nextState := p.decodeOutcome(outCtx.PreviousOutcome) - commitQuery := Query{} + nextState := previousOutcome.NextState() - consensusObservation, err := getConsensusObservation(p.lggr, p.reportingCfg.F, p.cfg.DestChain, aos) + consensusObservation, err := getConsensusObservation(w.lggr, w.reportingCfg.F, w.cfg.DestChain, aos) if err != nil { - p.lggr.Warnw("Get consensus observation failed, empty outcome", "error", err) + w.lggr.Warnw("Get consensus observation failed, empty outcome", "error", err) return Outcome{}, nextState } @@ -49,9 +51,9 @@ func (p *Plugin) getOutcome( return buildReport(commitQuery, consensusObservation, previousOutcome), nextState case WaitingForReportTransmission: return checkForReportTransmission( - p.lggr, p.cfg.MaxReportTransmissionCheckAttempts, previousOutcome, consensusObservation), nextState + w.lggr, w.cfg.MaxReportTransmissionCheckAttempts, previousOutcome, consensusObservation), nextState default: - p.lggr.Warnw("Unexpected next state in Outcome", "state", nextState) + w.lggr.Warnw("Unexpected next state in Outcome", "state", nextState) return Outcome{}, nextState } } @@ -131,8 +133,6 @@ func buildReport( outcome := Outcome{ OutcomeType: outcomeType, RootsToReport: roots, - GasPrices: consensusObservation.GasPricesArray(), - TokenPrices: consensusObservation.TokenPricesArray(), OffRampNextSeqNums: prevOutcome.OffRampNextSeqNums, } @@ -188,7 +188,7 @@ func getConsensusObservation( lggr logger.Logger, F int, destChain cciptypes.ChainSelector, - aos []types.AttributedObservation, + aos []shared.AttributedObservation[Observation], ) (ConsensusObservation, error) { aggObs := aggregateObservations(aos) fChains := fChainConsensus(lggr, F, aggObs.FChain) @@ -200,11 +200,7 @@ func getConsensusObservation( } consensusObs := ConsensusObservation{ - MerkleRoots: merkleRootConsensus(lggr, aggObs.MerkleRoots, fChains), - // TODO: use consensus of observed gas prices - GasPrices: make(map[cciptypes.ChainSelector]cciptypes.BigInt), - // TODO: use consensus of observed token prices - TokenPrices: make(map[types.Account]cciptypes.BigInt), + MerkleRoots: merkleRootConsensus(lggr, aggObs.MerkleRoots, fChains), OnRampMaxSeqNums: onRampMaxSeqNumsConsensus(lggr, aggObs.OnRampMaxSeqNums, fChains), OffRampNextSeqNums: offRampMaxSeqNumsConsensus(lggr, aggObs.OffRampNextSeqNums, fDestChain), FChain: fChains, @@ -213,9 +209,10 @@ func getConsensusObservation( return consensusObs, nil } -// Given a mapping from chains to a list of merkle roots, return a mapping from chains to a single consensus merkle -// root. The consensus merkle root for a given chain is the merkle root with the most observations that was observed at -// least fChain times. +// Given a mapping from chains to a list of merkle roots, +// return a mapping from chains to a single consensus merkle root. +// The consensus merkle root for a given chain is the merkle root with the +// most observations that was observed at least fChain times. func merkleRootConsensus( lggr logger.Logger, rootsByChain map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain, diff --git a/commit/outcome_test.go b/commit/merkleroot/outcome_test.go similarity index 85% rename from commit/outcome_test.go rename to commit/merkleroot/outcome_test.go index c810b0c56..b231bc875 100644 --- a/commit/outcome_test.go +++ b/commit/merkleroot/outcome_test.go @@ -1,13 +1,13 @@ -package commit +package merkleroot import ( "testing" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) func Test_buildReport(t *testing.T) { @@ -27,14 +27,6 @@ func Test_buildReport(t *testing.T) { MerkleRoot: cciptypes.Bytes32{2}, }, }, - GasPrices: map[cciptypes.ChainSelector]cciptypes.BigInt{ - cciptypes.ChainSelector(1): cciptypes.NewBigIntFromInt64(1000), - cciptypes.ChainSelector(2): cciptypes.NewBigIntFromInt64(2000), - }, - TokenPrices: map[types.Account]cciptypes.BigInt{ - types.Account("1"): cciptypes.NewBigIntFromInt64(1000), - types.Account("2"): cciptypes.NewBigIntFromInt64(2000), - }, } for i := 0; i < rounds; i++ { diff --git a/commit/merkleroot/processor.go b/commit/merkleroot/processor.go new file mode 100644 index 000000000..a227d26e4 --- /dev/null +++ b/commit/merkleroot/processor.go @@ -0,0 +1,59 @@ +package merkleroot + +import ( + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/pluginconfig" + "github.com/smartcontractkit/chainlink-ccip/shared" +) + +// Processor is the processor responsible for +// reading next messages and building merkle roots for them, +// It's setup to use RMN to query which messages to include in the merkle root and ensures +// the newly built merkle roots are the same as RMN roots. +type Processor struct { + nodeID commontypes.OracleID + cfg pluginconfig.CommitPluginConfig + lggr logger.Logger + observer Observer + ccipReader reader.CCIP + reportingCfg ocr3types.ReportingPluginConfig + chainSupport shared.ChainSupport +} + +// NewProcessor creates a new Processor +func NewProcessor( + oracleID commontypes.OracleID, + lggr logger.Logger, + cfg pluginconfig.CommitPluginConfig, + homeChain reader.HomeChain, + ccipReader reader.CCIP, + msgHasher cciptypes.MessageHasher, + reportingCfg ocr3types.ReportingPluginConfig, + chainSupport shared.ChainSupport, +) *Processor { + observer := ObserverImpl{ + lggr, + homeChain, + oracleID, + chainSupport, + ccipReader, + msgHasher, + } + return &Processor{ + nodeID: oracleID, + cfg: cfg, + lggr: lggr, + ccipReader: ccipReader, + observer: observer, + reportingCfg: reportingCfg, + chainSupport: chainSupport, + } +} + +var _ shared.PluginProcessor[Query, Observation, Outcome] = &Processor{} diff --git a/commit/merkleroot/types.go b/commit/merkleroot/types.go new file mode 100644 index 000000000..571ccd8ef --- /dev/null +++ b/commit/merkleroot/types.go @@ -0,0 +1,156 @@ +package merkleroot + +import ( + "sort" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/plugintypes" + "github.com/smartcontractkit/chainlink-ccip/shared" +) + +type Query struct { + RmnOnRampMaxSeqNums []plugintypes.SeqNumChain + MerkleRoots []cciptypes.MerkleRootChain +} + +func NewCommitQuery(rmnOnRampMaxSeqNums []plugintypes.SeqNumChain, merkleRoots []cciptypes.MerkleRootChain) Query { + return Query{ + RmnOnRampMaxSeqNums: rmnOnRampMaxSeqNums, + MerkleRoots: merkleRoots, + } +} + +type Observation struct { + MerkleRoots []cciptypes.MerkleRootChain `json:"merkleRoots"` + OnRampMaxSeqNums []plugintypes.SeqNumChain `json:"onRampMaxSeqNums"` + OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` + FChain map[cciptypes.ChainSelector]int `json:"fChain"` +} + +// MerkleAggregatedObservation is the aggregation of a list of observations +type MerkleAggregatedObservation struct { + // A map from chain selectors to the list of merkle roots observed for each chain + MerkleRoots map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain + + // A map from chain selectors to the list of OnRamp max sequence numbers observed for each chain + OnRampMaxSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum + + // A map from chain selectors to the list of OffRamp next sequence numbers observed for each chain + OffRampNextSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum + + // A map from chain selectors to the list of f (failure tolerance) observed for each chain + FChain map[cciptypes.ChainSelector][]int +} + +// aggregateObservations takes a list of observations and produces an MerkleAggregatedObservation +func aggregateObservations(aos []shared.AttributedObservation[Observation]) MerkleAggregatedObservation { + aggObs := MerkleAggregatedObservation{ + MerkleRoots: make(map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain), + OnRampMaxSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), + OffRampNextSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), + FChain: make(map[cciptypes.ChainSelector][]int), + } + + for _, ao := range aos { + obs := ao.Observation + // MerkleRoots + for _, merkleRoot := range obs.MerkleRoots { + aggObs.MerkleRoots[merkleRoot.ChainSel] = + append(aggObs.MerkleRoots[merkleRoot.ChainSel], merkleRoot) + } + + // OnRampMaxSeqNums + for _, seqNumChain := range obs.OnRampMaxSeqNums { + aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel] = + append(aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) + } + + // OffRampNextSeqNums + for _, seqNumChain := range obs.OffRampNextSeqNums { + aggObs.OffRampNextSeqNums[seqNumChain.ChainSel] = + append(aggObs.OffRampNextSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) + } + + // FChain + for chainSel, f := range obs.FChain { + aggObs.FChain[chainSel] = append(aggObs.FChain[chainSel], f) + } + } + + return aggObs +} + +// ConsensusObservation holds the consensus values for all chains across all observations in a round +type ConsensusObservation struct { + // A map from chain selectors to each chain's consensus merkle root + MerkleRoots map[cciptypes.ChainSelector]cciptypes.MerkleRootChain + + // A map from chain selectors to each chain's consensus OnRamp max sequence number + OnRampMaxSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum + + // A map from chain selectors to each chain's consensus OffRamp next sequence number + OffRampNextSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum + + // A map from chain selectors to each chain's consensus f (failure tolerance) + FChain map[cciptypes.ChainSelector]int +} + +type OutcomeType int + +const ( + ReportIntervalsSelected OutcomeType = iota + 1 + ReportGenerated + ReportEmpty + ReportInFlight + ReportTransmitted + ReportTransmissionFailed +) + +type Outcome struct { + OutcomeType OutcomeType `json:"outcomeType"` + RangesSelectedForReport []plugintypes.ChainRange `json:"rangesSelectedForReport"` + RootsToReport []cciptypes.MerkleRootChain `json:"rootsToReport"` + OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` + ReportTransmissionCheckAttempts uint `json:"reportTransmissionCheckAttempts"` +} + +// Sort all fields of the given Outcome +func (o *Outcome) Sort() { + sort.Slice(o.RangesSelectedForReport, func(i, j int) bool { + return o.RangesSelectedForReport[i].ChainSel < o.RangesSelectedForReport[j].ChainSel + }) + sort.Slice(o.RootsToReport, func(i, j int) bool { + return o.RootsToReport[i].ChainSel < o.RootsToReport[j].ChainSel + }) + sort.Slice(o.OffRampNextSeqNums, func(i, j int) bool { + return o.OffRampNextSeqNums[i].ChainSel < o.OffRampNextSeqNums[j].ChainSel + }) +} + +func (o *Outcome) NextState() State { + switch o.OutcomeType { + case ReportIntervalsSelected: + return BuildingReport + case ReportGenerated: + return WaitingForReportTransmission + case ReportEmpty: + return SelectingRangesForReport + case ReportInFlight: + return WaitingForReportTransmission + case ReportTransmitted: + return SelectingRangesForReport + case ReportTransmissionFailed: + return SelectingRangesForReport + default: + return SelectingRangesForReport + } +} + +type State int + +const ( + SelectingRangesForReport State = iota + 1 + BuildingReport + WaitingForReportTransmission +) diff --git a/commit/merkleroot/validate_observation.go b/commit/merkleroot/validate_observation.go new file mode 100644 index 000000000..217bffd94 --- /dev/null +++ b/commit/merkleroot/validate_observation.go @@ -0,0 +1,180 @@ +package merkleroot + +import ( + "context" + "fmt" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + + "github.com/smartcontractkit/chainlink-ccip/shared" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/libocr/commontypes" + + "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +func (w *Processor) ValidateObservation( + _ Outcome, + _ Query, + ao shared.AttributedObservation[Observation]) error { + obs := ao.Observation + if err := validateFChain(obs.FChain); err != nil { + return fmt.Errorf("failed to validate FChain: %w", err) + } + observerSupportedChains, err := w.chainSupport.SupportedChains(ao.OracleID) + if err != nil { + return fmt.Errorf("failed to get supported chains: %w", err) + } + + supportsDestChain, err := w.chainSupport.SupportsDestChain(ao.OracleID) + if err != nil { + return fmt.Errorf("call to supportsDestChain failed: %w", err) + } + + if err := validateObservedMerkleRoots(obs.MerkleRoots, ao.OracleID, observerSupportedChains); err != nil { + return fmt.Errorf("failed to validate MerkleRoots: %w", err) + } + + if err := validateObservedOnRampMaxSeqNums(obs.OnRampMaxSeqNums, ao.OracleID, observerSupportedChains); err != nil { + return fmt.Errorf("failed to validate OnRampMaxSeqNums: %w", err) + } + + if err := validateObservedOffRampMaxSeqNums(obs.OffRampNextSeqNums, ao.OracleID, supportsDestChain); err != nil { + return fmt.Errorf("failed to validate OffRampNextSeqNums: %w", err) + } + + return nil +} + +func validateObservedMerkleRoots( + merkleRoots []cciptypes.MerkleRootChain, + observer commontypes.OracleID, + observerSupportedChains mapset.Set[cciptypes.ChainSelector], +) error { + if len(merkleRoots) == 0 { + return nil + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, root := range merkleRoots { + if !observerSupportedChains.Contains(root.ChainSel) { + return fmt.Errorf("found merkle root for chain %d, but this chain is not supported by Observer %d", + root.ChainSel, observer) + } + + if seenChains.Contains(root.ChainSel) { + return fmt.Errorf("duplicate merkle root for chain %d", root.ChainSel) + } + seenChains.Add(root.ChainSel) + } + + return nil +} + +func validateObservedOnRampMaxSeqNums( + onRampMaxSeqNums []plugintypes.SeqNumChain, + observer commontypes.OracleID, + observerSupportedChains mapset.Set[cciptypes.ChainSelector], +) error { + if len(onRampMaxSeqNums) == 0 { + return nil + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, seqNumChain := range onRampMaxSeqNums { + if !observerSupportedChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("found onRampMaxSeqNum for chain %d, but this chain is not supported by Observer %d, "+ + "observerSupportedChains: %v, onRampMaxSeqNums: %v", + seqNumChain.ChainSel, observer, observerSupportedChains, onRampMaxSeqNums) + } + + if seenChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("duplicate onRampMaxSeqNum for chain %d", seqNumChain.ChainSel) + } + seenChains.Add(seqNumChain.ChainSel) + } + + return nil +} + +func validateObservedOffRampMaxSeqNums( + offRampMaxSeqNums []plugintypes.SeqNumChain, + observer commontypes.OracleID, + supportsDestChain bool, +) error { + if len(offRampMaxSeqNums) == 0 { + return nil + } + + if !supportsDestChain { + return fmt.Errorf("observer %d does not support dest chain, but has observed %d offRampMaxSeqNums", + observer, len(offRampMaxSeqNums)) + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, seqNumChain := range offRampMaxSeqNums { + if seenChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("duplicate offRampMaxSeqNum for chain %d", seqNumChain.ChainSel) + } + seenChains.Add(seqNumChain.ChainSel) + } + + return nil +} + +func validateFChain(fChain map[cciptypes.ChainSelector]int) error { + for _, f := range fChain { + if f < 0 { + return fmt.Errorf("fChain %d is negative", f) + } + } + + return nil +} + +// ValidateMerkleRootsState merkle roots seq nums validation by comparing with on-chain state. +func ValidateMerkleRootsState( + ctx context.Context, + lggr logger.Logger, + report cciptypes.CommitPluginReport, + reader reader.CCIP, +) (bool, error) { + reportChains := make([]cciptypes.ChainSelector, 0) + reportMinSeqNums := make(map[cciptypes.ChainSelector]cciptypes.SeqNum) + for _, mr := range report.MerkleRoots { + reportChains = append(reportChains, mr.ChainSel) + reportMinSeqNums[mr.ChainSel] = mr.SeqNumsRange.Start() + } + + if len(reportChains) == 0 { + return true, nil + } + + onchainNextSeqNums, err := reader.NextSeqNum(ctx, reportChains) + if err != nil { + return false, fmt.Errorf("get next sequence numbers: %w", err) + } + if len(onchainNextSeqNums) != len(reportChains) { + return false, fmt.Errorf("critical error: onchainSeqNums length mismatch") + } + + for i, nextSeqNum := range onchainNextSeqNums { + chain := reportChains[i] + reportMinSeqNum, ok := reportMinSeqNums[chain] + if !ok { + return false, fmt.Errorf("critical error: reportSeqNum not found for chain %d", chain) + } + + if reportMinSeqNum != nextSeqNum { + lggr.Warnw("report is not valid due to seq num mismatch", + "chain", chain, "reportMinSeqNum", reportMinSeqNum, "onchainNextSeqNum", nextSeqNum) + return false, nil + } + } + + return true, nil +} diff --git a/commit/merkleroot/validate_observation_test.go b/commit/merkleroot/validate_observation_test.go new file mode 100644 index 000000000..77479356b --- /dev/null +++ b/commit/merkleroot/validate_observation_test.go @@ -0,0 +1,293 @@ +package merkleroot + +import ( + "context" + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +func Test_validateObservedMerkleRoots(t *testing.T) { + testCases := []struct { + name string + merkleRoots []cciptypes.MerkleRootChain + observer commontypes.OracleID + observerSupportedChains mapset.Set[cciptypes.ChainSelector] + expErr bool + }{ + { + name: "Chain not supported", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), + expErr: true, + }, + { + name: "Duplicate chains", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{3, 7}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedMerkleRoots(tc.merkleRoots, tc.observer, tc.observerSupportedChains) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateObservedOnRampMaxSeqNums(t *testing.T) { + testCases := []struct { + name string + onRampMaxSeqNums []plugintypes.SeqNumChain + observer commontypes.OracleID + observerSupportedChains mapset.Set[cciptypes.ChainSelector] + expErr bool + }{ + { + name: "Chain not supported", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), + expErr: true, + }, + { + name: "Duplicate chains", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + {ChainSel: 2, SeqNum: 33}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedOnRampMaxSeqNums(tc.onRampMaxSeqNums, tc.observer, tc.observerSupportedChains) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateObservedOffRampMaxSeqNums(t *testing.T) { + testCases := []struct { + name string + offRampMaxSeqNums []plugintypes.SeqNumChain + observer commontypes.OracleID + supportsDestChain bool + expErr bool + }{ + { + name: "Dest chain not supported", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + supportsDestChain: false, + expErr: true, + }, + { + name: "Duplicate chains", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + {ChainSel: 2, SeqNum: 33}, + }, + observer: 10, + supportsDestChain: false, + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + supportsDestChain: true, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedOffRampMaxSeqNums(tc.offRampMaxSeqNums, tc.observer, tc.supportsDestChain) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateFChain(t *testing.T) { + testCases := []struct { + name string + fChain map[cciptypes.ChainSelector]int + expErr bool + }{ + { + name: "FChain contains negative values", + fChain: map[cciptypes.ChainSelector]int{ + 1: 11, + 2: -4, + }, + expErr: true, + }, + { + name: "FChain valid", + fChain: map[cciptypes.ChainSelector]int{ + 12: 6, + 7: 9, + }, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateFChain(tc.fChain) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateMerkleRootsState(t *testing.T) { + testCases := []struct { + name string + reportSeqNums []plugintypes.SeqNumChain + onchainNextSeqNums []cciptypes.SeqNum + expValid bool + expErr bool + }{ + { + name: "happy path", + reportSeqNums: []plugintypes.SeqNumChain{ + plugintypes.NewSeqNumChain(10, 100), + plugintypes.NewSeqNumChain(20, 200), + }, + onchainNextSeqNums: []cciptypes.SeqNum{100, 200}, + expValid: true, + expErr: false, + }, + { + name: "one root is stale", + reportSeqNums: []plugintypes.SeqNumChain{ + plugintypes.NewSeqNumChain(10, 100), + plugintypes.NewSeqNumChain(20, 200), + }, + onchainNextSeqNums: []cciptypes.SeqNum{100, 201}, // <- 200 is already on chain + expValid: false, + expErr: false, + }, + { + name: "one root has gap", + reportSeqNums: []plugintypes.SeqNumChain{ + plugintypes.NewSeqNumChain(10, 101), // <- onchain 99 but we submit 101 instead of 100 + plugintypes.NewSeqNumChain(20, 200), + }, + onchainNextSeqNums: []cciptypes.SeqNum{100, 200}, + expValid: false, + expErr: false, + }, + { + name: "reader returned wrong number of seq nums", + reportSeqNums: []plugintypes.SeqNumChain{ + plugintypes.NewSeqNumChain(10, 100), + plugintypes.NewSeqNumChain(20, 200), + }, + onchainNextSeqNums: []cciptypes.SeqNum{100, 200, 300}, + expValid: false, + expErr: true, + }, + } + + ctx := context.Background() + lggr := logger.Test(t) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reader := reader_mock.NewMockCCIP(t) + rep := cciptypes.CommitPluginReport{} + chains := make([]cciptypes.ChainSelector, 0, len(tc.reportSeqNums)) + for _, snc := range tc.reportSeqNums { + rep.MerkleRoots = append(rep.MerkleRoots, cciptypes.MerkleRootChain{ + ChainSel: snc.ChainSel, + SeqNumsRange: cciptypes.NewSeqNumRange(snc.SeqNum, snc.SeqNum+10), + }) + chains = append(chains, snc.ChainSel) + } + reader.On("NextSeqNum", ctx, chains).Return(tc.onchainNextSeqNums, nil) + valid, err := ValidateMerkleRootsState(ctx, lggr, rep, reader) + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tc.expValid, valid) + }) + } +} diff --git a/commit/plugin.go b/commit/plugin.go index c60cb9657..e926f69ed 100644 --- a/commit/plugin.go +++ b/commit/plugin.go @@ -7,8 +7,15 @@ import ( "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" + "github.com/smartcontractkit/chainlink-ccip/commit/tokenprice" + + "github.com/smartcontractkit/chainlink-ccip/commit/chainfee" + "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" + "github.com/smartcontractkit/chainlink-ccip/shared" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" @@ -18,19 +25,25 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" ) +type MerkleRootObservation = shared.AttributedObservation[merkleroot.Observation] +type TokenPricesObservation = shared.AttributedObservation[tokenprice.Observation] +type ChainFeeObservation = shared.AttributedObservation[chainfee.Observation] + type Plugin struct { - nodeID commontypes.OracleID - oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID - cfg pluginconfig.CommitPluginConfig - ccipReader reader.CCIP - readerSyncer *plugincommon.BackgroundReaderSyncer - tokenPricesReader reader.TokenPrices - reportCodec cciptypes.CommitPluginCodec - lggr logger.Logger - homeChain reader.HomeChain - reportingCfg ocr3types.ReportingPluginConfig - chainSupport ChainSupport - observer Observer + nodeID commontypes.OracleID + oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID + cfg pluginconfig.CommitPluginConfig + ccipReader reader.CCIP + readerSyncer *plugincommon.BackgroundReaderSyncer + tokenPricesReader reader.TokenPrices + reportCodec cciptypes.CommitPluginCodec + lggr logger.Logger + homeChain reader.HomeChain + reportingCfg ocr3types.ReportingPluginConfig + chainSupport shared.ChainSupport + merkleRootProcessor shared.PluginProcessor[merkleroot.Query, merkleroot.Observation, merkleroot.Outcome] + tokenPriceProcessor shared.PluginProcessor[tokenprice.Query, tokenprice.Observation, tokenprice.Outcome] + chainFeeProcessor shared.PluginProcessor[chainfee.Query, chainfee.Observation, chainfee.Outcome] } func NewPlugin( @@ -56,37 +69,161 @@ func NewPlugin( lggr.Errorw("error starting background reader syncer", "err", err) } - chainSupport := CCIPChainSupport{ - lggr: lggr, - homeChain: homeChain, - oracleIDToP2pID: oracleIDToP2pID, - nodeID: nodeID, - destChain: cfg.DestChain, + chainSupport := shared.NewCCIPChainSupport( + lggr, + homeChain, + oracleIDToP2pID, + nodeID, + cfg.DestChain, + ) + + merkleRootProcessor := merkleroot.NewProcessor( + nodeID, + lggr, + cfg, + homeChain, + ccipReader, + msgHasher, + reportingCfg, + chainSupport, + ) + + return &Plugin{ + nodeID: nodeID, + oracleIDToP2pID: oracleIDToP2pID, + lggr: lggr, + cfg: cfg, + tokenPricesReader: tokenPricesReader, + ccipReader: ccipReader, + homeChain: homeChain, + readerSyncer: readerSyncer, + reportCodec: reportCodec, + reportingCfg: reportingCfg, + chainSupport: chainSupport, + merkleRootProcessor: merkleRootProcessor, + tokenPriceProcessor: tokenprice.NewProcessor(), + chainFeeProcessor: chainfee.NewProcessor(), } +} - observer := ObserverImpl{ - lggr: lggr, - homeChain: homeChain, - nodeID: nodeID, - chainSupport: chainSupport, - ccipReader: ccipReader, - msgHasher: msgHasher, +func (p *Plugin) Query(_ context.Context, outCtx ocr3types.OutcomeContext) (types.Query, error) { + return types.Query{}, nil +} + +func (p *Plugin) ObservationQuorum(_ ocr3types.OutcomeContext, _ types.Query) (ocr3types.Quorum, error) { + // Across all chains we require at least 2F+1 observations. + return ocr3types.QuorumTwoFPlusOne, nil +} + +func (p *Plugin) Observation( + ctx context.Context, outCtx ocr3types.OutcomeContext, _ types.Query, +) (types.Observation, error) { + prevOutcome := p.decodeOutcome(outCtx.PreviousOutcome) + fChain := p.ObserveFChain() + //TODO: Move fchain to a new processor instead of computing it inside MerkleProcessor + merkleRootObs, err := p.merkleRootProcessor.Observation(ctx, prevOutcome.MerkleRootOutcome, merkleroot.Query{}) + if err != nil { + p.lggr.Errorw("failed to get merkle observation", "err", err) + } + tokenPriceObs, err := p.tokenPriceProcessor.Observation(ctx, prevOutcome.TokenPriceOutcome, tokenprice.Query{}) + if err != nil { + //log error + p.lggr.Errorw("failed to get token prices", "err", err) + } + chainFeeObs, err := p.chainFeeProcessor.Observation(ctx, prevOutcome.ChainFeeOutcome, chainfee.Query{}) + if err != nil { + p.lggr.Errorw("failed to get gas prices", "err", err) } - return &Plugin{ - nodeID: nodeID, - oracleIDToP2pID: oracleIDToP2pID, - lggr: lggr, - cfg: cfg, - tokenPricesReader: tokenPricesReader, - ccipReader: ccipReader, - homeChain: homeChain, - readerSyncer: readerSyncer, - reportCodec: reportCodec, - reportingCfg: reportingCfg, - chainSupport: chainSupport, - observer: observer, + obs := Observation{ + MerkleRootObs: merkleRootObs, + TokenPriceObs: tokenPriceObs, + ChainFeeObs: chainFeeObs, + FChain: fChain, + } + return obs.Encode() +} + +func (p *Plugin) ObserveFChain() map[cciptypes.ChainSelector]int { + fChain, err := p.homeChain.GetFChain() + if err != nil { + // TODO: metrics + p.lggr.Warnw("call to GetFChain failed", "err", err) + return map[cciptypes.ChainSelector]int{} + } + return fChain +} + +// Outcome depending on the current state, either: +// - chooses the seq num ranges for the next round +// - builds a report +// - checks for the transmission of a previous report +func (p *Plugin) Outcome( + outCtx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, +) (ocr3types.Outcome, error) { + prevOutcome := p.decodeOutcome(outCtx.PreviousOutcome) + + var merkleObservations []MerkleRootObservation + var tokensObservations []TokenPricesObservation + var feeObservations []ChainFeeObservation + + for _, ao := range aos { + obs, err := DecodeCommitPluginObservation(ao.Observation) + if err != nil { + p.lggr.Errorw("failed to decode observation", "err", err) + continue + } + merkleObservations = append(merkleObservations, + MerkleRootObservation{ + OracleID: ao.Observer, + Observation: obs.MerkleRootObs, + }, + ) + + tokensObservations = append(tokensObservations, + TokenPricesObservation{ + OracleID: ao.Observer, + Observation: obs.TokenPriceObs, + }, + ) + + feeObservations = append(feeObservations, + ChainFeeObservation{ + OracleID: ao.Observer, + Observation: obs.ChainFeeObs, + }, + ) } + + merkleRootOutcome, err := p.merkleRootProcessor.Outcome( + prevOutcome.MerkleRootOutcome, + merkleroot.Query{}, + merkleObservations, + ) + if err != nil { + p.lggr.Errorw("failed to get merkle outcome", "err", err) + } + + tokenPriceOutcome, err := p.tokenPriceProcessor.Outcome( + prevOutcome.TokenPriceOutcome, + tokenprice.Query{}, + tokensObservations, + ) + + if err != nil { + p.lggr.Errorw("failed to get token prices outcome", "err", err) + } + + chainFeeOutcome, err := p.chainFeeProcessor.Outcome(prevOutcome.ChainFeeOutcome, chainfee.Query{}, feeObservations) + if err != nil { + p.lggr.Errorw("failed to get gas prices outcome", "err", err) + } + + return Outcome{ + MerkleRootOutcome: merkleRootOutcome, + TokenPriceOutcome: tokenPriceOutcome, + ChainFeeOutcome: chainFeeOutcome, + }.Encode() } func (p *Plugin) Close() error { @@ -105,18 +242,18 @@ func (p *Plugin) Close() error { return nil } -func (p *Plugin) decodeOutcome(outcome ocr3types.Outcome) (Outcome, State) { +func (p *Plugin) decodeOutcome(outcome ocr3types.Outcome) Outcome { if len(outcome) == 0 { - return Outcome{}, SelectingRangesForReport + return Outcome{} } decodedOutcome, err := DecodeOutcome(outcome) if err != nil { p.lggr.Errorw("Failed to decode Outcome", "outcome", outcome, "err", err) - return Outcome{}, SelectingRangesForReport + return Outcome{} } - return decodedOutcome, decodedOutcome.NextState() + return decodedOutcome } func syncFrequency(configuredValue time.Duration) time.Duration { diff --git a/commit/plugin_e2e_test.go b/commit/plugin_e2e_test.go index fc88383cd..247755f33 100644 --- a/commit/plugin_e2e_test.go +++ b/commit/plugin_e2e_test.go @@ -9,15 +9,18 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" + "github.com/smartcontractkit/chainlink-ccip/chainconfig" "github.com/smartcontractkit/chainlink-ccip/internal/libs/testhelpers" "github.com/smartcontractkit/chainlink-ccip/internal/mocks" @@ -76,36 +79,38 @@ func TestPlugin_E2E_AllNodesAgree(t *testing.T) { reportingCfg := ocr3types.ReportingPluginConfig{F: 1} outcomeIntervalsSelected := Outcome{ - OutcomeType: ReportIntervalsSelected, - RangesSelectedForReport: []plugintypes.ChainRange{ - {ChainSel: sourceChain1, SeqNumRange: ccipocr3.SeqNumRange{10, 10}}, - {ChainSel: sourceChain2, SeqNumRange: ccipocr3.SeqNumRange{20, 20}}, - }, - OffRampNextSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: sourceChain1, SeqNum: 10}, - {ChainSel: sourceChain2, SeqNum: 20}, + MerkleRootOutcome: merkleroot.Outcome{ + OutcomeType: merkleroot.ReportIntervalsSelected, + RangesSelectedForReport: []plugintypes.ChainRange{ + {ChainSel: sourceChain1, SeqNumRange: ccipocr3.SeqNumRange{10, 10}}, + {ChainSel: sourceChain2, SeqNumRange: ccipocr3.SeqNumRange{20, 20}}, + }, + OffRampNextSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: sourceChain1, SeqNum: 10}, + {ChainSel: sourceChain2, SeqNum: 20}, + }, }, } outcomeReportGenerated := Outcome{ - OutcomeType: ReportGenerated, - RootsToReport: []ccipocr3.MerkleRootChain{ - { - ChainSel: sourceChain1, - SeqNumsRange: ccipocr3.SeqNumRange{0xa, 0xa}, - MerkleRoot: merkleRoot1, + MerkleRootOutcome: merkleroot.Outcome{ + OutcomeType: merkleroot.ReportGenerated, + RootsToReport: []ccipocr3.MerkleRootChain{ + { + ChainSel: sourceChain1, + SeqNumsRange: ccipocr3.SeqNumRange{0xa, 0xa}, + MerkleRoot: merkleRoot1, + }, + }, + OffRampNextSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: sourceChain1, SeqNum: 10}, + {ChainSel: sourceChain2, SeqNum: 20}, }, }, - OffRampNextSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: sourceChain1, SeqNum: 10}, - {ChainSel: sourceChain2, SeqNum: 20}, - }, - TokenPrices: make([]ccipocr3.TokenPrice, 0), - GasPrices: make([]ccipocr3.GasPriceChain, 0), } outcomeReportGeneratedOneInflightCheck := outcomeReportGenerated - outcomeReportGeneratedOneInflightCheck.ReportTransmissionCheckAttempts = 1 + outcomeReportGeneratedOneInflightCheck.MerkleRootOutcome.ReportTransmissionCheckAttempts = 1 testCases := []struct { name string @@ -134,10 +139,7 @@ func TestPlugin_E2E_AllNodesAgree(t *testing.T) { MerkleRoot: merkleRoot1, }, }, - PriceUpdates: ccipocr3.PriceUpdates{ - TokenPriceUpdates: []ccipocr3.TokenPrice{}, - GasPriceUpdates: []ccipocr3.GasPriceChain{}, - }, + PriceUpdates: ccipocr3.PriceUpdates{}, }, }, }, @@ -145,11 +147,13 @@ func TestPlugin_E2E_AllNodesAgree(t *testing.T) { name: "report generated in previous outcome, still inflight", prevOutcome: outcomeReportGenerated, expOutcome: Outcome{ - OutcomeType: ReportInFlight, - ReportTransmissionCheckAttempts: 1, - OffRampNextSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: sourceChain1, SeqNum: 10}, - {ChainSel: sourceChain2, SeqNum: 20}, + MerkleRootOutcome: merkleroot.Outcome{ + OutcomeType: merkleroot.ReportInFlight, + ReportTransmissionCheckAttempts: 1, + OffRampNextSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: sourceChain1, SeqNum: 10}, + {ChainSel: sourceChain2, SeqNum: 20}, + }, }, }, }, @@ -157,8 +161,9 @@ func TestPlugin_E2E_AllNodesAgree(t *testing.T) { name: "report generated in previous outcome, still inflight, reached all inflight check attempts", prevOutcome: outcomeReportGeneratedOneInflightCheck, expOutcome: Outcome{ - OutcomeType: ReportTransmissionFailed, - ReportTransmissionCheckAttempts: 0, + MerkleRootOutcome: merkleroot.Outcome{ + OutcomeType: merkleroot.ReportTransmissionFailed, + }, }, }, { @@ -167,8 +172,9 @@ func TestPlugin_E2E_AllNodesAgree(t *testing.T) { offRampNextSeqNumDefaultOverrideKeys: []ccipocr3.ChainSelector{sourceChain1, sourceChain2}, offRampNextSeqNumDefaultOverrideValues: []ccipocr3.SeqNum{11, 20}, expOutcome: Outcome{ - OutcomeType: ReportTransmitted, - ReportTransmissionCheckAttempts: 0, + MerkleRootOutcome: merkleroot.Outcome{ + OutcomeType: merkleroot.ReportTransmitted, + }, }, }, } diff --git a/commit/query.go b/commit/query.go deleted file mode 100644 index 2c6cf59fe..000000000 --- a/commit/query.go +++ /dev/null @@ -1,12 +0,0 @@ -package commit - -import ( - "context" - - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" -) - -func (p *Plugin) Query(_ context.Context, outCtx ocr3types.OutcomeContext) (types.Query, error) { - return types.Query{}, nil -} diff --git a/commit/report.go b/commit/report.go index ad8701506..61b7ec0d8 100644 --- a/commit/report.go +++ b/commit/report.go @@ -7,6 +7,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) @@ -18,12 +20,16 @@ func (p *Plugin) Reports(seqNr uint64, outcomeBytes ocr3types.Outcome) ([]ocr3ty return nil, fmt.Errorf("failed to decode Outcome (%s): %w", hex.EncodeToString(outcomeBytes), err) } - // Reports are only generated from "ReportGenerated" outcomes - if outcome.OutcomeType != ReportGenerated { + // Until we start adding tokens and gas to the report, we don't need to report anything + if outcome.MerkleRootOutcome.OutcomeType != merkleroot.ReportGenerated { return []ocr3types.ReportWithInfo[[]byte]{}, nil } - rep := cciptypes.NewCommitPluginReport(outcome.RootsToReport, outcome.TokenPrices, outcome.GasPrices) + rep := cciptypes.NewCommitPluginReport( + outcome.MerkleRootOutcome.RootsToReport, + outcome.TokenPriceOutcome.TokenPrices, + outcome.ChainFeeOutcome.GasPrices, + ) encodedReport, err := p.reportCodec.Encode(context.Background(), rep) if err != nil { @@ -67,7 +73,7 @@ func (p *Plugin) ShouldTransmitAcceptedReport( return false, fmt.Errorf("decode commit plugin report: %w", err) } - isValid, err := validateMerkleRootsState(ctx, p.lggr, decodedReport, p.ccipReader) + isValid, err := merkleroot.ValidateMerkleRootsState(ctx, p.lggr, decodedReport, p.ccipReader) if !isValid { return false, nil } diff --git a/commit/tokenprice/processor.go b/commit/tokenprice/processor.go new file mode 100644 index 000000000..c3c1bf1c3 --- /dev/null +++ b/commit/tokenprice/processor.go @@ -0,0 +1,65 @@ +package tokenprice + +import ( + "context" + "fmt" + + mapset "github.com/deckarep/golang-set/v2" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-ccip/shared" +) + +type Processor struct { +} + +func NewProcessor() *Processor { + return &Processor{} +} + +func (w *Processor) Query(ctx context.Context, prevOutcome Outcome) (Query, error) { + return Query{}, nil +} + +func (w *Processor) Observation( + ctx context.Context, + prevOutcome Outcome, + query Query, +) (Observation, error) { + return Observation{}, nil +} + +func (w *Processor) Outcome( + prevOutcome Outcome, + query Query, + aos []shared.AttributedObservation[Observation], +) (Outcome, error) { + return Outcome{}, nil +} + +func (w *Processor) ValidateObservation( + prevOutcome Outcome, + query Query, + ao shared.AttributedObservation[Observation], +) error { + //TODO: Validate token prices + return nil +} + +func validateObservedTokenPrices(tokenPrices []cciptypes.TokenPrice) error { + tokensWithPrice := mapset.NewSet[types.Account]() + for _, t := range tokenPrices { + if tokensWithPrice.Contains(t.TokenID) { + return fmt.Errorf("duplicate token price for token: %s", t.TokenID) + } + tokensWithPrice.Add(t.TokenID) + + if t.Price.IsEmpty() { + return fmt.Errorf("token price must not be empty") + } + } + return nil +} + +var _ shared.PluginProcessor[Query, Observation, Outcome] = &Processor{} diff --git a/commit/tokenprice/types.go b/commit/tokenprice/types.go new file mode 100644 index 000000000..658a3c666 --- /dev/null +++ b/commit/tokenprice/types.go @@ -0,0 +1,23 @@ +package tokenprice + +import ( + "time" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +type Query struct { +} + +type Outcome struct { + TokenPrices []cciptypes.TokenPrice `json:"tokenPrices"` +} + +type Observation struct { + // FeedTokenPrices for tokens from the feeds on the feed chain + FeedTokenPrices []cciptypes.TokenPrice `json:"feedTokenPrices"` + // PriceRegistryTokenUpdates for tokens from the PriceRegistry on the dest chain + PriceRegistryTokenUpdates []cciptypes.TokenPrice `json:"priceRegistryTokenPrices"` + // Observation time + Timestamp time.Time +} diff --git a/commit/tokenprice/validate_test.go b/commit/tokenprice/validate_test.go new file mode 100644 index 000000000..71b25d856 --- /dev/null +++ b/commit/tokenprice/validate_test.go @@ -0,0 +1,66 @@ +package tokenprice + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +func Test_validateObservedTokenPrices(t *testing.T) { + testCases := []struct { + name string + tokenPrices []cciptypes.TokenPrice + expErr bool + }{ + { + name: "empty is valid", + tokenPrices: []cciptypes.TokenPrice{}, + expErr: false, + }, + { + name: "all valid", + tokenPrices: []cciptypes.TokenPrice{ + cciptypes.NewTokenPrice("0x1", big.NewInt(1)), + cciptypes.NewTokenPrice("0x2", big.NewInt(1)), + cciptypes.NewTokenPrice("0x3", big.NewInt(1)), + cciptypes.NewTokenPrice("0xa", big.NewInt(1)), + }, + expErr: false, + }, + { + name: "dup price", + tokenPrices: []cciptypes.TokenPrice{ + cciptypes.NewTokenPrice("0x1", big.NewInt(1)), + cciptypes.NewTokenPrice("0x2", big.NewInt(1)), + cciptypes.NewTokenPrice("0x1", big.NewInt(1)), // dup + cciptypes.NewTokenPrice("0xa", big.NewInt(1)), + }, + expErr: true, + }, + { + name: "nil price", + tokenPrices: []cciptypes.TokenPrice{ + cciptypes.NewTokenPrice("0x1", big.NewInt(1)), + cciptypes.NewTokenPrice("0x2", big.NewInt(1)), + cciptypes.NewTokenPrice("0x3", nil), // nil price + cciptypes.NewTokenPrice("0xa", big.NewInt(1)), + }, + expErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedTokenPrices(tc.tokenPrices) + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + + } +} diff --git a/commit/types.go b/commit/types.go index 22a65e780..324d28b2b 100644 --- a/commit/types.go +++ b/commit/types.go @@ -3,10 +3,10 @@ package commit import ( "encoding/json" "fmt" - "sort" - - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-ccip/commit/chainfee" + "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" + "github.com/smartcontractkit/chainlink-ccip/commit/tokenprice" "github.com/smartcontractkit/chainlink-ccip/plugintypes" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" @@ -27,20 +27,11 @@ func DecodeCommitPluginQuery(encodedQuery []byte) (Query, error) { return q, err } -func NewCommitQuery(rmnOnRampMaxSeqNums []plugintypes.SeqNumChain, merkleRoots []cciptypes.MerkleRootChain) Query { - return Query{ - RmnOnRampMaxSeqNums: rmnOnRampMaxSeqNums, - MerkleRoots: merkleRoots, - } -} - type Observation struct { - MerkleRoots []cciptypes.MerkleRootChain `json:"merkleRoots"` - GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` - TokenPrices []cciptypes.TokenPrice `json:"tokenPrices"` - OnRampMaxSeqNums []plugintypes.SeqNumChain `json:"onRampMaxSeqNums"` - OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` - FChain map[cciptypes.ChainSelector]int `json:"fChain"` + MerkleRootObs merkleroot.Observation `json:"merkleObs"` + TokenPriceObs tokenprice.Observation `json:"tokenObs"` + ChainFeeObs chainfee.Observation `json:"gasObs"` + FChain map[cciptypes.ChainSelector]int `json:"fChain"` } func (obs Observation) Encode() ([]byte, error) { @@ -58,173 +49,16 @@ func DecodeCommitPluginObservation(encodedObservation []byte) (Observation, erro return o, err } -// AggregatedObservation is the aggregation of a list of observations -type AggregatedObservation struct { - // A map from chain selectors to the list of merkle roots observed for each chain - MerkleRoots map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain - - // A map from chain selectors to the list of gas prices observed for each chain - GasPrices map[cciptypes.ChainSelector][]cciptypes.BigInt - - // A map from token IDs to the list of prices observed for each token - TokenPrices map[types.Account][]cciptypes.BigInt - - // A map from chain selectors to the list of OnRamp max sequence numbers observed for each chain - OnRampMaxSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum - - // A map from chain selectors to the list of OffRamp next sequence numbers observed for each chain - OffRampNextSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum - - // A map from chain selectors to the list of f (failure tolerance) observed for each chain - FChain map[cciptypes.ChainSelector][]int -} - -// aggregateObservations takes a list of observations and produces an AggregatedObservation -func aggregateObservations(aos []types.AttributedObservation) AggregatedObservation { - aggObs := AggregatedObservation{ - MerkleRoots: make(map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain), - GasPrices: make(map[cciptypes.ChainSelector][]cciptypes.BigInt), - TokenPrices: make(map[types.Account][]cciptypes.BigInt), - OnRampMaxSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), - OffRampNextSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), - FChain: make(map[cciptypes.ChainSelector][]int), - } - - for _, ao := range aos { - obs, err := DecodeCommitPluginObservation(ao.Observation) - if err != nil { - // TODO: lggr - continue - } - - // MerkleRoots - for _, merkleRoot := range obs.MerkleRoots { - aggObs.MerkleRoots[merkleRoot.ChainSel] = - append(aggObs.MerkleRoots[merkleRoot.ChainSel], merkleRoot) - } - - // GasPrices - for _, gasPriceChain := range obs.GasPrices { - aggObs.GasPrices[gasPriceChain.ChainSel] = - append(aggObs.GasPrices[gasPriceChain.ChainSel], gasPriceChain.GasPrice) - } - - // TokenPrices - for _, tokenPrice := range obs.TokenPrices { - aggObs.TokenPrices[tokenPrice.TokenID] = - append(aggObs.TokenPrices[tokenPrice.TokenID], tokenPrice.Price) - } - - // OnRampMaxSeqNums - for _, seqNumChain := range obs.OnRampMaxSeqNums { - aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel] = - append(aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) - } - - // OffRampNextSeqNums - for _, seqNumChain := range obs.OffRampNextSeqNums { - aggObs.OffRampNextSeqNums[seqNumChain.ChainSel] = - append(aggObs.OffRampNextSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) - } - - // FChain - for chainSel, f := range obs.FChain { - aggObs.FChain[chainSel] = append(aggObs.FChain[chainSel], f) - } - } - - return aggObs -} - -// ConsensusObservation holds the consensus values for all chains across all observations in a round -type ConsensusObservation struct { - // A map from chain selectors to each chain's consensus merkle root - MerkleRoots map[cciptypes.ChainSelector]cciptypes.MerkleRootChain - - // A map from chain selectors to each chain's consensus gas prices - GasPrices map[cciptypes.ChainSelector]cciptypes.BigInt - - // A map from token IDs to each token's consensus price - TokenPrices map[types.Account]cciptypes.BigInt - - // A map from chain selectors to each chain's consensus OnRamp max sequence number - OnRampMaxSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum - - // A map from chain selectors to each chain's consensus OffRamp next sequence number - OffRampNextSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum - - // A map from chain selectors to each chain's consensus f (failure tolerance) - FChain map[cciptypes.ChainSelector]int -} - -// GasPricesArray returns a list of gas prices -func (co ConsensusObservation) GasPricesArray() []cciptypes.GasPriceChain { - gasPrices := make([]cciptypes.GasPriceChain, 0, len(co.GasPrices)) - for chain, gasPrice := range co.GasPrices { - gasPrices = append(gasPrices, cciptypes.NewGasPriceChain(gasPrice.Int, chain)) - } - sort.Slice(gasPrices, func(i, j int) bool { return gasPrices[i].ChainSel < gasPrices[j].ChainSel }) - - return gasPrices -} - -// TokenPricesArray returns a list of token prices -func (co ConsensusObservation) TokenPricesArray() []cciptypes.TokenPrice { - tokenPrices := make([]cciptypes.TokenPrice, 0, len(co.TokenPrices)) - for tokenID, tokenPrice := range co.TokenPrices { - tokenPrices = append(tokenPrices, cciptypes.NewTokenPrice(tokenID, tokenPrice.Int)) - } - sort.Slice(tokenPrices, func(i, j int) bool { return tokenPrices[i].TokenID < tokenPrices[j].TokenID }) - - return tokenPrices -} - -type OutcomeType int - -const ( - ReportIntervalsSelected OutcomeType = iota + 1 - ReportGenerated - ReportEmpty - ReportInFlight - ReportTransmitted - ReportTransmissionFailed -) - type Outcome struct { - OutcomeType OutcomeType `json:"outcomeType"` - RangesSelectedForReport []plugintypes.ChainRange `json:"rangesSelectedForReport"` - RootsToReport []cciptypes.MerkleRootChain `json:"rootsToReport"` - OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` - TokenPrices []cciptypes.TokenPrice `json:"tokenPrices"` - GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` - ReportTransmissionCheckAttempts uint `json:"reportTransmissionCheckAttempts"` -} - -// Sort all fields of the given Outcome -func (o Outcome) sort() { - sort.Slice(o.RangesSelectedForReport, func(i, j int) bool { - return o.RangesSelectedForReport[i].ChainSel < o.RangesSelectedForReport[j].ChainSel - }) - sort.Slice(o.RootsToReport, func(i, j int) bool { - return o.RootsToReport[i].ChainSel < o.RootsToReport[j].ChainSel - }) - sort.Slice(o.OffRampNextSeqNums, func(i, j int) bool { - return o.OffRampNextSeqNums[i].ChainSel < o.OffRampNextSeqNums[j].ChainSel - }) - sort.Slice(o.TokenPrices, func(i, j int) bool { - return o.TokenPrices[i].TokenID < o.TokenPrices[j].TokenID - }) - sort.Slice(o.GasPrices, func(i, j int) bool { - return o.GasPrices[i].ChainSel < o.GasPrices[j].ChainSel - }) + MerkleRootOutcome merkleroot.Outcome `json:"merkleOutcome"` + TokenPriceOutcome tokenprice.Outcome `json:"tokensOutcome"` + ChainFeeOutcome chainfee.Outcome `json:"gasOutcome"` } // Encode encodes an Outcome deterministically func (o Outcome) Encode() ([]byte, error) { - // Sort all lists to ensure deterministic serialization - o.sort() - + o.MerkleRootOutcome.Sort() encodedOutcome, err := json.Marshal(o) if err != nil { return nil, fmt.Errorf("failed to encode Outcome: %w", err) @@ -238,30 +72,3 @@ func DecodeOutcome(b []byte) (Outcome, error) { err := json.Unmarshal(b, &o) return o, err } - -func (o Outcome) NextState() State { - switch o.OutcomeType { - case ReportIntervalsSelected: - return BuildingReport - case ReportGenerated: - return WaitingForReportTransmission - case ReportEmpty: - return SelectingRangesForReport - case ReportInFlight: - return WaitingForReportTransmission - case ReportTransmitted: - return SelectingRangesForReport - case ReportTransmissionFailed: - return SelectingRangesForReport - default: - return SelectingRangesForReport - } -} - -type State int - -const ( - SelectingRangesForReport State = iota + 1 - BuildingReport - WaitingForReportTransmission -) diff --git a/commit/validate_observation.go b/commit/validate_observation.go index 50389a450..4a3103e1a 100644 --- a/commit/validate_observation.go +++ b/commit/validate_observation.go @@ -1,137 +1,60 @@ package commit import ( - "context" "fmt" - mapset "github.com/deckarep/golang-set/v2" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/chainlink-ccip/commit/tokenprice" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - "github.com/smartcontractkit/chainlink-ccip/internal/reader" - "github.com/smartcontractkit/chainlink-ccip/plugintypes" + "github.com/smartcontractkit/chainlink-ccip/commit/chainfee" + "github.com/smartcontractkit/chainlink-ccip/commit/merkleroot" ) // ValidateObservation validates an observation to ensure it is well-formed -func (p *Plugin) ValidateObservation(_ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error { +func (p *Plugin) ValidateObservation( + outCtx ocr3types.OutcomeContext, + _ types.Query, + ao types.AttributedObservation, +) error { obs, err := DecodeCommitPluginObservation(ao.Observation) if err != nil { return fmt.Errorf("failed to decode commit plugin observation: %w", err) } + prevOutcome := p.decodeOutcome(outCtx.PreviousOutcome) if err := validateFChain(obs.FChain); err != nil { return fmt.Errorf("failed to validate FChain: %w", err) } - observerSupportedChains, err := p.chainSupport.SupportedChains(ao.Observer) - if err != nil { - return fmt.Errorf("failed to get supported chains: %w", err) + merkleObs := MerkleRootObservation{ + OracleID: ao.Observer, + Observation: obs.MerkleRootObs, } - supportsDestChain, err := p.chainSupport.SupportsDestChain(ao.Observer) + err = p.merkleRootProcessor.ValidateObservation(prevOutcome.MerkleRootOutcome, merkleroot.Query{}, merkleObs) if err != nil { - return fmt.Errorf("call to supportsDestChain failed: %w", err) - } - - if err := validateObservedMerkleRoots(obs.MerkleRoots, ao.Observer, observerSupportedChains); err != nil { - return fmt.Errorf("failed to validate MerkleRoots: %w", err) - } - - if err := validateObservedOnRampMaxSeqNums(obs.OnRampMaxSeqNums, ao.Observer, observerSupportedChains); err != nil { - return fmt.Errorf("failed to validate OnRampMaxSeqNums: %w", err) - } - - if err := validateObservedOffRampMaxSeqNums(obs.OffRampNextSeqNums, ao.Observer, supportsDestChain); err != nil { - return fmt.Errorf("failed to validate OffRampNextSeqNums: %w", err) - } - - if err := validateObservedTokenPrices(obs.TokenPrices); err != nil { - return fmt.Errorf("failed to validate token prices: %w", err) - } - - if err := validateObservedGasPrices(obs.GasPrices); err != nil { - return fmt.Errorf("failed to validate gas prices: %w", err) - } - - return nil -} - -func validateObservedMerkleRoots( - merkleRoots []cciptypes.MerkleRootChain, - observer commontypes.OracleID, - observerSupportedChains mapset.Set[cciptypes.ChainSelector], -) error { - if len(merkleRoots) == 0 { - return nil - } - - seenChains := mapset.NewSet[cciptypes.ChainSelector]() - for _, root := range merkleRoots { - if !observerSupportedChains.Contains(root.ChainSel) { - return fmt.Errorf("found merkle root for chain %d, but this chain is not supported by Observer %d", - root.ChainSel, observer) - } - - if seenChains.Contains(root.ChainSel) { - return fmt.Errorf("duplicate merkle root for chain %d", root.ChainSel) - } - seenChains.Add(root.ChainSel) - } - - return nil -} - -func validateObservedOnRampMaxSeqNums( - onRampMaxSeqNums []plugintypes.SeqNumChain, - observer commontypes.OracleID, - observerSupportedChains mapset.Set[cciptypes.ChainSelector], -) error { - if len(onRampMaxSeqNums) == 0 { - return nil + return fmt.Errorf("validate merkle roots observation: %w", err) } - seenChains := mapset.NewSet[cciptypes.ChainSelector]() - for _, seqNumChain := range onRampMaxSeqNums { - if !observerSupportedChains.Contains(seqNumChain.ChainSel) { - return fmt.Errorf("found onRampMaxSeqNum for chain %d, but this chain is not supported by Observer %d, "+ - "observerSupportedChains: %v, onRampMaxSeqNums: %v", - seqNumChain.ChainSel, observer, observerSupportedChains, onRampMaxSeqNums) - } - - if seenChains.Contains(seqNumChain.ChainSel) { - return fmt.Errorf("duplicate onRampMaxSeqNum for chain %d", seqNumChain.ChainSel) - } - seenChains.Add(seqNumChain.ChainSel) + tokenObs := TokenPricesObservation{ + OracleID: ao.Observer, + Observation: obs.TokenPriceObs, } - - return nil -} - -func validateObservedOffRampMaxSeqNums( - offRampMaxSeqNums []plugintypes.SeqNumChain, - observer commontypes.OracleID, - supportsDestChain bool, -) error { - if len(offRampMaxSeqNums) == 0 { - return nil + err = p.tokenPriceProcessor.ValidateObservation(prevOutcome.TokenPriceOutcome, tokenprice.Query{}, tokenObs) + if err != nil { + return fmt.Errorf("validate token prices observation: %w", err) } - if !supportsDestChain { - return fmt.Errorf("observer %d does not support dest chain, but has observed %d offRampMaxSeqNums", - observer, len(offRampMaxSeqNums)) + gasObs := ChainFeeObservation{ + OracleID: ao.Observer, + Observation: obs.ChainFeeObs, } - - seenChains := mapset.NewSet[cciptypes.ChainSelector]() - for _, seqNumChain := range offRampMaxSeqNums { - if seenChains.Contains(seqNumChain.ChainSel) { - return fmt.Errorf("duplicate offRampMaxSeqNum for chain %d", seqNumChain.ChainSel) - } - seenChains.Add(seqNumChain.ChainSel) + if err := p.chainFeeProcessor.ValidateObservation(prevOutcome.ChainFeeOutcome, chainfee.Query{}, gasObs); err != nil { + return fmt.Errorf("validate chain fee observation: %w", err) } return nil @@ -146,78 +69,3 @@ func validateFChain(fChain map[cciptypes.ChainSelector]int) error { return nil } - -func validateObservedTokenPrices(tokenPrices []cciptypes.TokenPrice) error { - tokensWithPrice := mapset.NewSet[types.Account]() - for _, t := range tokenPrices { - if tokensWithPrice.Contains(t.TokenID) { - return fmt.Errorf("duplicate token price for token: %s", t.TokenID) - } - tokensWithPrice.Add(t.TokenID) - - if t.Price.IsEmpty() { - return fmt.Errorf("token price must not be empty") - } - } - - return nil -} - -func validateObservedGasPrices(gasPrices []cciptypes.GasPriceChain) error { - // Duplicate gas prices must not appear for the same chain and must not be empty. - gasPriceChains := mapset.NewSet[cciptypes.ChainSelector]() - for _, g := range gasPrices { - if gasPriceChains.Contains(g.ChainSel) { - return fmt.Errorf("duplicate gas price for chain %d", g.ChainSel) - } - gasPriceChains.Add(g.ChainSel) - if g.GasPrice.IsEmpty() { - return fmt.Errorf("gas price must not be empty") - } - } - - return nil -} - -// validateMerkleRootsState merkle roots seq nums validation by comparing with on-chain state. -func validateMerkleRootsState( - ctx context.Context, - lggr logger.Logger, - report cciptypes.CommitPluginReport, - reader reader.CCIP, -) (bool, error) { - reportChains := make([]cciptypes.ChainSelector, 0) - reportMinSeqNums := make(map[cciptypes.ChainSelector]cciptypes.SeqNum) - for _, mr := range report.MerkleRoots { - reportChains = append(reportChains, mr.ChainSel) - reportMinSeqNums[mr.ChainSel] = mr.SeqNumsRange.Start() - } - - if len(reportChains) == 0 { - return true, nil - } - - onchainNextSeqNums, err := reader.NextSeqNum(ctx, reportChains) - if err != nil { - return false, fmt.Errorf("get next sequence numbers: %w", err) - } - if len(onchainNextSeqNums) != len(reportChains) { - return false, fmt.Errorf("critical error: onchainSeqNums length mismatch") - } - - for i, nextSeqNum := range onchainNextSeqNums { - chain := reportChains[i] - reportMinSeqNum, ok := reportMinSeqNums[chain] - if !ok { - return false, fmt.Errorf("critical error: reportSeqNum not found for chain %d", chain) - } - - if reportMinSeqNum != nextSeqNum { - lggr.Warnw("report is not valid due to seq num mismatch", - "chain", chain, "reportMinSeqNum", reportMinSeqNum, "onchainNextSeqNum", nextSeqNum) - return false, nil - } - } - - return true, nil -} diff --git a/commit/validate_observation_test.go b/commit/validate_observation_test.go index b89995b4a..fa64b5e61 100644 --- a/commit/validate_observation_test.go +++ b/commit/validate_observation_test.go @@ -1,183 +1,13 @@ package commit import ( - "context" - "math/big" "testing" - mapset "github.com/deckarep/golang-set/v2" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/libocr/commontypes" "github.com/stretchr/testify/assert" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - - reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" - "github.com/smartcontractkit/chainlink-ccip/plugintypes" ) -func Test_validateObservedMerkleRoots(t *testing.T) { - testCases := []struct { - name string - merkleRoots []cciptypes.MerkleRootChain - observer commontypes.OracleID - observerSupportedChains mapset.Set[cciptypes.ChainSelector] - expErr bool - }{ - { - name: "Chain not supported", - merkleRoots: []cciptypes.MerkleRootChain{ - {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, - {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), - expErr: true, - }, - { - name: "Duplicate chains", - merkleRoots: []cciptypes.MerkleRootChain{ - {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, - {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, - {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{3, 7}, MerkleRoot: [32]byte{1, 2, 3}}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), - expErr: true, - }, - { - name: "Valid offRampMaxSeqNums", - merkleRoots: []cciptypes.MerkleRootChain{ - {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, - {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), - expErr: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := validateObservedMerkleRoots(tc.merkleRoots, tc.observer, tc.observerSupportedChains) - - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - } -} - -func Test_validateObservedOnRampMaxSeqNums(t *testing.T) { - testCases := []struct { - name string - onRampMaxSeqNums []plugintypes.SeqNumChain - observer commontypes.OracleID - observerSupportedChains mapset.Set[cciptypes.ChainSelector] - expErr bool - }{ - { - name: "Chain not supported", - onRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), - expErr: true, - }, - { - name: "Duplicate chains", - onRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - {ChainSel: 2, SeqNum: 33}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), - expErr: true, - }, - { - name: "Valid offRampMaxSeqNums", - onRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - }, - observer: 10, - observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), - expErr: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := validateObservedOnRampMaxSeqNums(tc.onRampMaxSeqNums, tc.observer, tc.observerSupportedChains) - - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - } -} - -func Test_validateObservedOffRampMaxSeqNums(t *testing.T) { - testCases := []struct { - name string - offRampMaxSeqNums []plugintypes.SeqNumChain - observer commontypes.OracleID - supportsDestChain bool - expErr bool - }{ - { - name: "Dest chain not supported", - offRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - }, - observer: 10, - supportsDestChain: false, - expErr: true, - }, - { - name: "Duplicate chains", - offRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - {ChainSel: 2, SeqNum: 33}, - }, - observer: 10, - supportsDestChain: false, - expErr: true, - }, - { - name: "Valid offRampMaxSeqNums", - offRampMaxSeqNums: []plugintypes.SeqNumChain{ - {ChainSel: 1, SeqNum: 10}, - {ChainSel: 2, SeqNum: 20}, - }, - observer: 10, - supportsDestChain: true, - expErr: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := validateObservedOffRampMaxSeqNums(tc.offRampMaxSeqNums, tc.observer, tc.supportsDestChain) - - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - } -} - func Test_validateFChain(t *testing.T) { testCases := []struct { name string @@ -214,188 +44,3 @@ func Test_validateFChain(t *testing.T) { }) } } - -func Test_validateObservedTokenPrices(t *testing.T) { - testCases := []struct { - name string - tokenPrices []cciptypes.TokenPrice - expErr bool - }{ - { - name: "empty is valid", - tokenPrices: []cciptypes.TokenPrice{}, - expErr: false, - }, - { - name: "all valid", - tokenPrices: []cciptypes.TokenPrice{ - cciptypes.NewTokenPrice("0x1", big.NewInt(1)), - cciptypes.NewTokenPrice("0x2", big.NewInt(1)), - cciptypes.NewTokenPrice("0x3", big.NewInt(1)), - cciptypes.NewTokenPrice("0xa", big.NewInt(1)), - }, - expErr: false, - }, - { - name: "dup price", - tokenPrices: []cciptypes.TokenPrice{ - cciptypes.NewTokenPrice("0x1", big.NewInt(1)), - cciptypes.NewTokenPrice("0x2", big.NewInt(1)), - cciptypes.NewTokenPrice("0x1", big.NewInt(1)), // dup - cciptypes.NewTokenPrice("0xa", big.NewInt(1)), - }, - expErr: true, - }, - { - name: "nil price", - tokenPrices: []cciptypes.TokenPrice{ - cciptypes.NewTokenPrice("0x1", big.NewInt(1)), - cciptypes.NewTokenPrice("0x2", big.NewInt(1)), - cciptypes.NewTokenPrice("0x3", nil), // nil price - cciptypes.NewTokenPrice("0xa", big.NewInt(1)), - }, - expErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := validateObservedTokenPrices(tc.tokenPrices) - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - - } -} - -func Test_validateObservedGasPrices(t *testing.T) { - testCases := []struct { - name string - gasPrices []cciptypes.GasPriceChain - expErr bool - }{ - { - name: "empty is valid", - gasPrices: []cciptypes.GasPriceChain{}, - expErr: false, - }, - { - name: "all valid", - gasPrices: []cciptypes.GasPriceChain{ - cciptypes.NewGasPriceChain(big.NewInt(10), 1), - cciptypes.NewGasPriceChain(big.NewInt(20), 2), - cciptypes.NewGasPriceChain(big.NewInt(1312), 3), - }, - expErr: false, - }, - { - name: "duplicate gas price", - gasPrices: []cciptypes.GasPriceChain{ - cciptypes.NewGasPriceChain(big.NewInt(10), 1), - cciptypes.NewGasPriceChain(big.NewInt(20), 2), - cciptypes.NewGasPriceChain(big.NewInt(1312), 1), // notice we already have a gas price for chain 1 - }, - expErr: true, - }, - { - name: "empty gas price", - gasPrices: []cciptypes.GasPriceChain{ - cciptypes.NewGasPriceChain(big.NewInt(10), 1), - cciptypes.NewGasPriceChain(big.NewInt(20), 2), - cciptypes.NewGasPriceChain(nil, 3), // nil - }, - expErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := validateObservedGasPrices(tc.gasPrices) - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - }) - } -} - -func Test_validateMerkleRootsState(t *testing.T) { - testCases := []struct { - name string - reportSeqNums []plugintypes.SeqNumChain - onchainNextSeqNums []cciptypes.SeqNum - expValid bool - expErr bool - }{ - { - name: "happy path", - reportSeqNums: []plugintypes.SeqNumChain{ - plugintypes.NewSeqNumChain(10, 100), - plugintypes.NewSeqNumChain(20, 200), - }, - onchainNextSeqNums: []cciptypes.SeqNum{100, 200}, - expValid: true, - expErr: false, - }, - { - name: "one root is stale", - reportSeqNums: []plugintypes.SeqNumChain{ - plugintypes.NewSeqNumChain(10, 100), - plugintypes.NewSeqNumChain(20, 200), - }, - onchainNextSeqNums: []cciptypes.SeqNum{100, 201}, // <- 200 is already on chain - expValid: false, - expErr: false, - }, - { - name: "one root has gap", - reportSeqNums: []plugintypes.SeqNumChain{ - plugintypes.NewSeqNumChain(10, 101), // <- onchain 99 but we submit 101 instead of 100 - plugintypes.NewSeqNumChain(20, 200), - }, - onchainNextSeqNums: []cciptypes.SeqNum{100, 200}, - expValid: false, - expErr: false, - }, - { - name: "reader returned wrong number of seq nums", - reportSeqNums: []plugintypes.SeqNumChain{ - plugintypes.NewSeqNumChain(10, 100), - plugintypes.NewSeqNumChain(20, 200), - }, - onchainNextSeqNums: []cciptypes.SeqNum{100, 200, 300}, - expValid: false, - expErr: true, - }, - } - - ctx := context.Background() - lggr := logger.Test(t) - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - reader := reader_mock.NewMockCCIP(t) - rep := cciptypes.CommitPluginReport{} - chains := make([]cciptypes.ChainSelector, 0, len(tc.reportSeqNums)) - for _, snc := range tc.reportSeqNums { - rep.MerkleRoots = append(rep.MerkleRoots, cciptypes.MerkleRootChain{ - ChainSel: snc.ChainSel, - SeqNumsRange: cciptypes.NewSeqNumRange(snc.SeqNum, snc.SeqNum+10), - }) - chains = append(chains, snc.ChainSel) - } - reader.On("NextSeqNum", ctx, chains).Return(tc.onchainNextSeqNums, nil) - valid, err := validateMerkleRootsState(ctx, lggr, rep, reader) - if tc.expErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - assert.Equal(t, tc.expValid, valid) - }) - } -} diff --git a/execute/factory.go b/execute/factory.go index 5f3d487e7..c1d7736b5 100644 --- a/execute/factory.go +++ b/execute/factory.go @@ -7,12 +7,13 @@ import ( "google.golang.org/grpc" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/types" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/types/core" + "github.com/smartcontractkit/libocr/commontypes" ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" diff --git a/execute/report/report_test.go b/execute/report/report_test.go index fe4325695..772f6aba6 100644 --- a/execute/report/report_test.go +++ b/execute/report/report_test.go @@ -114,7 +114,8 @@ func TestMustMakeBytes(t *testing.T) { } } -// assertMerkleRoot computes the source messages merkle root, then computes a verification with the proof, then compares +// assertMerkleRoot computes the source messages merkle root, +// then computes a verification with the proof, then compares // the roots. func assertMerkleRoot( t *testing.T, diff --git a/mocks/commit/observer.go b/mocks/commit/merkleroot/observer.go similarity index 65% rename from mocks/commit/observer.go rename to mocks/commit/merkleroot/observer.go index 054040a87..86fa4fe1e 100644 --- a/mocks/commit/observer.go +++ b/mocks/commit/merkleroot/observer.go @@ -1,12 +1,12 @@ // Code generated by mockery v2.43.2. DO NOT EDIT. -package commit +package merkleroot import ( - ccipocr3 "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - context "context" + ccipocr3 "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + mock "github.com/stretchr/testify/mock" plugintypes "github.com/smartcontractkit/chainlink-ccip/plugintypes" @@ -72,54 +72,6 @@ func (_c *MockObserver_ObserveFChain_Call) RunAndReturn(run func() map[ccipocr3. return _c } -// ObserveGasPrices provides a mock function with given fields: ctx -func (_m *MockObserver) ObserveGasPrices(ctx context.Context) []ccipocr3.GasPriceChain { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for ObserveGasPrices") - } - - var r0 []ccipocr3.GasPriceChain - if rf, ok := ret.Get(0).(func(context.Context) []ccipocr3.GasPriceChain); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]ccipocr3.GasPriceChain) - } - } - - return r0 -} - -// MockObserver_ObserveGasPrices_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ObserveGasPrices' -type MockObserver_ObserveGasPrices_Call struct { - *mock.Call -} - -// ObserveGasPrices is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockObserver_Expecter) ObserveGasPrices(ctx interface{}) *MockObserver_ObserveGasPrices_Call { - return &MockObserver_ObserveGasPrices_Call{Call: _e.mock.On("ObserveGasPrices", ctx)} -} - -func (_c *MockObserver_ObserveGasPrices_Call) Run(run func(ctx context.Context)) *MockObserver_ObserveGasPrices_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *MockObserver_ObserveGasPrices_Call) Return(_a0 []ccipocr3.GasPriceChain) *MockObserver_ObserveGasPrices_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockObserver_ObserveGasPrices_Call) RunAndReturn(run func(context.Context) []ccipocr3.GasPriceChain) *MockObserver_ObserveGasPrices_Call { - _c.Call.Return(run) - return _c -} - // ObserveMerkleRoots provides a mock function with given fields: ctx, ranges func (_m *MockObserver) ObserveMerkleRoots(ctx context.Context, ranges []plugintypes.ChainRange) []ccipocr3.MerkleRootChain { ret := _m.Called(ctx, ranges) @@ -217,54 +169,6 @@ func (_c *MockObserver_ObserveOffRampNextSeqNums_Call) RunAndReturn(run func(con return _c } -// ObserveTokenPrices provides a mock function with given fields: ctx -func (_m *MockObserver) ObserveTokenPrices(ctx context.Context) []ccipocr3.TokenPrice { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for ObserveTokenPrices") - } - - var r0 []ccipocr3.TokenPrice - if rf, ok := ret.Get(0).(func(context.Context) []ccipocr3.TokenPrice); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]ccipocr3.TokenPrice) - } - } - - return r0 -} - -// MockObserver_ObserveTokenPrices_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ObserveTokenPrices' -type MockObserver_ObserveTokenPrices_Call struct { - *mock.Call -} - -// ObserveTokenPrices is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockObserver_Expecter) ObserveTokenPrices(ctx interface{}) *MockObserver_ObserveTokenPrices_Call { - return &MockObserver_ObserveTokenPrices_Call{Call: _e.mock.On("ObserveTokenPrices", ctx)} -} - -func (_c *MockObserver_ObserveTokenPrices_Call) Run(run func(ctx context.Context)) *MockObserver_ObserveTokenPrices_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *MockObserver_ObserveTokenPrices_Call) Return(_a0 []ccipocr3.TokenPrice) *MockObserver_ObserveTokenPrices_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockObserver_ObserveTokenPrices_Call) RunAndReturn(run func(context.Context) []ccipocr3.TokenPrice) *MockObserver_ObserveTokenPrices_Call { - _c.Call.Return(run) - return _c -} - // NewMockObserver creates a new instance of MockObserver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockObserver(t interface { diff --git a/mocks/commit/chain_support.go b/mocks/shared/chain_support.go similarity index 99% rename from mocks/commit/chain_support.go rename to mocks/shared/chain_support.go index 959a7d8c2..a205cb03c 100644 --- a/mocks/commit/chain_support.go +++ b/mocks/shared/chain_support.go @@ -1,10 +1,9 @@ // Code generated by mockery v2.43.2. DO NOT EDIT. -package commit +package shared import ( ccipocr3 "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - commontypes "github.com/smartcontractkit/libocr/commontypes" mapset "github.com/deckarep/golang-set/v2" diff --git a/mocks/shared/plugin_processor.go b/mocks/shared/plugin_processor.go new file mode 100644 index 000000000..b58b076b3 --- /dev/null +++ b/mocks/shared/plugin_processor.go @@ -0,0 +1,258 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package shared + +import ( + context "context" + + shared "github.com/smartcontractkit/chainlink-ccip/shared" + mock "github.com/stretchr/testify/mock" +) + +// MockPluginProcessor is an autogenerated mock type for the PluginProcessor type +type MockPluginProcessor[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + mock.Mock +} + +type MockPluginProcessor_Expecter[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + mock *mock.Mock +} + +func (_m *MockPluginProcessor[QueryType, ObservationType, OutcomeType]) EXPECT() *MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType] { + return &MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType]{mock: &_m.Mock} +} + +// Observation provides a mock function with given fields: ctx, prevOutcome, query +func (_m *MockPluginProcessor[QueryType, ObservationType, OutcomeType]) Observation(ctx context.Context, prevOutcome OutcomeType, query QueryType) (ObservationType, error) { + ret := _m.Called(ctx, prevOutcome, query) + + if len(ret) == 0 { + panic("no return value specified for Observation") + } + + var r0 ObservationType + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, OutcomeType, QueryType) (ObservationType, error)); ok { + return rf(ctx, prevOutcome, query) + } + if rf, ok := ret.Get(0).(func(context.Context, OutcomeType, QueryType) ObservationType); ok { + r0 = rf(ctx, prevOutcome, query) + } else { + r0 = ret.Get(0).(ObservationType) + } + + if rf, ok := ret.Get(1).(func(context.Context, OutcomeType, QueryType) error); ok { + r1 = rf(ctx, prevOutcome, query) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPluginProcessor_Observation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Observation' +type MockPluginProcessor_Observation_Call[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + *mock.Call +} + +// Observation is a helper method to define mock.On call +// - ctx context.Context +// - prevOutcome OutcomeType +// - query QueryType +func (_e *MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType]) Observation(ctx interface{}, prevOutcome interface{}, query interface{}) *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType] { + return &MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType]{Call: _e.mock.On("Observation", ctx, prevOutcome, query)} +} + +func (_c *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType]) Run(run func(ctx context.Context, prevOutcome OutcomeType, query QueryType)) *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(OutcomeType), args[2].(QueryType)) + }) + return _c +} + +func (_c *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType]) Return(_a0 ObservationType, _a1 error) *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType]) RunAndReturn(run func(context.Context, OutcomeType, QueryType) (ObservationType, error)) *MockPluginProcessor_Observation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(run) + return _c +} + +// Outcome provides a mock function with given fields: prevOutcome, query, aos +func (_m *MockPluginProcessor[QueryType, ObservationType, OutcomeType]) Outcome(prevOutcome OutcomeType, query QueryType, aos []shared.AttributedObservation[ObservationType]) (OutcomeType, error) { + ret := _m.Called(prevOutcome, query, aos) + + if len(ret) == 0 { + panic("no return value specified for Outcome") + } + + var r0 OutcomeType + var r1 error + if rf, ok := ret.Get(0).(func(OutcomeType, QueryType, []shared.AttributedObservation[ObservationType]) (OutcomeType, error)); ok { + return rf(prevOutcome, query, aos) + } + if rf, ok := ret.Get(0).(func(OutcomeType, QueryType, []shared.AttributedObservation[ObservationType]) OutcomeType); ok { + r0 = rf(prevOutcome, query, aos) + } else { + r0 = ret.Get(0).(OutcomeType) + } + + if rf, ok := ret.Get(1).(func(OutcomeType, QueryType, []shared.AttributedObservation[ObservationType]) error); ok { + r1 = rf(prevOutcome, query, aos) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPluginProcessor_Outcome_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Outcome' +type MockPluginProcessor_Outcome_Call[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + *mock.Call +} + +// Outcome is a helper method to define mock.On call +// - prevOutcome OutcomeType +// - query QueryType +// - aos []shared.AttributedObservation[ObservationType] +func (_e *MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType]) Outcome(prevOutcome interface{}, query interface{}, aos interface{}) *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType] { + return &MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType]{Call: _e.mock.On("Outcome", prevOutcome, query, aos)} +} + +func (_c *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType]) Run(run func(prevOutcome OutcomeType, query QueryType, aos []shared.AttributedObservation[ObservationType])) *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(OutcomeType), args[1].(QueryType), args[2].([]shared.AttributedObservation[ObservationType])) + }) + return _c +} + +func (_c *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType]) Return(_a0 OutcomeType, _a1 error) *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType]) RunAndReturn(run func(OutcomeType, QueryType, []shared.AttributedObservation[ObservationType]) (OutcomeType, error)) *MockPluginProcessor_Outcome_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(run) + return _c +} + +// Query provides a mock function with given fields: ctx, prevOutcome +func (_m *MockPluginProcessor[QueryType, ObservationType, OutcomeType]) Query(ctx context.Context, prevOutcome OutcomeType) (QueryType, error) { + ret := _m.Called(ctx, prevOutcome) + + if len(ret) == 0 { + panic("no return value specified for Query") + } + + var r0 QueryType + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, OutcomeType) (QueryType, error)); ok { + return rf(ctx, prevOutcome) + } + if rf, ok := ret.Get(0).(func(context.Context, OutcomeType) QueryType); ok { + r0 = rf(ctx, prevOutcome) + } else { + r0 = ret.Get(0).(QueryType) + } + + if rf, ok := ret.Get(1).(func(context.Context, OutcomeType) error); ok { + r1 = rf(ctx, prevOutcome) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPluginProcessor_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' +type MockPluginProcessor_Query_Call[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + *mock.Call +} + +// Query is a helper method to define mock.On call +// - ctx context.Context +// - prevOutcome OutcomeType +func (_e *MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType]) Query(ctx interface{}, prevOutcome interface{}) *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType] { + return &MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType]{Call: _e.mock.On("Query", ctx, prevOutcome)} +} + +func (_c *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType]) Run(run func(ctx context.Context, prevOutcome OutcomeType)) *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(OutcomeType)) + }) + return _c +} + +func (_c *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType]) Return(_a0 QueryType, _a1 error) *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType]) RunAndReturn(run func(context.Context, OutcomeType) (QueryType, error)) *MockPluginProcessor_Query_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(run) + return _c +} + +// ValidateObservation provides a mock function with given fields: prevOutcome, query, ao +func (_m *MockPluginProcessor[QueryType, ObservationType, OutcomeType]) ValidateObservation(prevOutcome OutcomeType, query QueryType, ao shared.AttributedObservation[ObservationType]) error { + ret := _m.Called(prevOutcome, query, ao) + + if len(ret) == 0 { + panic("no return value specified for ValidateObservation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(OutcomeType, QueryType, shared.AttributedObservation[ObservationType]) error); ok { + r0 = rf(prevOutcome, query, ao) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockPluginProcessor_ValidateObservation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateObservation' +type MockPluginProcessor_ValidateObservation_Call[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}] struct { + *mock.Call +} + +// ValidateObservation is a helper method to define mock.On call +// - prevOutcome OutcomeType +// - query QueryType +// - ao shared.AttributedObservation[ObservationType] +func (_e *MockPluginProcessor_Expecter[QueryType, ObservationType, OutcomeType]) ValidateObservation(prevOutcome interface{}, query interface{}, ao interface{}) *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType] { + return &MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType]{Call: _e.mock.On("ValidateObservation", prevOutcome, query, ao)} +} + +func (_c *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType]) Run(run func(prevOutcome OutcomeType, query QueryType, ao shared.AttributedObservation[ObservationType])) *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(OutcomeType), args[1].(QueryType), args[2].(shared.AttributedObservation[ObservationType])) + }) + return _c +} + +func (_c *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType]) Return(_a0 error) *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType]) RunAndReturn(run func(OutcomeType, QueryType, shared.AttributedObservation[ObservationType]) error) *MockPluginProcessor_ValidateObservation_Call[QueryType, ObservationType, OutcomeType] { + _c.Call.Return(run) + return _c +} + +// NewMockPluginProcessor creates a new instance of MockPluginProcessor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPluginProcessor[QueryType interface{}, ObservationType interface{}, OutcomeType interface{}](t interface { + mock.TestingT + Cleanup(func()) +}) *MockPluginProcessor[QueryType, ObservationType, OutcomeType] { + mock := &MockPluginProcessor[QueryType, ObservationType, OutcomeType]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pluginconfig/commit_test.go b/pluginconfig/commit_test.go index 3fcd5b5f8..58f3b8352 100644 --- a/pluginconfig/commit_test.go +++ b/pluginconfig/commit_test.go @@ -4,11 +4,12 @@ import ( "math/big" "testing" - commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) func TestCommitPluginConfigValidate(t *testing.T) { diff --git a/pluginconfig/execute_test.go b/pluginconfig/execute_test.go index 4d9c759ae..3991ecc75 100644 --- a/pluginconfig/execute_test.go +++ b/pluginconfig/execute_test.go @@ -3,8 +3,9 @@ package pluginconfig import ( "testing" - commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/stretchr/testify/require" + + commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" ) func TestExecuteOffchainConfig_Validate(t *testing.T) { diff --git a/plugintypes/commit_test.go b/plugintypes/commit_test.go index 14dca5ec9..e95e6d12f 100644 --- a/plugintypes/commit_test.go +++ b/plugintypes/commit_test.go @@ -5,9 +5,10 @@ import ( "math/big" "testing" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) func TestCommitPluginObservation_EncodeAndDecode(t *testing.T) { diff --git a/commit/chain_support.go b/shared/chain_support.go similarity index 83% rename from commit/chain_support.go rename to shared/chain_support.go index f56145f4e..2689e5ed7 100644 --- a/commit/chain_support.go +++ b/shared/chain_support.go @@ -1,4 +1,4 @@ -package commit +package shared import ( "fmt" @@ -31,11 +31,27 @@ type ChainSupport interface { type CCIPChainSupport struct { lggr logger.Logger homeChain reader.HomeChain - oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID + oracleIDToP2PID map[commontypes.OracleID]libocrtypes.PeerID nodeID commontypes.OracleID destChain cciptypes.ChainSelector } +func NewCCIPChainSupport( + lggr logger.Logger, + homeChain reader.HomeChain, + oracleIDToP2PID map[commontypes.OracleID]libocrtypes.PeerID, + nodeID commontypes.OracleID, + destChain cciptypes.ChainSelector, +) CCIPChainSupport { + return CCIPChainSupport{ + lggr: lggr, + homeChain: homeChain, + oracleIDToP2PID: oracleIDToP2PID, + nodeID: nodeID, + destChain: destChain, + } +} + func (c CCIPChainSupport) KnownSourceChainsSlice() ([]cciptypes.ChainSelector, error) { allChainsSet, err := c.homeChain.GetKnownCCIPChains() if err != nil { @@ -53,9 +69,9 @@ func (c CCIPChainSupport) KnownSourceChainsSlice() ([]cciptypes.ChainSelector, e // SupportedChains returns the set of chains that the given Oracle is configured to access func (c CCIPChainSupport) SupportedChains(oracleID commontypes.OracleID) (mapset.Set[cciptypes.ChainSelector], error) { - p2pID, exists := c.oracleIDToP2pID[oracleID] + p2pID, exists := c.oracleIDToP2PID[oracleID] if !exists { - return nil, fmt.Errorf("oracle ID %d not found in oracleIDToP2pID", c.nodeID) + return nil, fmt.Errorf("oracle ID %d not found in oracleIDToP2PID", c.nodeID) } supportedChains, err := c.homeChain.GetSupportedChainsForPeer(p2pID) if err != nil { @@ -73,9 +89,9 @@ func (c CCIPChainSupport) SupportsDestChain(oracle commontypes.OracleID) (bool, return false, fmt.Errorf("get chain config: %w", err) } - peerID, ok := c.oracleIDToP2pID[oracle] + peerID, ok := c.oracleIDToP2PID[oracle] if !ok { - return false, fmt.Errorf("oracle ID %d not found in oracleIDToP2pID", oracle) + return false, fmt.Errorf("oracle ID %d not found in oracleIDToP2PID", oracle) } return destChainConfig.SupportedNodes.Contains(peerID), nil diff --git a/commit/chain_support_test.go b/shared/chain_support_test.go similarity index 96% rename from commit/chain_support_test.go rename to shared/chain_support_test.go index 982ade924..3ff7cf08d 100644 --- a/commit/chain_support_test.go +++ b/shared/chain_support_test.go @@ -1,16 +1,17 @@ -package commit +package shared import ( "fmt" "testing" mapset "github.com/deckarep/golang-set/v2" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + reader2 "github.com/smartcontractkit/chainlink-ccip/internal/reader" "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" ) @@ -52,7 +53,7 @@ func TestCCIPChainSupport_SupportedChains(t *testing.T) { cs := &CCIPChainSupport{ lggr: lggr, homeChain: homeChainReader, - oracleIDToP2pID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, + oracleIDToP2PID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, } t.Run("happy path", func(t *testing.T) { @@ -83,7 +84,7 @@ func TestCCIPChainSupport_SupportsDestChain(t *testing.T) { lggr: lggr, homeChain: homeChainReader, destChain: dstChain, - oracleIDToP2pID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, + oracleIDToP2PID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, } t.Run("happy path", func(t *testing.T) { diff --git a/shared/plugin_processor.go b/shared/plugin_processor.go new file mode 100644 index 000000000..c1e14eb6b --- /dev/null +++ b/shared/plugin_processor.go @@ -0,0 +1,44 @@ +package shared + +import ( + "context" + + "github.com/smartcontractkit/libocr/commontypes" +) + +type AttributedObservation[ObservationType any] struct { + OracleID commontypes.OracleID + Observation ObservationType +} + +// PluginProcessor is to encapsulate logic for multiple processors under a OCR plugin. +// This makes it easier to manage and test when there are multiple logical components of a single plugin. +// Some of them will implement state machines (e.g. merkleroot), +// others might implement simpler logic. (e.g. tokenprices, chainfee) +// The OCR plugin becomes a coordinator/collector of these processors. +// Example Pseudo code: +// +// type OCRPlugin struct { +// merkleRootsProcessor +// tokenPriceProcessor +// chainFeeProcessor +// } +// +// // Observation excludes error handling for brevity. +// func (p *OCRPlugin) Observation() ocrtype.Observation { +// return ocrtype.Observation { +// merkleRoots: p.merkleRootsProcessor.Observation(...) +// tokenPrices: p.tokenPriceProcessor.Observation(...) +// chainFees: p.chainFeeProcessor.Observation(...) +// } +// } +// +// Notice all interface functions are using prevOutcome instead of outCtx. +// We're interested in the prevOutcome, and it makes it easier to have all decoding on the top level (OCR plugin), +// otherwise there might be cyclic dependencies or just complicating the code more. +type PluginProcessor[QueryType any, ObservationType any, OutcomeType any] interface { + Query(ctx context.Context, prevOutcome OutcomeType) (QueryType, error) + Observation(ctx context.Context, prevOutcome OutcomeType, query QueryType) (ObservationType, error) + ValidateObservation(prevOutcome OutcomeType, query QueryType, ao AttributedObservation[ObservationType]) error + Outcome(prevOutcome OutcomeType, query QueryType, aos []AttributedObservation[ObservationType]) (OutcomeType, error) +}