From ed5091cd2a9ae3aeb5915e1ed901d3333ba6090b Mon Sep 17 00:00:00 2001 From: Will Winder Date: Mon, 1 Jul 2024 15:01:01 -0400 Subject: [PATCH] OCR3 execute report generation. --- execute/factory.go | 4 +- execute/plugin.go | 315 ++++++++++++++++++-- execute/plugin_functions.go | 56 ++++ execute/plugin_test.go | 524 ++++++++++++++++++++++++++++++++-- go.mod | 2 +- go.sum | 4 + internal/mocks/reportcodec.go | 16 ++ 7 files changed, 882 insertions(+), 39 deletions(-) diff --git a/execute/factory.go b/execute/factory.go index a75139d12..5a466b7af 100644 --- a/execute/factory.go +++ b/execute/factory.go @@ -47,10 +47,12 @@ func (p PluginFactory) NewReportingPlugin( config ocr3types.ReportingPluginConfig, ) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) { return NewPlugin( - context.Background(), config, cciptypes.ExecutePluginConfig{}, nil, + nil, + nil, + nil, ), ocr3types.ReportingPluginInfo{}, nil } diff --git a/execute/plugin.go b/execute/plugin.go index cb252d0cf..d27fb06c7 100644 --- a/execute/plugin.go +++ b/execute/plugin.go @@ -2,42 +2,59 @@ package execute import ( "context" + "errors" "fmt" "slices" "sort" "sync/atomic" "time" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + "github.com/smartcontractkit/chainlink-ccip/internal/libs/slicelib" ) +// maxReportSizeBytes that should be returned as an execution report payload. +const maxReportSizeBytes = 250_000 + // Plugin implements the main ocr3 plugin logic. type Plugin struct { - reportingCfg ocr3types.ReportingPluginConfig - cfg cciptypes.ExecutePluginConfig - ccipReader cciptypes.CCIPReader + reportingCfg ocr3types.ReportingPluginConfig + cfg cciptypes.ExecutePluginConfig + ccipReader cciptypes.CCIPReader + reportCodec cciptypes.ExecutePluginCodec + msgHasher cciptypes.MessageHasher + tokenDataReader TokenDataReader - //commitRootsCache cache.CommitsRootsCache lastReportTS *atomic.Int64 + + lggr logger.Logger } func NewPlugin( - _ context.Context, reportingCfg ocr3types.ReportingPluginConfig, cfg cciptypes.ExecutePluginConfig, ccipReader cciptypes.CCIPReader, + reportCodec cciptypes.ExecutePluginCodec, + msgHasher cciptypes.MessageHasher, + lggr logger.Logger, ) *Plugin { lastReportTS := &atomic.Int64{} lastReportTS.Store(time.Now().Add(-cfg.MessageVisibilityInterval).UnixMilli()) + // TODO: initialize tokenDataReader. + return &Plugin{ reportingCfg: reportingCfg, cfg: cfg, ccipReader: ccipReader, + reportCodec: reportCodec, + msgHasher: msgHasher, lastReportTS: lastReportTS, + lggr: lggr, } } @@ -199,6 +216,248 @@ func (p *Plugin) ObservationQuorum(outctx ocr3types.OutcomeContext, query types. return ocr3types.QuorumFPlusOne, nil } +// TokenDataReader is an interface for reading extra token data from an async process. +// TODO: Build a token data reading process. +type TokenDataReader interface { + ReadTokenData(ctx context.Context, srcChain cciptypes.ChainSelector, num cciptypes.SeqNum) ([][]byte, error) +} + +// buildSingleChainReportMaxSize generates the largest report which fits into maxSizeBytes. +// See buildSingleChainReport for more details about how a report is built. +func buildSingleChainReportMaxSize( + ctx context.Context, + lggr logger.Logger, + hasher cciptypes.MessageHasher, + tokenDataReader TokenDataReader, + encoder cciptypes.ExecutePluginCodec, + report cciptypes.ExecutePluginCommitDataWithMessages, + maxSizeBytes int, +) (cciptypes.ExecutePluginReportSingleChain, int, cciptypes.ExecutePluginCommitDataWithMessages, error) { + finalReport, encodedSize, err := + buildSingleChainReport(ctx, lggr, hasher, tokenDataReader, encoder, report, 0) + if err != nil { + return cciptypes.ExecutePluginReportSingleChain{}, + 0, + cciptypes.ExecutePluginCommitDataWithMessages{}, + fmt.Errorf("unable to build a single chain report (max): %w", err) + } + + // return fully executed report + if encodedSize <= maxSizeBytes { + report = markNewMessagesExecuted(finalReport, report) + return finalReport, encodedSize, report, nil + } + + var searchErr error + idx := sort.Search(len(report.Messages), func(mid int) bool { + if searchErr != nil { + return false + } + finalReport2, encodedSize2, err := + buildSingleChainReport(ctx, lggr, hasher, tokenDataReader, encoder, report, mid) + if searchErr != nil { + searchErr = fmt.Errorf("unable to build a single chain report (messages %d): %w", mid, err) + } + + if (encodedSize2) <= maxSizeBytes { + // mid is a valid report size, try something bigger next iteration. + finalReport = finalReport2 + encodedSize = encodedSize2 + return false // not full + } + return true // full + }) + if searchErr != nil { + return cciptypes.ExecutePluginReportSingleChain{}, 0, cciptypes.ExecutePluginCommitDataWithMessages{}, searchErr + } + + // No messages fit into the report. + if idx <= 0 { + return cciptypes.ExecutePluginReportSingleChain{}, + 0, + cciptypes.ExecutePluginCommitDataWithMessages{}, + errNothingExecuted + } + + report = markNewMessagesExecuted(finalReport, report) + return finalReport, encodedSize, report, nil +} + +// buildSingleChainReport converts the on-chain event data stored in cciptypes.ExecutePluginCommitDataWithMessages into +// the final on-chain report format. +// +// The hasher and encoding codec are provided as arguments to allow for chain-specific formats to be used. +// +// The maxMessages argument is used to limit the number of messages that are included in the report. If maxMessages is +// set to 0, all messages will be included. This allows the caller to create smaller reports if needed. +func buildSingleChainReport( + ctx context.Context, + lggr logger.Logger, + hasher cciptypes.MessageHasher, + tokenDataReader TokenDataReader, + encoder cciptypes.ExecutePluginCodec, + report cciptypes.ExecutePluginCommitDataWithMessages, + maxMessages int, +) (cciptypes.ExecutePluginReportSingleChain, int, error) { + // TODO: maxMessages selects messages in FIFO order which may not yield the optimal message size. One message with a + // maximum data size could push the report over a size limit even if several smaller messages could have fit. + if maxMessages == 0 { + maxMessages = len(report.Messages) + } + + tree, err := constructMerkleTree(ctx, hasher, report) + if err != nil { + return cciptypes.ExecutePluginReportSingleChain{}, 0, + fmt.Errorf("unable to construct merkle tree from messages: %w", err) + } + lggr.Debugw( + "constructing merkle tree", + "sourceChain", report.SourceChain, + "treeLeaves", len(report.Messages)) + numMsgs := len(report.Messages) + + // Iterate sequence range and executed messages to select messages to execute. + var toExecute []int + var offchainTokenData [][][]byte + var msgInRoot []cciptypes.CCIPMsg + executedIdx := 0 + for i := 0; i < numMsgs && len(toExecute) <= maxMessages; i++ { + seqNum := report.SequenceNumberRange.Start() + cciptypes.SeqNum(i) + // Skip messages which are already executed + if executedIdx < len(report.ExecutedMessages) && report.ExecutedMessages[executedIdx] == seqNum { + executedIdx++ + } else { + msg := report.Messages[i] + tokenData, err := tokenDataReader.ReadTokenData(context.Background(), report.SourceChain, msg.SeqNum) + if err != nil { + // TODO: skip message instead of failing the whole thing. + // that might mean moving the token data reading out of the loop. + lggr.Infow( + "unable to read token data", + "source-chain", report.SourceChain, + "seq-num", msg.SeqNum, + "error", err) + return cciptypes.ExecutePluginReportSingleChain{}, 0, fmt.Errorf( + "unable to read token data for message %d: %w", msg.SeqNum, err) + } + + lggr.Debugw( + "read token data", + "source-chain", report.SourceChain, + "seq-num", msg.SeqNum, + "data", tokenData) + offchainTokenData = append(offchainTokenData, tokenData) + toExecute = append(toExecute, i) + msgInRoot = append(msgInRoot, msg) + } + } + + lggr.Infow( + "selected messages from commit report for execution", + "sourceChain", report.SourceChain, + "commitRoot", report.MerkleRoot.String(), + "numMessages", numMsgs, + "toExecute", len(toExecute)) + proof, err := tree.Prove(toExecute) + if err != nil { + return cciptypes.ExecutePluginReportSingleChain{}, 0, + fmt.Errorf("unable to prove messages for report %s: %w", report.MerkleRoot.String(), err) + } + + var proofsCast []cciptypes.Bytes32 + for _, p := range proof.Hashes { + proofsCast = append(proofsCast, p) + } + + finalReport := cciptypes.ExecutePluginReportSingleChain{ + SourceChainSelector: report.SourceChain, + Messages: msgInRoot, + OffchainTokenData: offchainTokenData, + Proofs: proofsCast, + ProofFlagBits: cciptypes.BigInt{Int: slicelib.BoolsToBitFlags(proof.SourceFlags)}, + } + + // Note: ExecutePluginReport is a strict array of data, so wrapping the final report + // does not add any additional overhead to the size being computed here. + + // Compute the size of the encoded report. + encoded, err := encoder.Encode( + ctx, + cciptypes.ExecutePluginReport{ + ChainReports: []cciptypes.ExecutePluginReportSingleChain{finalReport}, + }, + ) + if err != nil { + lggr.Errorw("unable to encode report", "err", err, "report", finalReport) + return cciptypes.ExecutePluginReportSingleChain{}, 0, fmt.Errorf("unable to encode report: %w", err) + } + + return finalReport, len(encoded), nil +} + +// selectReport takes a list of reports in execution order and selects the first reports that fit within the +// maxReportSizeBytes. Individual messages in a commit report may be skipped for various reasons, for example if an +// out-of-order execution is detected or the message requires additional off-chain metadata which is not yet available. +// If there is not enough space in the final report, it may be partially executed by searching for a subset of messages +// which can fit in the final report. +func selectReport( + ctx context.Context, + lggr logger.Logger, + hasher cciptypes.MessageHasher, + encoder cciptypes.ExecutePluginCodec, + tokenDataReader TokenDataReader, + reports []cciptypes.ExecutePluginCommitDataWithMessages, + maxReportSizeBytes int, +) ([]cciptypes.ExecutePluginReportSingleChain, []cciptypes.ExecutePluginCommitDataWithMessages, error) { + // TODO: It may be desirable for this entire function to be an interface so that + // different selection algorithms can be used. + + // count number of fully executed reports so that they can be removed after iterating the reports. + fullyExecuted := 0 + accumulatedSize := 0 + var finalReports []cciptypes.ExecutePluginReportSingleChain + for reportIdx, report := range reports { + execReport, encodedSize, updatedReport, err := + buildSingleChainReportMaxSize(ctx, lggr, hasher, tokenDataReader, encoder, + report, maxReportSizeBytes-accumulatedSize) + // No messages fit into the report, stop adding more reports. + if errors.Is(err, errNothingExecuted) { + break + } + if err != nil { + return nil, nil, fmt.Errorf("unable to build single chain report: %w", err) + } + reports[reportIdx] = updatedReport + accumulatedSize += encodedSize + finalReports = append(finalReports, execReport) + + // partially executed report detected, stop adding more reports. + // TODO: do not break if messages were intentionally skipped. + if len(updatedReport.Messages) != len(updatedReport.ExecutedMessages) { + break + } + fullyExecuted++ + } + + // Remove reports that are about to be executed. + if fullyExecuted == len(reports) { + reports = nil + } else { + reports = reports[fullyExecuted:] + } + + lggr.Infow( + "selected commit reports for execution report", + "numReports", len(finalReports), + "size", accumulatedSize, + "incompleteReports", len(reports), + "maxSize", maxReportSizeBytes) + + return finalReports, reports, nil +} + +// Outcome collects the reports from the two phases and constructs the final outcome. Part of the outcome is a fully +// formed report that will be encoded for final transmission in the reporting phase. func (p *Plugin) Outcome( outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, ) (ocr3types.Outcome, error) { @@ -226,16 +485,16 @@ func (p *Plugin) Outcome( mergedMessageObservations) // flatten commit reports and sort by timestamp. - var reports []cciptypes.ExecutePluginCommitDataWithMessages + var commitReports []cciptypes.ExecutePluginCommitDataWithMessages for _, report := range observation.CommitReports { - reports = append(reports, report...) + commitReports = append(commitReports, report...) } - sort.Slice(reports, func(i, j int) bool { - return reports[i].Timestamp.Before(reports[j].Timestamp) + sort.Slice(commitReports, func(i, j int) bool { + return commitReports[i].Timestamp.Before(commitReports[j].Timestamp) }) - // add messages to their reports. - for _, report := range reports { + // add messages to their commitReports. + for _, report := range commitReports { report.Messages = nil for i := report.SequenceNumberRange.Start(); i <= report.SequenceNumberRange.End(); i++ { if msg, ok := observation.Messages[report.SourceChain][i]; ok { @@ -244,15 +503,39 @@ func (p *Plugin) Outcome( } } - // TODO: select reports and messages for the final exec report. - // TODO: may only need the proofs for the final exec report rather than the report and the messages. + // TODO: this function should be pure, a context should not be needed. + outcomeReports, commitReports, err := + selectReport(context.Background(), p.lggr, p.msgHasher, p.reportCodec, p.tokenDataReader, + commitReports, maxReportSizeBytes) + if err != nil { + return ocr3types.Outcome{}, fmt.Errorf("unable to extract proofs: %w", err) + } + + execReport := cciptypes.ExecutePluginReport{ + ChainReports: outcomeReports, + } - return cciptypes.NewExecutePluginOutcome(reports).Encode() + return cciptypes.NewExecutePluginOutcome(commitReports, execReport).Encode() } func (p *Plugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[[]byte], error) { + decodedOutcome, err := cciptypes.DecodeExecutePluginOutcome(outcome) + if err != nil { + return nil, err + } - panic("implement me") + // TODO: this function should be pure, a context should not be needed. + encoded, err := p.reportCodec.Encode(context.Background(), decodedOutcome.Report) + if err != nil { + return nil, err + } + + report := []ocr3types.ReportWithInfo[[]byte]{{ + Report: encoded, + Info: nil, + }} + + return report, nil } func (p *Plugin) ShouldAcceptAttestedReport( diff --git a/execute/plugin_functions.go b/execute/plugin_functions.go index 522cf7ad2..5f72ae394 100644 --- a/execute/plugin_functions.go +++ b/execute/plugin_functions.go @@ -1,6 +1,7 @@ package execute import ( + "context" "errors" "fmt" "sort" @@ -8,6 +9,8 @@ import ( mapset "github.com/deckarep/golang-set/v2" "golang.org/x/crypto/sha3" + "github.com/smartcontractkit/chainlink-common/pkg/hashutil" + "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" @@ -315,3 +318,56 @@ func mergeCommitObservations( return results, nil } + +// markNewMessagesExecuted compares an execute plugin report with the commit report metadata and marks the new messages +// as executed. +func markNewMessagesExecuted( + execReport cciptypes.ExecutePluginReportSingleChain, report cciptypes.ExecutePluginCommitDataWithMessages, +) cciptypes.ExecutePluginCommitDataWithMessages { + // Mark new messages executed. + for i := 0; i < len(execReport.Messages); i++ { + report.ExecutedMessages = + append(report.ExecutedMessages, execReport.Messages[i].SeqNum) + } + sort.Slice( + report.ExecutedMessages, + func(i, j int) bool { return report.ExecutedMessages[i] < report.ExecutedMessages[j] }) + + return report +} + +// constructMerkleTree creates the merkle tree object from the messages in the report. +func constructMerkleTree( + ctx context.Context, + hasher cciptypes.MessageHasher, + report cciptypes.ExecutePluginCommitDataWithMessages, +) (*merklemulti.Tree[[32]byte], error) { + // Ensure we have the expected number of messages + numMsgs := int(report.SequenceNumberRange.End() - report.SequenceNumberRange.Start() + 1) + if numMsgs != len(report.Messages) { + return nil, fmt.Errorf( + "malformed report %s, unexpected number of messages: expected %d, got %d", + report.MerkleRoot.String(), numMsgs, len(report.Messages)) + } + + treeLeaves := make([][32]byte, 0) + for _, msg := range report.Messages { + if !report.SequenceNumberRange.Contains(msg.SeqNum) { + return nil, fmt.Errorf( + "malformed report, message %s sequence number %d outside of report range %s", + report.MerkleRoot.String(), msg.SeqNum, report.SequenceNumberRange) + } + if report.SourceChain != msg.SourceChain { + return nil, fmt.Errorf("malformed report, message %s for unexpected source chain: expected %d, got %d", + report.MerkleRoot.String(), report.SourceChain, msg.SourceChain) + } + leaf, err := hasher.Hash(ctx, msg) + if err != nil { + return nil, fmt.Errorf("unable to hash message (%d, %d): %w", msg.SourceChain, msg.SeqNum, err) + } + treeLeaves = append(treeLeaves, leaf) + } + + // TODO: Do not hard code the hash function, it should be derived from the message hasher. + return merklemulti.NewTree(hashutil.NewKeccak(), treeLeaves) +} diff --git a/execute/plugin_test.go b/execute/plugin_test.go index 50ad4642a..0a5be313c 100644 --- a/execute/plugin_test.go +++ b/execute/plugin_test.go @@ -2,34 +2,24 @@ package execute import ( "context" - "encoding/json" - "math" + "crypto/rand" + "encoding/hex" + "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-ccip/internal/mocks" - + "github.com/smartcontractkit/chainlink-common/pkg/hashutil" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" -) - -func TestSomethingCool(t *testing.T) { - - foo := map[cciptypes.ChainSelector]int{ - cciptypes.ChainSelector(1): 1, - cciptypes.ChainSelector(math.MaxUint64): 1, - } - js, _ := json.Marshal(foo) - t.Log(string(js)) - - b := []byte(`{"1":1,"18446744073709551615":1}`) - var bar map[cciptypes.ChainSelector]int - assert.NoError(t, json.Unmarshal(b, &bar)) - t.Log(bar) -} + "github.com/smartcontractkit/chainlink-ccip/internal/libs/slicelib" + "github.com/smartcontractkit/chainlink-ccip/internal/mocks" +) func Test_getPendingExecutedReports(t *testing.T) { tests := []struct { @@ -40,7 +30,6 @@ func Test_getPendingExecutedReports(t *testing.T) { want1 time.Time wantErr assert.ErrorAssertionFunc }{ - // TODO: Add test cases. { name: "empty", reports: nil, @@ -166,3 +155,496 @@ func Test_getPendingExecutedReports(t *testing.T) { }) } } + +// TODO: better than this +type tdr struct{} + +func (t tdr) ReadTokenData( + ctx context.Context, srcChain cciptypes.ChainSelector, num cciptypes.SeqNum) ([][]byte, error, +) { + return nil, nil +} + +// mustRandID generates a random hex ID value. +func mustRandID() string { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes)[:32] +} + +func breakCommitReport( + commitReport cciptypes.ExecutePluginCommitDataWithMessages, +) cciptypes.ExecutePluginCommitDataWithMessages { + commitReport.Messages = append(commitReport.Messages, cciptypes.CCIPMsg{ + CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + ID: mustRandID(), + SourceChain: cciptypes.ChainSelector(1), + SeqNum: cciptypes.SeqNum(999), + MsgHash: cciptypes.Bytes32{}, + }, + }) + return commitReport +} + +func makeTestCommitReport( + numMessages, srcChain, firstSeqNum, block int, timestamp int64, executed []cciptypes.SeqNum, +) cciptypes.ExecutePluginCommitDataWithMessages { + for _, e := range executed { + if e < cciptypes.SeqNum(firstSeqNum) || e > cciptypes.SeqNum(firstSeqNum+numMessages-1) { + panic("executed message out of range") + } + } + var messages []cciptypes.CCIPMsg + for i := 0; i < numMessages; i++ { + messages = append(messages, cciptypes.CCIPMsg{ + CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + ID: mustRandID(), + SourceChain: cciptypes.ChainSelector(srcChain), + SeqNum: cciptypes.SeqNum(i + firstSeqNum), + MsgHash: cciptypes.Bytes32{}, + }, + Nonce: uint64(i), + }) + } + + sequenceNumberRange := + cciptypes.NewSeqNumRange(cciptypes.SeqNum(firstSeqNum), cciptypes.SeqNum(firstSeqNum+numMessages-1)) + return cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SourceChain: cciptypes.ChainSelector(srcChain), + SequenceNumberRange: sequenceNumberRange, + Timestamp: time.UnixMilli(timestamp), + BlockNum: uint64(block), + ExecutedMessages: executed, + }, + Messages: messages, + } + +} + +// assertMerkleRoot computes the source messages merkle root, then computes a verification with the proof, then compares +// the roots. +func assertMerkleRoot( + t *testing.T, + hasher cciptypes.MessageHasher, + execReport cciptypes.ExecutePluginReportSingleChain, + commitReport cciptypes.ExecutePluginCommitDataWithMessages, +) { + keccak := hashutil.NewKeccak() + // Generate merkle root from commit report messages + var leafHashes [][32]byte + for _, msg := range commitReport.Messages { + hash, err := hasher.Hash(context.Background(), msg) + require.NoError(t, err) + leafHashes = append(leafHashes, hash) + } + tree, err := merklemulti.NewTree(keccak, leafHashes) + require.NoError(t, err) + merkleRoot := tree.Root() + + // Generate merkle root from exec report messages and proofj + ctx := context.Background() + var leaves [][32]byte + for _, msg := range execReport.Messages { + hash, err := hasher.Hash(ctx, msg) + require.NoError(t, err) + leaves = append(leaves, hash) + } + proofCast := make([][32]byte, len(execReport.Proofs)) + for i, p := range execReport.Proofs { + copy(proofCast[i][:], p[:32]) + } + var proof merklemulti.Proof[[32]byte] + proof.Hashes = proofCast + proof.SourceFlags = slicelib.BitFlagsToBools(execReport.ProofFlagBits.Int, len(leaves)+len(proofCast)-1) + recomputedMerkleRoot, err := merklemulti.VerifyComputeRoot(hashutil.NewKeccak(), + leaves, + proof) + assert.NoError(t, err) + assert.NotNil(t, recomputedMerkleRoot) + + // Compare them + assert.Equal(t, merkleRoot, recomputedMerkleRoot) +} + +func Test_selectReport(t *testing.T) { + hasher := mocks.NewMessageHasher() + codec := mocks.NewExecutePluginJSONReportCodec() + lggr := logger.Test(t) + var tokenDataReader tdr + + type args struct { + reports []cciptypes.ExecutePluginCommitDataWithMessages + maxReportSize int + } + tests := []struct { + name string + args args + expectedExecReports int + expectedCommitReports int + expectedExecThings []int + lastReportExecuted []cciptypes.SeqNum + wantErr string + }{ + { + name: "empty report", + args: args{ + reports: []cciptypes.ExecutePluginCommitDataWithMessages{}, + }, + expectedExecReports: 0, + expectedCommitReports: 0, + }, + { + name: "half report", + args: args{ + maxReportSize: 2200, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, nil), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 1, + expectedExecThings: []int{5}, + lastReportExecuted: []cciptypes.SeqNum{100, 101, 102, 103, 104}, + }, + { + name: "full report", + args: args{ + maxReportSize: 10000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, nil), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 0, + expectedExecThings: []int{10}, + }, + { + name: "two reports", + args: args{ + maxReportSize: 15000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, nil), + makeTestCommitReport(20, 2, 100, 999, 10101010101, nil), + }, + }, + expectedExecReports: 2, + expectedCommitReports: 0, + expectedExecThings: []int{10, 20}, + }, + { + name: "one and half reports", + args: args{ + maxReportSize: 8000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, nil), + makeTestCommitReport(20, 2, 100, 999, 10101010101, nil), + }, + }, + expectedExecReports: 2, + expectedCommitReports: 1, + expectedExecThings: []int{10, 10}, + lastReportExecuted: []cciptypes.SeqNum{100, 101, 102, 103, 104, 105, 106, 107, 108, 109}, + }, + { + name: "exactly one report", + args: args{ + maxReportSize: 3900, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, nil), + makeTestCommitReport(20, 2, 100, 999, 10101010101, nil), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 1, + expectedExecThings: []int{10}, + lastReportExecuted: []cciptypes.SeqNum{}, + }, + { + name: "execute remainder of partially executed report", + args: args{ + maxReportSize: 2500, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, []cciptypes.SeqNum{100, 101, 102, 103, 104}), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 0, + expectedExecThings: []int{5}, + }, + { + name: "partially execute remainder of partially executed report", + args: args{ + maxReportSize: 2000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, []cciptypes.SeqNum{100, 101, 102, 103, 104}), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 1, + expectedExecThings: []int{4}, + lastReportExecuted: []cciptypes.SeqNum{100, 101, 102, 103, 104, 105, 106, 107, 108}, + }, + { + name: "execute remainder of sparsely executed report", + args: args{ + maxReportSize: 2500, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, []cciptypes.SeqNum{100, 102, 104, 106, 108}), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 0, + expectedExecThings: []int{5}, + }, + { + name: "partially execute remainder of partially executed sparse report", + args: args{ + maxReportSize: 2000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + makeTestCommitReport(10, 1, 100, 999, 10101010101, []cciptypes.SeqNum{100, 102, 104, 106, 108}), + }, + }, + expectedExecReports: 1, + expectedCommitReports: 1, + expectedExecThings: []int{4}, + lastReportExecuted: []cciptypes.SeqNum{100, 101, 102, 103, 104, 105, 106, 107, 108}, + }, + { + name: "broken report", + args: args{ + maxReportSize: 10000, + reports: []cciptypes.ExecutePluginCommitDataWithMessages{ + breakCommitReport(makeTestCommitReport(10, 1, 101, 1000, 10101010102, nil)), + }, + }, + wantErr: "unable to build a single chain report", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + execReports, commitReports, err := + selectReport(ctx, lggr, hasher, codec, tokenDataReader, tt.args.reports, tt.args.maxReportSize) + if tt.wantErr != "" { + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + require.Len(t, execReports, tt.expectedExecReports) + require.Len(t, commitReports, tt.expectedCommitReports) + for i, execReport := range execReports { + assert.Len(t, execReport.Messages, tt.expectedExecThings[i]) + assert.Len(t, execReport.OffchainTokenData, tt.expectedExecThings[i]) + assert.NotEmptyf(t, execReport.Proofs, "Proof should not be empty.") + assertMerkleRoot(t, hasher, execReport, tt.args.reports[i]) + } + // If the last report is partially executed, the executed messages can be checked. + if len(execReports) > 0 && len(tt.lastReportExecuted) > 0 { + lastReport := commitReports[len(commitReports)-1] + assert.ElementsMatch(t, tt.lastReportExecuted, lastReport.ExecutedMessages) + } + }) + } +} + +type badHasher struct{} + +func (bh badHasher) Hash(context.Context, cciptypes.CCIPMsg) (cciptypes.Bytes32, error) { + return cciptypes.Bytes32{}, fmt.Errorf("bad hasher") +} + +type badTokenDataReader struct{} + +func (btdr badTokenDataReader) ReadTokenData( + _ context.Context, _ cciptypes.ChainSelector, _ cciptypes.SeqNum, +) ([][]byte, error) { + return nil, fmt.Errorf("bad token data reader") +} + +type badCodec struct{} + +func (bc badCodec) Encode(ctx context.Context, report cciptypes.ExecutePluginReport) ([]byte, error) { + return nil, fmt.Errorf("bad codec") +} + +func (bc badCodec) Decode(ctx context.Context, bytes []byte) (cciptypes.ExecutePluginReport, error) { + return cciptypes.ExecutePluginReport{}, fmt.Errorf("bad codec") +} + +func Test_buildSingleChainReport_Errors(t *testing.T) { + lggr := logger.Test(t) + + type args struct { + report cciptypes.ExecutePluginCommitDataWithMessages + maxMessages int + hasher cciptypes.MessageHasher + tokenDataReader TokenDataReader + codec cciptypes.ExecutePluginCodec + } + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "wrong number of messages", + wantErr: "unexpected number of messages: expected 1, got 2", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(100)), + }, + Messages: []cciptypes.CCIPMsg{ + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{}}, + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{}}, + }, + }, + }, + }, + { + name: "wrong sequence numbers", + wantErr: "sequence number 102 outside of report range [100 -> 101]", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(101)), + }, + Messages: []cciptypes.CCIPMsg{ + { + CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SeqNum: cciptypes.SeqNum(100), + }, + }, + { + CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SeqNum: cciptypes.SeqNum(102), + }, + }, + }, + }, + }, + }, + { + name: "source mismatch", + wantErr: "unexpected source chain: expected 1111, got 2222", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SourceChain: 1111, + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(100)), + }, + Messages: []cciptypes.CCIPMsg{ + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SourceChain: 2222, + SeqNum: cciptypes.SeqNum(100), + }}, + }, + }, + hasher: badHasher{}, + }, + }, + { + name: "bad hasher", + wantErr: "unable to hash message (1234567, 100): bad hasher", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SourceChain: 1234567, + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(100)), + }, + Messages: []cciptypes.CCIPMsg{ + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SourceChain: 1234567, + SeqNum: cciptypes.SeqNum(100), + }}, + }, + }, + hasher: badHasher{}, + }, + }, + { + name: "bad token data reader", + wantErr: "unable to read token data for message 100: bad token data reader", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SourceChain: 1234567, + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(100)), + }, + Messages: []cciptypes.CCIPMsg{ + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SourceChain: 1234567, + SeqNum: cciptypes.SeqNum(100), + }}, + }, + }, + tokenDataReader: badTokenDataReader{}, + }, + }, + { + name: "bad codec", + wantErr: "unable to encode report: bad codec", + args: args{ + report: cciptypes.ExecutePluginCommitDataWithMessages{ + ExecutePluginCommitData: cciptypes.ExecutePluginCommitData{ + SourceChain: 1234567, + SequenceNumberRange: cciptypes.NewSeqNumRange(cciptypes.SeqNum(100), cciptypes.SeqNum(100)), + }, + Messages: []cciptypes.CCIPMsg{ + {CCIPMsgBaseDetails: cciptypes.CCIPMsgBaseDetails{ + SourceChain: 1234567, + SeqNum: cciptypes.SeqNum(100), + }}, + }, + }, + codec: badCodec{}, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Select hasher mock. + var resolvedHasher cciptypes.MessageHasher + if tt.args.hasher != nil { + resolvedHasher = tt.args.hasher + } else { + resolvedHasher = mocks.NewMessageHasher() + } + + // Select token data reader mock. + var resolvedTokenDataReader TokenDataReader + if tt.args.tokenDataReader != nil { + resolvedTokenDataReader = tt.args.tokenDataReader + } else { + resolvedTokenDataReader = tdr{} + } + + // Select codec mock. + var resolvedCodec cciptypes.ExecutePluginCodec + if tt.args.codec != nil { + resolvedCodec = tt.args.codec + } else { + resolvedCodec = mocks.NewExecutePluginJSONReportCodec() + } + + ctx := context.Background() + execReport, size, err := buildSingleChainReport( + ctx, lggr, resolvedHasher, resolvedTokenDataReader, resolvedCodec, tt.args.report, tt.args.maxMessages) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + fmt.Println(execReport, size, err) + }) + } +} diff --git a/go.mod b/go.mod index 06f7be7ee..fcb0b1a86 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21.7 require ( github.com/deckarep/golang-set/v2 v2.6.0 - github.com/smartcontractkit/chainlink-common v0.1.7-0.20240625074419-c278d083facf + github.com/smartcontractkit/chainlink-common v0.1.7-0.20240626200607-030cd3975e55 github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.24.0 diff --git a/go.sum b/go.sum index c71295378..8bc997bc9 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,12 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6Ng github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240624191350-e000707a4cf7 h1:3GjsV3Daa9POHEfUgjC0pbnxsL/iL/nbhVaju1TZpVU= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240624191350-e000707a4cf7/go.mod h1:L32xvCpk84Nglit64OhySPMP1tM3TTBK7Tw0qZl7Sd4= github.com/smartcontractkit/chainlink-common v0.1.7-0.20240625074419-c278d083facf h1:d9AS/K8RSVG64USb20N/U7RaPOsYPcmuLGJq7iE+caM= github.com/smartcontractkit/chainlink-common v0.1.7-0.20240625074419-c278d083facf/go.mod h1:L32xvCpk84Nglit64OhySPMP1tM3TTBK7Tw0qZl7Sd4= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240626200607-030cd3975e55 h1:bKcZAqP+UuXijzswFr4WssWuuhaXFkJF8Jbg8vhOWPY= +github.com/smartcontractkit/chainlink-common v0.1.7-0.20240626200607-030cd3975e55/go.mod h1:L32xvCpk84Nglit64OhySPMP1tM3TTBK7Tw0qZl7Sd4= github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c h1:lIyMbTaF2H0Q71vkwZHX/Ew4KF2BxiKhqEXwF8rn+KI= github.com/smartcontractkit/libocr v0.0.0-20240419185742-fd3cab206b2c/go.mod h1:fb1ZDVXACvu4frX3APHZaEBp0xi1DIm34DcA0CwTsZM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/mocks/reportcodec.go b/internal/mocks/reportcodec.go index b29c93603..2b686efdf 100644 --- a/internal/mocks/reportcodec.go +++ b/internal/mocks/reportcodec.go @@ -22,3 +22,19 @@ func (c CommitPluginJSONReportCodec) Decode(ctx context.Context, bytes []byte) ( err := json.Unmarshal(bytes, &report) return report, err } + +type ExecutePluginJSONReportCodec struct{} + +func NewExecutePluginJSONReportCodec() *ExecutePluginJSONReportCodec { + return &ExecutePluginJSONReportCodec{} +} + +func (c ExecutePluginJSONReportCodec) Encode(_ context.Context, report cciptypes.ExecutePluginReport) ([]byte, error) { + return json.Marshal(report) +} + +func (c ExecutePluginJSONReportCodec) Decode(_ context.Context, bytes []byte) (cciptypes.ExecutePluginReport, error) { + report := cciptypes.ExecutePluginReport{} + err := json.Unmarshal(bytes, &report) + return report, err +}