diff --git a/.github/actions/setup-postgres/.env b/.github/actions/setup-postgres/.env new file mode 100644 index 000000000..47ed8d9bc --- /dev/null +++ b/.github/actions/setup-postgres/.env @@ -0,0 +1,5 @@ +POSTGRES_USER=postgres +POSTGRES_OPTIONS="-c max_connections=1000 -c shared_buffers=2GB -c log_lock_waits=true" +POSTGRES_PASSWORD=postgres +POSTGRES_DB=chainlink_test +POSTGRES_HOST_AUTH_METHOD=trust diff --git a/.github/actions/setup-postgres/action.yml b/.github/actions/setup-postgres/action.yml new file mode 100644 index 000000000..45bfba596 --- /dev/null +++ b/.github/actions/setup-postgres/action.yml @@ -0,0 +1,18 @@ +name: Setup Postgresql +description: Setup postgres docker container via docker-compose, allowing usage of a custom command, see https://github.com/orgs/community/discussions/26688 +inputs: + base-path: + description: Path to the base of the repo + required: false + default: . +runs: + using: composite + steps: + - name: Start postgres service + run: docker compose up -d + shell: bash + working-directory: ${{ inputs.base-path }}/.github/actions/setup-postgres + - name: Wait for postgres service to be healthy + run: ./wait-for-healthy-postgres.sh + shell: bash + working-directory: ${{ inputs.base-path }}/.github/actions/setup-postgres diff --git a/.github/actions/setup-postgres/bin/pg_dump b/.github/actions/setup-postgres/bin/pg_dump new file mode 100755 index 000000000..d8135ad82 --- /dev/null +++ b/.github/actions/setup-postgres/bin/pg_dump @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# +# This script acts as a docker replacement around pg_dump so that developers can +# run DB involved tests locally without having postgres installed. +# +# Installation: +# - Make sure that your PATH doesn't already contain a postgres installation +# - Add this script to your PATH +# +# Usage: +# You should be able to setup your test db via: +# - go build -o chainlink.test . # Build the chainlink binary to run test db prep commands from +# - export CL_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/chainlink_test?sslmode=disable" +# - pushd .github/actions/setup-postgres/ # Navigate to the setup-postgres action so we can spin up a docker postgres +# instance +# - docker compose up # Spin up postgres +# - ./chainlink.test local db preparetest # Run the db migration, which will shell out to our pg_dump wrapper too. +# - popd +# - go test -timeout 30s ./core/services/workflows/... -v # Run tests that use the database + +cd "$(dirname "$0")" || exit + +docker compose exec -T postgres pg_dump "$@" diff --git a/.github/actions/setup-postgres/docker-compose.yml b/.github/actions/setup-postgres/docker-compose.yml new file mode 100644 index 000000000..23f8d82b9 --- /dev/null +++ b/.github/actions/setup-postgres/docker-compose.yml @@ -0,0 +1,15 @@ +name: gha_postgres +services: + postgres: + ports: + - "5432:5432" + container_name: cl_pg + image: postgres:14-alpine + command: postgres ${POSTGRES_OPTIONS} + env_file: + - .env + healthcheck: + test: "pg_isready -d ${POSTGRES_DB} -U ${POSTGRES_USER}" + interval: 2s + timeout: 5s + retries: 5 diff --git a/.github/actions/setup-postgres/wait-for-healthy-postgres.sh b/.github/actions/setup-postgres/wait-for-healthy-postgres.sh new file mode 100755 index 000000000..438cfbaff --- /dev/null +++ b/.github/actions/setup-postgres/wait-for-healthy-postgres.sh @@ -0,0 +1,25 @@ +#!/bin/bash +RETRIES=10 + +until [ $RETRIES -eq 0 ]; do + DOCKER_OUTPUT=$(docker compose ps postgres --status running --format json) + JSON_TYPE=$(echo "$DOCKER_OUTPUT" | jq -r 'type') + + if [ "$JSON_TYPE" == "array" ]; then + HEALTH_STATUS=$(echo "$DOCKER_OUTPUT" | jq -r '.[0].Health') + elif [ "$JSON_TYPE" == "object" ]; then + HEALTH_STATUS=$(echo "$DOCKER_OUTPUT" | jq -r '.Health') + else + HEALTH_STATUS="Unknown JSON type: $JSON_TYPE" + fi + + echo "postgres health status: $HEALTH_STATUS" + if [ "$HEALTH_STATUS" == "healthy" ]; then + exit 0 + fi + + echo "Waiting for postgres server, $((RETRIES--)) remaining attempts..." + sleep 2 +done + +exit 1 diff --git a/.github/workflows/ccip-integration-test.yml b/.github/workflows/ccip-integration-test.yml new file mode 100644 index 000000000..15dc0e3e3 --- /dev/null +++ b/.github/workflows/ccip-integration-test.yml @@ -0,0 +1,68 @@ +name: "Run CCIP OCR3 Integration Test" + +on: + pull_request: + push: + branches: + - 'ccip-develop' + +jobs: + integration-test-ccip-ocr3: + env: + # We explicitly have this env var not be "CL_DATABASE_URL" to avoid having it be used by core related tests + # when they should not be using it, while still allowing us to DRY up the setup + DB_URL: postgresql://postgres:postgres@localhost:5432/chainlink_test?sslmode=disable + + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.22.5'] + steps: + - name: Checkout the repo + uses: actions/checkout@9bb56186c3b09b4f86b1c65136769dd318469633 # v4.1.2 + - name: Setup Go ${{ matrix.go-version }} + uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 + with: + go-version: ${{ matrix.go-version }} + - name: Display Go version + run: go version + - name: Clone CCIP repo + run: | + git clone https://github.com/smartcontractkit/ccip.git + cd ccip + git fetch + git checkout ccip-develop + - name: Update chainlink-ccip dependency in ccip + run: | + cd ccip + go get github.com/smartcontractkit/chainlink-ccip@${{ github.event.pull_request.head.sha }} + make gomodtidy + - name: Setup Postgres + uses: ./.github/actions/setup-postgres + - name: Download Go vendor packages + run: | + cd ccip + go mod download + - name: Build binary + run: | + cd ccip + go build -o ccip.test . + - name: Setup DB + run: | + cd ccip + ./ccip.test local db preparetest + env: + CL_DATABASE_URL: ${{ env.DB_URL }} + - name: Run ccip ocr3 integration test + run: | + cd ccip + go test -v -timeout 3m -run "^TestIntegration_OCR3Nodes$" ./core/capabilities/ccip/ccip_integration_tests + EXITCODE=${PIPESTATUS[0]} + if [ $EXITCODE -ne 0 ]; then + echo "Integration test failed" + else + echo "Integration test passed!" + fi + exit $EXITCODE + env: + CL_DATABASE_URL: ${{ env.DB_URL }} diff --git a/commit/plugin.go b/commit/plugin.go index 5c8f69469..76f58ffb2 100644 --- a/commit/plugin.go +++ b/commit/plugin.go @@ -226,11 +226,11 @@ func (p *Plugin) ValidateObservation( return fmt.Errorf("validate observer %d reading eligibility: %w", ao.Observer, err) } - if err := validateObservedTokenPrices(obs.TokenPrices); err != nil { + if err := ValidateObservedTokenPrices(obs.TokenPrices); err != nil { return fmt.Errorf("validate token prices: %w", err) } - if err := validateObservedGasPrices(obs.GasPrices); err != nil { + if err := ValidateObservedGasPrices(obs.GasPrices); err != nil { return fmt.Errorf("validate gas prices: %w", err) } @@ -347,7 +347,7 @@ func (p *Plugin) ShouldTransmitAcceptedReport( return false, fmt.Errorf("decode commit plugin report: %w", err) } - isValid, err := validateMerkleRootsState(ctx, p.lggr, decodedReport, p.ccipReader) + isValid, err := ValidateMerkleRootsState(ctx, p.lggr, decodedReport, p.ccipReader) if !isValid { return false, nil } diff --git a/commit/plugin_e2e_test.go b/commit/plugin_e2e_test.go index c9e56302f..ac87c5959 100644 --- a/commit/plugin_e2e_test.go +++ b/commit/plugin_e2e_test.go @@ -16,6 +16,7 @@ import ( libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/chainlink-ccip/chainconfig" @@ -433,6 +434,10 @@ func setupHomeChainPoller(lggr logger.Logger, chainConfigInfos []reader.ChainCon // to prevent linting error because of logging after finishing tests, we close the poller after each test, having // lower polling interval make it catch up faster 10*time.Millisecond, + types.BoundContract{ + Address: "0xCCIPConfigFakeAddress", + Name: consts.ContractNameCCIPConfig, + }, ) return homeChain diff --git a/commit/plugin_functions.go b/commit/plugin_functions.go index cca9751c0..650ce36e9 100644 --- a/commit/plugin_functions.go +++ b/commit/plugin_functions.go @@ -559,7 +559,7 @@ func validateObserverReadingEligibility( return nil } -func validateObservedTokenPrices(tokenPrices []cciptypes.TokenPrice) error { +func ValidateObservedTokenPrices(tokenPrices []cciptypes.TokenPrice) error { tokensWithPrice := mapset.NewSet[types.Account]() for _, t := range tokenPrices { if tokensWithPrice.Contains(t.TokenID) { @@ -575,7 +575,7 @@ func validateObservedTokenPrices(tokenPrices []cciptypes.TokenPrice) error { return nil } -func validateObservedGasPrices(gasPrices []cciptypes.GasPriceChain) error { +func ValidateObservedGasPrices(gasPrices []cciptypes.GasPriceChain) error { // Duplicate gas prices must not appear for the same chain and must not be empty. gasPriceChains := mapset.NewSet[cciptypes.ChainSelector]() for _, g := range gasPrices { @@ -591,8 +591,8 @@ func validateObservedGasPrices(gasPrices []cciptypes.GasPriceChain) error { return nil } -// validateMerkleRootsState merkle roots seq nums validation by comparing with on-chain state. -func validateMerkleRootsState( +// ValidateMerkleRootsState merkle roots seq nums validation by comparing with on-chain state. +func ValidateMerkleRootsState( ctx context.Context, lggr logger.Logger, report cciptypes.CommitPluginReport, diff --git a/commit/plugin_functions_test.go b/commit/plugin_functions_test.go index 5fc2c84f7..87e5c6b69 100644 --- a/commit/plugin_functions_test.go +++ b/commit/plugin_functions_test.go @@ -647,7 +647,7 @@ func Test_validateObservedTokenPrices(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := validateObservedTokenPrices(tc.tokenPrices) + err := ValidateObservedTokenPrices(tc.tokenPrices) if tc.expErr { assert.Error(t, err) return @@ -700,7 +700,7 @@ func Test_validateObservedGasPrices(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := validateObservedGasPrices(tc.gasPrices) + err := ValidateObservedGasPrices(tc.gasPrices) if tc.expErr { assert.Error(t, err) return @@ -1546,7 +1546,7 @@ func Test_validateMerkleRootsState(t *testing.T) { chains = append(chains, snc.ChainSel) } reader.On("NextSeqNum", ctx, chains).Return(tc.onchainNextSeqNums, nil) - valid, err := validateMerkleRootsState(ctx, lggr, rep, reader) + valid, err := ValidateMerkleRootsState(ctx, lggr, rep, reader) if tc.expErr { assert.Error(t, err) return diff --git a/commitrmnocb/README.md b/commitrmnocb/README.md new file mode 100644 index 000000000..826b54091 --- /dev/null +++ b/commitrmnocb/README.md @@ -0,0 +1,53 @@ +# OCR3 Commit Plugin + +## Context +The purpose of the OCR3 Commit Plugin is to write reports to a configured destination chain. These reports +contain metadata of cross-chain messages, from a set of source chains, that can be executed on the destination chain. + +## Commit Plugin Design + +The plugin is implemented as a state machine, and moves from state to state each round. There are 3 states: +1. SelectingIntervalsForReport + - Determine intervals to be included in the next report +2. BuildingReport + - Build a report from the intervals determined in the previous round +3. WaitingForReportTransmission + - Check if the maximum committed sequence numbers on the dest chain have changed since generating the most + recent report, i.e. check if the report has been committed. + - If the maximum committed sequence numbers have changed (i.e. the report has been committed) or the maximum + number of check attempts have been exhausted, move to the SelectingIntervalsForReport state and generate a new + report. + - If the maximum committed sequence numbers have _not_ changed (i.e. the report is still in-flight) and the + maximum number of check attempts are not been exhausted, move to the WaitingForReportTransmission state in order + to check again. + +This approach leads to a clear separation of concerns and addresses the complications that can arise if a report +is not successfully transmitted (as we explicitly only continue once we know the previous report has been committed). +In this design, full messages are no longer in the observations, only merkle roots and intervals are. This reduces the +size of observations, which reduces bandwidth and improves performance. + +This is the state machine diagram. States are in boxes, outcomes are within arrows. + + Start + | + V + ------------------------------- + | SelectingIntervalsForReport | <---------| + ------------------------------- | + | | + ReportIntervalsSelected | + | | + V | + ------------------ | + | BuildingReport | -- ReportEmpty --->| + ------------------ | + | ReportTransmitted + ReportGenerated or + | ReportNotTransmitted + V | + -------------------------------- | + | WaitingForReportTransmission | -------->| + -------------------------------- + | ^ + | | + ReportNotYetTransmitted diff --git a/commitrmnocb/chain_support.go b/commitrmnocb/chain_support.go new file mode 100644 index 000000000..f271c0882 --- /dev/null +++ b/commitrmnocb/chain_support.go @@ -0,0 +1,73 @@ +package commitrmnocb + +import ( + "fmt" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/internal/libs/slicelib" + "github.com/smartcontractkit/chainlink-ccip/internal/reader" +) + +// ChainSupport contains functions that enable an oracle to determine which chains are accessible by itself and +// other oracles +type ChainSupport interface { + // SupportedChains returns the set of chains that the given Oracle is configured to access + SupportedChains(oracleID commontypes.OracleID) (mapset.Set[cciptypes.ChainSelector], error) + + // SupportsDestChain returns true if the given oracle supports the dest chain, returns false otherwise + SupportsDestChain(oracle commontypes.OracleID) (bool, error) + + // KnownSourceChainsSlice returns a list of all known source chains + KnownSourceChainsSlice() ([]cciptypes.ChainSelector, error) +} + +type CCIPChainSupport struct { + lggr logger.Logger + homeChain reader.HomeChain + oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID + nodeID commontypes.OracleID + destChain cciptypes.ChainSelector +} + +func (c CCIPChainSupport) KnownSourceChainsSlice() ([]cciptypes.ChainSelector, error) { + knownSourceChains, err := c.homeChain.GetKnownCCIPChains() + if err != nil { + c.lggr.Errorw("error getting known chains", "err", err) + return nil, fmt.Errorf("error getting known chains: %w", err) + } + knownSourceChainsSlice := knownSourceChains.ToSlice() + return slicelib.Filter(knownSourceChainsSlice, func(ch cciptypes.ChainSelector) bool { return ch != c.destChain }), nil +} + +// SupportedChains returns the set of chains that the given Oracle is configured to access +func (c CCIPChainSupport) SupportedChains(oracleID commontypes.OracleID) (mapset.Set[cciptypes.ChainSelector], error) { + p2pID, exists := c.oracleIDToP2pID[oracleID] + if !exists { + return nil, fmt.Errorf("oracle ID %d not found in oracleIDToP2pID", c.nodeID) + } + supportedChains, err := c.homeChain.GetSupportedChainsForPeer(p2pID) + if err != nil { + c.lggr.Warnw("error getting supported chains", err) + return mapset.NewSet[cciptypes.ChainSelector](), fmt.Errorf("error getting supported chains: %w", err) + } + + return supportedChains, nil +} + +// SupportsDestChain returns true if the given oracle supports the dest chain, returns false otherwise +func (c CCIPChainSupport) SupportsDestChain(oracle commontypes.OracleID) (bool, error) { + destChainConfig, err := c.homeChain.GetChainConfig(c.destChain) + if err != nil { + return false, fmt.Errorf("get chain config: %w", err) + } + return destChainConfig.SupportedNodes.Contains(c.oracleIDToP2pID[oracle]), nil +} + +// Interface compliance check +var _ ChainSupport = (*CCIPChainSupport)(nil) diff --git a/commitrmnocb/factory.go b/commitrmnocb/factory.go new file mode 100644 index 000000000..6fe4996b2 --- /dev/null +++ b/commitrmnocb/factory.go @@ -0,0 +1,164 @@ +package commitrmnocb + +import ( + "context" + "errors" + "fmt" + "math/big" + + "google.golang.org/grpc" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" + "github.com/smartcontractkit/chainlink-common/pkg/types" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + + "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/pluginconfig" +) + +const maxReportTransmissionCheckAttempts = 5 +const maxQueryLength = 1024 * 1024 // 1MB + +// PluginFactoryConstructor implements common OCR3ReportingPluginClient and is used for initializing a plugin factory +// and a validation service. +type PluginFactoryConstructor struct{} + +func NewPluginFactoryConstructor() *PluginFactoryConstructor { + return &PluginFactoryConstructor{} +} +func (p PluginFactoryConstructor) NewReportingPluginFactory( + ctx context.Context, + config core.ReportingPluginServiceConfig, + grpcProvider grpc.ClientConnInterface, + pipelineRunner core.PipelineRunnerService, + telemetry core.TelemetryService, + errorLog core.ErrorLog, + capRegistry core.CapabilitiesRegistry, + keyValueStore core.KeyValueStore, + relayerSet core.RelayerSet, +) (core.OCR3ReportingPluginFactory, error) { + return nil, errors.New("unimplemented") +} + +func (p PluginFactoryConstructor) NewValidationService(ctx context.Context) (core.ValidationService, error) { + panic("implement me") +} + +// PluginFactory implements common ReportingPluginFactory and is used for (re-)initializing commit plugin instances. +type PluginFactory struct { + lggr logger.Logger + ocrConfig reader.OCR3ConfigWithMeta + commitCodec cciptypes.CommitPluginCodec + msgHasher cciptypes.MessageHasher + homeChainReader reader.HomeChain + contractReaders map[cciptypes.ChainSelector]types.ContractReader + chainWriters map[cciptypes.ChainSelector]types.ChainWriter +} + +func NewPluginFactory( + lggr logger.Logger, + ocrConfig reader.OCR3ConfigWithMeta, + commitCodec cciptypes.CommitPluginCodec, + msgHasher cciptypes.MessageHasher, + homeChainReader reader.HomeChain, + contractReaders map[cciptypes.ChainSelector]types.ContractReader, + chainWriters map[cciptypes.ChainSelector]types.ChainWriter, +) *PluginFactory { + return &PluginFactory{ + lggr: lggr, + ocrConfig: ocrConfig, + commitCodec: commitCodec, + msgHasher: msgHasher, + homeChainReader: homeChainReader, + contractReaders: contractReaders, + chainWriters: chainWriters, + } +} + +func (p *PluginFactory) NewReportingPlugin(config ocr3types.ReportingPluginConfig, +) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) { + offchainConfig, err := pluginconfig.DecodeCommitOffchainConfig(config.OffchainConfig) + if err != nil { + return nil, ocr3types.ReportingPluginInfo{}, fmt.Errorf("failed to decode commit offchain config: %w", err) + } + + if err = offchainConfig.Validate(); err != nil { + return nil, ocr3types.ReportingPluginInfo{}, fmt.Errorf("failed to validate commit offchain config: %w", err) + } + + var oracleIDToP2PID = make(map[commontypes.OracleID]ragep2ptypes.PeerID) + for oracleID, p2pID := range p.ocrConfig.Config.P2PIds { + oracleIDToP2PID[commontypes.OracleID(oracleID)] = p2pID + } + + onChainTokenPricesReader := reader.NewOnchainTokenPricesReader( + reader.TokenPriceConfig{ // TODO: Inject config + StaticPrices: map[ocr2types.Account]big.Int{}, + }, + nil, // TODO: Inject this + ) + ccipReader := reader.NewCCIPChainReader( + p.lggr, + p.contractReaders, + p.chainWriters, + p.ocrConfig.Config.ChainSelector, + ) + return NewPlugin( + context.Background(), + config.OracleID, + oracleIDToP2PID, + pluginconfig.CommitPluginConfig{ + DestChain: p.ocrConfig.Config.ChainSelector, + NewMsgScanBatchSize: merklemulti.MaxNumberTreeLeaves, + MaxReportTransmissionCheckAttempts: maxReportTransmissionCheckAttempts, + OffchainConfig: offchainConfig, + }, + ccipReader, + onChainTokenPricesReader, + p.commitCodec, + p.msgHasher, + p.lggr, + p.homeChainReader, + config, + ), ocr3types.ReportingPluginInfo{ + Name: "CCIPRoleCommit", + Limits: ocr3types.ReportingPluginLimits{ + MaxQueryLength: maxQueryLength, + MaxObservationLength: 20_000, // 20kB + MaxOutcomeLength: 10_000, // 10kB + MaxReportLength: 10_000, // 10kB + MaxReportCount: 10, + }, + }, nil +} + +func (p PluginFactory) Name() string { + panic("implement me") +} + +func (p PluginFactory) Start(ctx context.Context) error { + panic("implement me") +} + +func (p PluginFactory) Close() error { + panic("implement me") +} + +func (p PluginFactory) Ready() error { + panic("implement me") +} + +func (p PluginFactory) HealthReport() map[string]error { + panic("implement me") +} + +// Interface compatibility checks. +var _ core.OCR3ReportingPluginClient = &PluginFactoryConstructor{} +var _ core.OCR3ReportingPluginFactory = &PluginFactory{} diff --git a/commitrmnocb/metrics.go b/commitrmnocb/metrics.go new file mode 100644 index 000000000..ee93ab520 --- /dev/null +++ b/commitrmnocb/metrics.go @@ -0,0 +1 @@ +package commitrmnocb diff --git a/commitrmnocb/observation.go b/commitrmnocb/observation.go new file mode 100644 index 000000000..6068d54ee --- /dev/null +++ b/commitrmnocb/observation.go @@ -0,0 +1,221 @@ +package commitrmnocb + +import ( + "context" + "encoding/hex" + "fmt" + "sort" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/hashutil" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +func (p *Plugin) ObservationQuorum(_ ocr3types.OutcomeContext, _ types.Query) (ocr3types.Quorum, error) { + // Across all chains we require at least 2F+1 observations. + return ocr3types.QuorumTwoFPlusOne, nil +} + +func (p *Plugin) Observation( + ctx context.Context, outCtx ocr3types.OutcomeContext, _ types.Query, +) (types.Observation, error) { + previousOutcome, nextState := p.decodeOutcome(outCtx.PreviousOutcome) + + observation := Observation{} + switch nextState { + case SelectingRangesForReport: + offRampNextSeqNums := p.observer.ObserveOffRampNextSeqNums(ctx) + observation = Observation{ + // TODO: observe OnRamp max seq nums. The use of offRampNextSeqNums here effectively disables batching, + // e.g. the ranges selected for each chain will be [x, x] (e.g. [46, 46]), which means reports will only + // contain one message per chain. Querying the OnRamp contract requires changes to reader.CCIP, which will + // need to be done in a future change. + OnRampMaxSeqNums: offRampNextSeqNums, + OffRampNextSeqNums: offRampNextSeqNums, + FChain: p.observer.ObserveFChain(), + } + + case BuildingReport: + observation = Observation{ + MerkleRoots: p.observer.ObserveMerkleRoots(ctx, previousOutcome.RangesSelectedForReport), + GasPrices: p.observer.ObserveGasPrices(ctx), + TokenPrices: p.observer.ObserveTokenPrices(ctx), + FChain: p.observer.ObserveFChain(), + } + + case WaitingForReportTransmission: + observation = Observation{ + OffRampNextSeqNums: p.observer.ObserveOffRampNextSeqNums(ctx), + FChain: p.observer.ObserveFChain(), + } + + default: + p.lggr.Errorw("Unexpected state", "state", nextState) + return observation.Encode() + } + + p.lggr.Infow("Observation", "observation", observation) + return observation.Encode() +} + +type Observer interface { + // ObserveOffRampNextSeqNums observes the next sequence numbers for each source chain from the OffRamp + ObserveOffRampNextSeqNums(ctx context.Context) []plugintypes.SeqNumChain + + // ObserveMerkleRoots computes the merkle roots for the given sequence number ranges + ObserveMerkleRoots(ctx context.Context, ranges []plugintypes.ChainRange) []cciptypes.MerkleRootChain + + ObserveTokenPrices(ctx context.Context) []cciptypes.TokenPrice + + ObserveGasPrices(ctx context.Context) []cciptypes.GasPriceChain + + ObserveFChain() map[cciptypes.ChainSelector]int +} + +type ObserverImpl struct { + lggr logger.Logger + homeChain reader.HomeChain + nodeID commontypes.OracleID + chainSupport ChainSupport + ccipReader reader.CCIP + msgHasher cciptypes.MessageHasher +} + +// ObserveOffRampNextSeqNums observes the next sequence numbers for each source chain from the OffRamp +func (o ObserverImpl) ObserveOffRampNextSeqNums(ctx context.Context) []plugintypes.SeqNumChain { + supportsDestChain, err := o.chainSupport.SupportsDestChain(o.nodeID) + if err != nil { + o.lggr.Warnw("call to SupportsDestChain failed", "err", err) + return nil + } + + if !supportsDestChain { + return nil + } + + sourceChains, err := o.chainSupport.KnownSourceChainsSlice() + if err != nil { + o.lggr.Warnw("call to KnownSourceChainsSlice failed", "err", err) + return nil + } + offRampNextSeqNums, err := o.ccipReader.NextSeqNum(ctx, sourceChains) + if err != nil { + o.lggr.Warnw("call to NextSeqNum failed", "err", err) + return nil + } + + if len(offRampNextSeqNums) != len(sourceChains) { + o.lggr.Errorf("call to NextSeqNum returned unexpected number of seq nums, got %d, expected %d", + len(offRampNextSeqNums), len(sourceChains)) + return nil + } + + result := make([]plugintypes.SeqNumChain, len(sourceChains)) + for i := range sourceChains { + result[i] = plugintypes.SeqNumChain{ChainSel: sourceChains[i], SeqNum: offRampNextSeqNums[i]} + } + + return result +} + +// ObserveMerkleRoots computes the merkle roots for the given sequence number ranges +func (o ObserverImpl) ObserveMerkleRoots( + ctx context.Context, + ranges []plugintypes.ChainRange, +) []cciptypes.MerkleRootChain { + var roots []cciptypes.MerkleRootChain + supportedChains, err := o.chainSupport.SupportedChains(o.nodeID) + if err != nil { + o.lggr.Warnw("call to supportedChains failed", "err", err) + return nil + } + + for _, chainRange := range ranges { + if supportedChains.Contains(chainRange.ChainSel) { + msgs, err := o.ccipReader.MsgsBetweenSeqNums(ctx, chainRange.ChainSel, chainRange.SeqNumRange) + if err != nil { + o.lggr.Warnw("call to MsgsBetweenSeqNums failed", "err", err) + continue + } + root, err := o.computeMerkleRoot(ctx, msgs) + if err != nil { + o.lggr.Warnw("call to computeMerkleRoot failed", "err", err) + continue + } + merkleRoot := cciptypes.MerkleRootChain{ + ChainSel: chainRange.ChainSel, + SeqNumsRange: chainRange.SeqNumRange, + MerkleRoot: root, + } + roots = append(roots, merkleRoot) + } + } + + return roots +} + +// computeMerkleRoot computes the merkle root of a list of messages +func (o ObserverImpl) computeMerkleRoot(ctx context.Context, msgs []cciptypes.Message) (cciptypes.Bytes32, error) { + var hashes [][32]byte + sort.Slice(msgs, func(i, j int) bool { return msgs[i].Header.SequenceNumber < msgs[j].Header.SequenceNumber }) + + for i, msg := range msgs { + // Assert there are no sequence number gaps in msgs + if i > 0 { + if msg.Header.SequenceNumber != msgs[i-1].Header.SequenceNumber+1 { + return [32]byte{}, fmt.Errorf("found non-consecutive sequence numbers when computing merkle root, "+ + "gap between sequence nums %d and %d, messages: %v", msgs[i-1].Header.SequenceNumber, + msg.Header.SequenceNumber, msgs) + } + } + + msgHash, err := o.msgHasher.Hash(ctx, msg) + if err != nil { + msgID := hex.EncodeToString(msg.Header.MessageID[:]) + o.lggr.Warnw("failed to hash message", "msg", msg, "msg_id", msgID, "err", err) + return cciptypes.Bytes32{}, fmt.Errorf("failed to hash message with id %s: %w", msgID, err) + } + + hashes = append(hashes, msgHash) + } + + // TODO: Do not hard code the hash function, it should be derived from the message hasher + tree, err := merklemulti.NewTree(hashutil.NewKeccak(), hashes) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to construct merkle tree from %d leaves: %w", len(hashes), err) + } + + root := tree.Root() + o.lggr.Infow("computeMerkleRoot: Computed merkle root", "root", hex.EncodeToString(root[:])) + + return root, nil +} + +func (o ObserverImpl) ObserveTokenPrices(ctx context.Context) []cciptypes.TokenPrice { + return []cciptypes.TokenPrice{} +} + +func (o ObserverImpl) ObserveGasPrices(ctx context.Context) []cciptypes.GasPriceChain { + return []cciptypes.GasPriceChain{} +} + +func (o ObserverImpl) ObserveFChain() map[cciptypes.ChainSelector]int { + fChain, err := o.homeChain.GetFChain() + if err != nil { + // TODO: metrics + o.lggr.Warnw("call to GetFChain failed", "err", err) + return map[cciptypes.ChainSelector]int{} + } + return fChain +} + +// Interface compliance check +var _ Observer = (*ObserverImpl)(nil) diff --git a/commitrmnocb/observation_test.go b/commitrmnocb/observation_test.go new file mode 100644 index 000000000..50ec5b5c5 --- /dev/null +++ b/commitrmnocb/observation_test.go @@ -0,0 +1,592 @@ +package commitrmnocb + +import ( + "context" + "encoding/hex" + "fmt" + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/internal/mocks" + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +func Test_Observation(t *testing.T) { + merkleRoots := []cciptypes.MerkleRootChain{ + { + ChainSel: 1, + SeqNumsRange: [2]cciptypes.SeqNum{5, 78}, + MerkleRoot: [32]byte{1}, + }, + } + gasPrices := []cciptypes.GasPriceChain{ + { + GasPrice: cciptypes.NewBigIntFromInt64(99), + ChainSel: 8, + }, + } + tokenPrices := []cciptypes.TokenPrice{ + { + TokenID: "token23", + Price: cciptypes.NewBigIntFromInt64(80761), + }, + } + offRampNextSeqNums := []plugintypes.SeqNumChain{ + { + ChainSel: 456, + SeqNum: 9987, + }, + } + fChain := map[cciptypes.ChainSelector]int{ + 872: 3, + } + + testCases := []struct { + name string + previousOutcome Outcome + merkleRoots []cciptypes.MerkleRootChain + gasPrices []cciptypes.GasPriceChain + tokenPrices []cciptypes.TokenPrice + offRampNextSeqNums []plugintypes.SeqNumChain + fChain map[cciptypes.ChainSelector]int + expObs Observation + }{ + { + name: "SelectingRangesForReport observation", + previousOutcome: Outcome{ + OutcomeType: ReportTransmitted, + }, + merkleRoots: merkleRoots, + gasPrices: gasPrices, + tokenPrices: tokenPrices, + offRampNextSeqNums: offRampNextSeqNums, + fChain: fChain, + expObs: Observation{ + OnRampMaxSeqNums: offRampNextSeqNums, + OffRampNextSeqNums: offRampNextSeqNums, + FChain: fChain, + }, + }, + { + name: "BuildingReport observation", + previousOutcome: Outcome{ + OutcomeType: ReportIntervalsSelected, + }, + merkleRoots: merkleRoots, + gasPrices: gasPrices, + tokenPrices: tokenPrices, + offRampNextSeqNums: offRampNextSeqNums, + fChain: fChain, + expObs: Observation{ + MerkleRoots: merkleRoots, + GasPrices: gasPrices, + TokenPrices: tokenPrices, + FChain: fChain, + }, + }, + { + name: "WaitingForReportTransmission observation", + previousOutcome: Outcome{ + OutcomeType: ReportInFlight, + }, + merkleRoots: merkleRoots, + gasPrices: gasPrices, + tokenPrices: tokenPrices, + offRampNextSeqNums: offRampNextSeqNums, + fChain: fChain, + expObs: Observation{ + OffRampNextSeqNums: offRampNextSeqNums, + FChain: fChain, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + observer := mocks.NewObserver() + observer.On( + "ObserveOffRampNextSeqNums", ctx, + ).Return(tc.offRampNextSeqNums) + observer.On( + "ObserveMerkleRoots", ctx, mock.Anything, + ).Return(tc.merkleRoots) + observer.On( + "ObserveTokenPrices", ctx, + ).Return(tc.tokenPrices) + observer.On( + "ObserveGasPrices", ctx, + ).Return(tc.gasPrices) + observer.On("ObserveFChain").Return(tc.fChain) + + p := Plugin{ + lggr: logger.Test(t), + observer: observer, + } + + previousOutcomeEncoded, err := tc.previousOutcome.Encode() + assert.NoError(t, err) + + result, err := p.Observation( + ctx, + ocr3types.OutcomeContext{PreviousOutcome: previousOutcomeEncoded}, + types.Query{}, + ) + assert.NoError(t, err) + + actualObs, err := DecodeCommitPluginObservation(result) + assert.NoError(t, err) + + assert.Equal(t, tc.expObs, actualObs) + }) + } +} + +func Test_ObserveOffRampNextSeqNums(t *testing.T) { + testCases := []struct { + name string + supportsDestChain bool + supportsDestChainError error + knownSourceChains []cciptypes.ChainSelector + knownSourceChainsError error + nextSeqNums []cciptypes.SeqNum + nextSeqNumsError error + expResult []plugintypes.SeqNumChain + }{ + { + name: "Happy path", + supportsDestChain: true, + supportsDestChainError: nil, + knownSourceChains: []cciptypes.ChainSelector{4, 7, 19}, + knownSourceChainsError: nil, + nextSeqNums: []cciptypes.SeqNum{345, 608, 7713}, + nextSeqNumsError: nil, + expResult: []plugintypes.SeqNumChain{ + plugintypes.NewSeqNumChain(4, 345), + plugintypes.NewSeqNumChain(7, 608), + plugintypes.NewSeqNumChain(19, 7713), + }, + }, + { + name: "nil is returned when supportsDestChain is false", + supportsDestChain: false, + supportsDestChainError: nil, + knownSourceChains: []cciptypes.ChainSelector{4, 7, 19}, + knownSourceChainsError: nil, + nextSeqNums: []cciptypes.SeqNum{345, 608, 7713}, + nextSeqNumsError: nil, + expResult: nil, + }, + { + name: "nil is returned when supportsDestChain errors", + supportsDestChain: true, + supportsDestChainError: fmt.Errorf("error"), + knownSourceChains: []cciptypes.ChainSelector{4, 7, 19}, + knownSourceChainsError: nil, + nextSeqNums: []cciptypes.SeqNum{345, 608, 7713}, + nextSeqNumsError: nil, + expResult: nil, + }, + { + name: "nil is returned when knownSourceChains errors", + supportsDestChain: true, + supportsDestChainError: nil, + knownSourceChains: []cciptypes.ChainSelector{4, 7, 19}, + knownSourceChainsError: fmt.Errorf("error"), + nextSeqNums: []cciptypes.SeqNum{345, 608, 7713}, + nextSeqNumsError: nil, + expResult: nil, + }, + { + name: "nil is returned when nextSeqNums returns incorrect number of seq nums", + supportsDestChain: true, + supportsDestChainError: nil, + knownSourceChains: []cciptypes.ChainSelector{4, 7, 19}, + knownSourceChainsError: nil, + nextSeqNums: []cciptypes.SeqNum{345, 608}, + nextSeqNumsError: nil, + expResult: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + var nodeID commontypes.OracleID = 1 + reader := mocks.NewCCIPReader() + reader.On( + "NextSeqNum", ctx, tc.knownSourceChains, + ).Return(tc.nextSeqNums, tc.nextSeqNumsError) + + chainSupport := mocks.NewChainSupport() + chainSupport.On( + "SupportsDestChain", nodeID, + ).Return(tc.supportsDestChain, tc.supportsDestChainError) + chainSupport.On( + "KnownSourceChainsSlice", + ).Return(tc.knownSourceChains, tc.knownSourceChainsError) + + o := ObserverImpl{ + nodeID: nodeID, + lggr: logger.Test(t), + msgHasher: mocks.NewMessageHasher(), + ccipReader: reader, + chainSupport: chainSupport, + } + + assert.Equal(t, tc.expResult, o.ObserveOffRampNextSeqNums(ctx)) + }) + } +} + +func Test_ObserveMerkleRoots(t *testing.T) { + testCases := []struct { + name string + ranges []plugintypes.ChainRange + supportedChains mapset.Set[cciptypes.ChainSelector] + supportedChainsFails bool + msgsBetweenSeqNums map[cciptypes.ChainSelector][]cciptypes.Message + msgsBetweenSeqNumsErrors map[cciptypes.ChainSelector]error + expMerkleRoots map[cciptypes.ChainSelector]string + }{ + { + name: "Success single chain", + ranges: []plugintypes.ChainRange{ + { + ChainSel: 8, + SeqNumRange: cciptypes.SeqNumRange{10, 11}, + }, + }, + supportedChains: mapset.NewSet[cciptypes.ChainSelector](8), + supportedChainsFails: false, + msgsBetweenSeqNums: map[cciptypes.ChainSelector][]cciptypes.Message{ + 8: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 10}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1b"), + SequenceNumber: 11}, + }, + }, + }, + msgsBetweenSeqNumsErrors: map[cciptypes.ChainSelector]error{}, + expMerkleRoots: map[cciptypes.ChainSelector]string{ + 8: "5b81aaf37240df67f3ab0e845f30e29f35fdf9169e2517c436c1c0c11224c97b", + }, + }, + { + name: "Success multiple chains", + ranges: []plugintypes.ChainRange{ + { + ChainSel: 8, + SeqNumRange: cciptypes.SeqNumRange{10, 11}, + }, + { + ChainSel: 15, + SeqNumRange: cciptypes.SeqNumRange{53, 55}, + }, + }, + supportedChains: mapset.NewSet[cciptypes.ChainSelector](8, 15), + supportedChainsFails: false, + msgsBetweenSeqNums: map[cciptypes.ChainSelector][]cciptypes.Message{ + 8: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 10}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1b"), + SequenceNumber: 11}}, + }, + 15: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x2a"), + SequenceNumber: 53}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x2b"), + SequenceNumber: 54}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x2c"), + SequenceNumber: 55}}, + }, + }, + msgsBetweenSeqNumsErrors: map[cciptypes.ChainSelector]error{}, + expMerkleRoots: map[cciptypes.ChainSelector]string{ + 8: "5b81aaf37240df67f3ab0e845f30e29f35fdf9169e2517c436c1c0c11224c97b", + 15: "c7685b1be19745f244da890574cf554d75a3feeaf0e1181541c594d77ac1d6c4", + }, + }, + { + name: "Unsupported chain does not return a merkle root", + ranges: []plugintypes.ChainRange{ + { + ChainSel: 8, + SeqNumRange: cciptypes.SeqNumRange{10, 11}, + }, + { + // Unsupported chain + ChainSel: 12, + SeqNumRange: cciptypes.SeqNumRange{50, 60}, + }, + }, + supportedChains: mapset.NewSet[cciptypes.ChainSelector](8), + supportedChainsFails: false, + msgsBetweenSeqNums: map[cciptypes.ChainSelector][]cciptypes.Message{ + 8: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 10}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1b"), + SequenceNumber: 11}, + }, + }, + }, + msgsBetweenSeqNumsErrors: map[cciptypes.ChainSelector]error{}, + expMerkleRoots: map[cciptypes.ChainSelector]string{ + 8: "5b81aaf37240df67f3ab0e845f30e29f35fdf9169e2517c436c1c0c11224c97b", + }, + }, + { + name: "Call to supportedChains fails", + ranges: []plugintypes.ChainRange{ + { + ChainSel: 8, + SeqNumRange: cciptypes.SeqNumRange{10, 11}, + }, + }, + supportedChains: mapset.NewSet[cciptypes.ChainSelector](8), + supportedChainsFails: true, + msgsBetweenSeqNums: map[cciptypes.ChainSelector][]cciptypes.Message{ + 8: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 10}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1b"), + SequenceNumber: 11}, + }, + }, + }, + msgsBetweenSeqNumsErrors: map[cciptypes.ChainSelector]error{}, + expMerkleRoots: nil, + }, + { + name: "msgsBetweenSeqNums fails for a chain", + ranges: []plugintypes.ChainRange{ + { + ChainSel: 8, + SeqNumRange: cciptypes.SeqNumRange{10, 11}, + }, + { + ChainSel: 12, + SeqNumRange: cciptypes.SeqNumRange{50, 60}, + }, + }, + supportedChains: mapset.NewSet[cciptypes.ChainSelector](8), + supportedChainsFails: false, + msgsBetweenSeqNums: map[cciptypes.ChainSelector][]cciptypes.Message{ + 8: {{ + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 10}, + }, { + Header: cciptypes.RampMessageHeader{ + MessageID: mustNewMessageID("0x1b"), + SequenceNumber: 11}}, + }, + 12: {}, + }, + msgsBetweenSeqNumsErrors: map[cciptypes.ChainSelector]error{ + 12: fmt.Errorf("error"), + }, + expMerkleRoots: map[cciptypes.ChainSelector]string{ + 8: "5b81aaf37240df67f3ab0e845f30e29f35fdf9169e2517c436c1c0c11224c97b", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + var nodeID commontypes.OracleID = 1 + reader := mocks.NewCCIPReader() + for _, r := range tc.ranges { + var err error + if e, exists := tc.msgsBetweenSeqNumsErrors[r.ChainSel]; exists { + err = e + } + reader.On( + "MsgsBetweenSeqNums", ctx, r.ChainSel, r.SeqNumRange, + ).Return(tc.msgsBetweenSeqNums[r.ChainSel], err) + } + + chainSupport := mocks.NewChainSupport() + if tc.supportedChainsFails { + chainSupport.On("SupportedChains", nodeID).Return( + mapset.NewSet[cciptypes.ChainSelector](), fmt.Errorf("error"), + ) + } else { + chainSupport.On("SupportedChains", nodeID).Return(tc.supportedChains, nil) + } + + o := ObserverImpl{ + nodeID: nodeID, + lggr: logger.Test(t), + msgHasher: mocks.NewMessageHasher(), + ccipReader: reader, + chainSupport: chainSupport, + } + + roots := o.ObserveMerkleRoots(ctx, tc.ranges) + if tc.expMerkleRoots == nil { + assert.Nil(t, roots) + } else { + for _, root := range roots { + assert.Equal(t, tc.expMerkleRoots[root.ChainSel], hex.EncodeToString(root.MerkleRoot[:])) + } + } + }) + } +} + +func Test_computeMerkleRoot(t *testing.T) { + testCases := []struct { + name string + messageHeaders []cciptypes.RampMessageHeader + messageHasher cciptypes.MessageHasher + expMerkleRoot string + expErr bool + }{ + { + name: "Single message success", + messageHeaders: []cciptypes.RampMessageHeader{ + { + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 112, + }}, + messageHasher: mocks.NewMessageHasher(), + expMerkleRoot: "1a00000000000000000000000000000000000000000000000000000000000000", + expErr: false, + }, + { + name: "Multiple messages success", + messageHeaders: []cciptypes.RampMessageHeader{ + { + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 112, + }, + { + MessageID: mustNewMessageID("0x23"), + SequenceNumber: 113, + }, + { + MessageID: mustNewMessageID("0x87"), + SequenceNumber: 114, + }}, + messageHasher: mocks.NewMessageHasher(), + expMerkleRoot: "94c7e711e6f2acf41dca598ced55b6925e55aaed83520dc5ea6cbc054344564b", + expErr: false, + }, + { + name: "Sequence number gap", + messageHeaders: []cciptypes.RampMessageHeader{ + { + MessageID: mustNewMessageID("0x10"), + SequenceNumber: 34, + }, + { + MessageID: mustNewMessageID("0x12"), + SequenceNumber: 36, + }}, + messageHasher: mocks.NewMessageHasher(), + expMerkleRoot: "", + expErr: true, + }, + { + name: "Empty messages", + messageHeaders: []cciptypes.RampMessageHeader{}, + messageHasher: mocks.NewMessageHasher(), + expMerkleRoot: "", + expErr: true, + }, + { + name: "Bad hasher", + messageHeaders: []cciptypes.RampMessageHeader{ + { + MessageID: mustNewMessageID("0x1a"), + SequenceNumber: 112, + }, + { + MessageID: mustNewMessageID("0x23"), + SequenceNumber: 113, + }, + { + MessageID: mustNewMessageID("0x87"), + SequenceNumber: 114, + }}, + messageHasher: NewBadMessageHasher(), + expMerkleRoot: "94c7e711e6f2acf41dca598ced55b6925e55aaed83520dc5ea6cbc054344564b", + expErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := ObserverImpl{ + lggr: logger.Test(t), + msgHasher: tc.messageHasher, + } + + msgs := make([]cciptypes.Message, 0) + for _, h := range tc.messageHeaders { + msgs = append(msgs, cciptypes.Message{Header: h}) + } + + rootBytes, err := p.computeMerkleRoot(context.Background(), msgs) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + rootString := hex.EncodeToString(rootBytes[:]) + assert.Equal(t, tc.expMerkleRoot, rootString) + }) + } +} + +func mustNewMessageID(msgIDHex string) cciptypes.Bytes32 { + msgID, err := cciptypes.NewBytes32FromString(msgIDHex) + if err != nil { + panic(err) + } + return msgID +} + +type BadMessageHasher struct{} + +func NewBadMessageHasher() *BadMessageHasher { + return &BadMessageHasher{} +} + +// Always returns an error +func (m *BadMessageHasher) Hash(ctx context.Context, msg cciptypes.Message) (cciptypes.Bytes32, error) { + return cciptypes.Bytes32{}, fmt.Errorf("failed to hash") +} diff --git a/commitrmnocb/outcome.go b/commitrmnocb/outcome.go new file mode 100644 index 000000000..bea98f47c --- /dev/null +++ b/commitrmnocb/outcome.go @@ -0,0 +1,335 @@ +package commitrmnocb + +import ( + "fmt" + "sort" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "golang.org/x/exp/maps" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +// Outcome depending on the current state, either: +// - chooses the seq num ranges for the next round +// - builds a report +// - checks for the transmission of a previous report +func (p *Plugin) Outcome( + outCtx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, +) (ocr3types.Outcome, error) { + previousOutcome, nextState := p.decodeOutcome(outCtx.PreviousOutcome) + commitQuery := Query{} + + consensusObservation, err := getConsensusObservation(p.lggr, p.reportingCfg.F, p.cfg.DestChain, aos) + if err != nil { + return ocr3types.Outcome{}, err + } + + outcome := Outcome{} + + switch nextState { + case SelectingRangesForReport: + outcome = ReportRangesOutcome(commitQuery, consensusObservation) + + case BuildingReport: + outcome = buildReport(commitQuery, consensusObservation) + + case WaitingForReportTransmission: + outcome = checkForReportTransmission( + p.lggr, p.cfg.MaxReportTransmissionCheckAttempts, previousOutcome, consensusObservation) + + default: + p.lggr.Warnw("Unexpected state in Outcome", "state", nextState) + return outcome.Encode() + } + + p.lggr.Infow("Commit Plugin Outcome", "outcome", outcome, "oid", p.nodeID) + return outcome.Encode() +} + +// ReportRangesOutcome determines the sequence number ranges for each chain to build a report from in the next round +// TODO: ensure each range is below a limit +func ReportRangesOutcome( + query Query, + consensusObservation ConsensusObservation, +) Outcome { + rangesToReport := make([]plugintypes.ChainRange, 0) + + rmnOnRampMaxSeqNumsMap := make(map[cciptypes.ChainSelector]cciptypes.SeqNum) + for _, seqNumChain := range query.RmnOnRampMaxSeqNums { + rmnOnRampMaxSeqNumsMap[seqNumChain.ChainSel] = seqNumChain.SeqNum + } + + observedOnRampMaxSeqNumsMap := consensusObservation.OnRampMaxSeqNums + observedOffRampNextSeqNumsMap := consensusObservation.OffRampNextSeqNums + + for chainSel, offRampNextSeqNum := range observedOffRampNextSeqNumsMap { + onRampMaxSeqNum, exists := observedOnRampMaxSeqNumsMap[chainSel] + if !exists { + continue + } + + if rmnOnRampMaxSeqNum, exists := rmnOnRampMaxSeqNumsMap[chainSel]; exists { + onRampMaxSeqNum = min(onRampMaxSeqNum, rmnOnRampMaxSeqNum) + } + + if offRampNextSeqNum <= onRampMaxSeqNum { + chainRange := plugintypes.ChainRange{ + ChainSel: chainSel, + SeqNumRange: [2]cciptypes.SeqNum{offRampNextSeqNum, onRampMaxSeqNum}, + } + rangesToReport = append(rangesToReport, chainRange) + } + } + + outcome := Outcome{ + OutcomeType: ReportIntervalsSelected, + RangesSelectedForReport: rangesToReport, + } + + return outcome +} + +// Given a set of observed merkle roots, gas prices and token prices, and roots from RMN, construct a report +// to transmit on-chain +func buildReport( + _ Query, + consensusObservation ConsensusObservation, +) Outcome { + roots := maps.Values(consensusObservation.MerkleRoots) + + outcomeType := ReportGenerated + if len(roots) == 0 { + outcomeType = ReportEmpty + } + + outcome := Outcome{ + OutcomeType: outcomeType, + RootsToReport: roots, + GasPrices: consensusObservation.GasPricesArray(), + TokenPrices: consensusObservation.TokenPricesArray(), + } + + return outcome +} + +// checkForReportTransmission checks if the OffRamp has an updated set of max seq nums compared to the seq nums that +// were observed when the most recent report was generated. If an update to these max seq sums is detected, it means +// that the previous report has been transmitted, and we output ReportTransmitted to dictate that a new report +// generation phase should begin. If no update is detected, and we've exhausted our check attempts, output +// ReportTransmissionFailed to signify we stop checking for updates and start a new report generation phase. If no +// update is detected, and we haven't exhausted our check attempts, output ReportInFlight to signify that we should +// check again next round. +func checkForReportTransmission( + lggr logger.Logger, + maxReportTransmissionCheckAttempts uint, + previousOutcome Outcome, + consensusObservation ConsensusObservation, +) Outcome { + + offRampUpdated := false + for _, previousSeqNumChain := range previousOutcome.OffRampNextSeqNums { + if currentSeqNum, exists := consensusObservation.OffRampNextSeqNums[previousSeqNumChain.ChainSel]; exists { + if previousSeqNumChain.SeqNum != currentSeqNum { + offRampUpdated = true + break + } + } + } + + if offRampUpdated { + return Outcome{ + OutcomeType: ReportTransmitted, + } + } + + if previousOutcome.ReportTransmissionCheckAttempts+1 >= maxReportTransmissionCheckAttempts { + lggr.Warnw("Failed to detect report transmission") + return Outcome{ + OutcomeType: ReportTransmissionFailed, + } + } + + return Outcome{ + OutcomeType: ReportInFlight, + OffRampNextSeqNums: previousOutcome.OffRampNextSeqNums, + ReportTransmissionCheckAttempts: previousOutcome.ReportTransmissionCheckAttempts + 1, + } +} + +// getConsensusObservation Combine the list of observations into a single consensus observation +func getConsensusObservation( + lggr logger.Logger, + F int, + destChain cciptypes.ChainSelector, + aos []types.AttributedObservation, +) (ConsensusObservation, error) { + aggObs := aggregateObservations(aos) + fChains := fChainConsensus(lggr, F, aggObs.FChain) + + fDestChain, exists := fChains[destChain] + if !exists { + return ConsensusObservation{}, + fmt.Errorf("no consensus value for fDestChain, destChain: %d", destChain) + } + + consensusObs := ConsensusObservation{ + MerkleRoots: merkleRootConsensus(lggr, aggObs.MerkleRoots, fChains), + // TODO: use consensus of observed gas prices + GasPrices: make(map[cciptypes.ChainSelector]cciptypes.BigInt), + // TODO: use consensus of observed token prices + TokenPrices: make(map[types.Account]cciptypes.BigInt), + OnRampMaxSeqNums: onRampMaxSeqNumsConsensus(lggr, aggObs.OnRampMaxSeqNums, fChains), + OffRampNextSeqNums: offRampMaxSeqNumsConsensus(lggr, aggObs.OffRampNextSeqNums, fDestChain), + FChain: fChains, + } + + return consensusObs, nil +} + +// Given a mapping from chains to a list of merkle roots, return a mapping from chains to a single consensus merkle +// root. The consensus merkle root for a given chain is the merkle root with the most observations that was observed at +// least fChain times. +func merkleRootConsensus( + lggr logger.Logger, + rootsByChain map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain, + fChains map[cciptypes.ChainSelector]int, +) map[cciptypes.ChainSelector]cciptypes.MerkleRootChain { + consensus := make(map[cciptypes.ChainSelector]cciptypes.MerkleRootChain) + + for chain, roots := range rootsByChain { + if fChain, exists := fChains[chain]; exists { + root, count := mostFrequentElem(roots) + + if count <= fChain { + // TODO: metrics + lggr.Warnf("failed to reach consensus on a merkle root for chain %d "+ + "because no single merkle root was observed more than the expected fChain (%d) times, found "+ + "merkle root %d observed by only %d oracles, all observed merkle roots: %v", + chain, fChain, root, count, roots) + } else { + consensus[chain] = root + } + } else { + // TODO: metrics + lggr.Warnf("merkleRootConsensus: fChain not found for chain %d", chain) + } + } + + return consensus +} + +// Given a mapping from chains to a list of max seq nums on their corresponding OnRamp, return a mapping from chains +// to a single max seq num. The consensus max seq num for a given chain is the f'th lowest max seq num if the number +// of max seq num observations is greater or equal than 2f+1, where f is the FChain of the corresponding source chain. +func onRampMaxSeqNumsConsensus( + lggr logger.Logger, + onRampMaxSeqNumsByChain map[cciptypes.ChainSelector][]cciptypes.SeqNum, + fChains map[cciptypes.ChainSelector]int, +) map[cciptypes.ChainSelector]cciptypes.SeqNum { + consensus := make(map[cciptypes.ChainSelector]cciptypes.SeqNum) + + for chain, onRampMaxSeqNums := range onRampMaxSeqNumsByChain { + if fChain, exists := fChains[chain]; exists { + if len(onRampMaxSeqNums) < 2*fChain+1 { + // TODO: metrics + lggr.Warnf("could not reach consensus on onRampMaxSeqNums for chain %d "+ + "because we did not receive more than 2fChain+1 observed sequence numbers, 2fChain+1: %d, "+ + "len(onRampMaxSeqNums): %d, onRampMaxSeqNums: %v", + chain, 2*fChain+1, len(onRampMaxSeqNums), onRampMaxSeqNums) + } else { + sort.Slice(onRampMaxSeqNums, func(i, j int) bool { return onRampMaxSeqNums[i] < onRampMaxSeqNums[j] }) + consensus[chain] = onRampMaxSeqNums[fChain] + } + } else { + // TODO: metrics + lggr.Warnf("could not reach consensus on onRampMaxSeqNums for chain %d "+ + "because there was no consensus fChain value for this chain", chain) + } + } + + return consensus +} + +// Given a mapping from chains to a list of max seq nums on the OffRamp, return a mapping from chains +// to a single max seq num. The consensus max seq num for a given chain is the max seq num with the most observations +// that was observed at least f times, where f is the FChain of the dest chain. +func offRampMaxSeqNumsConsensus( + lggr logger.Logger, + offRampMaxSeqNumsByChain map[cciptypes.ChainSelector][]cciptypes.SeqNum, + fDestChain int, +) map[cciptypes.ChainSelector]cciptypes.SeqNum { + consensus := make(map[cciptypes.ChainSelector]cciptypes.SeqNum) + + for chain, offRampMaxSeqNums := range offRampMaxSeqNumsByChain { + seqNum, count := mostFrequentElem(offRampMaxSeqNums) + if count <= fDestChain { + // TODO: metrics + lggr.Warnf("could not reach consensus on offRampMaxSeqNums for chain %d "+ + "because we did not receive a sequence number that was observed by at least fChain (%d) oracles, "+ + "offRampMaxSeqNums: %v", chain, fDestChain, offRampMaxSeqNums) + } else { + consensus[chain] = seqNum + } + } + + return consensus +} + +// Given a mapping from chains to a list of FChain values for each chain, return a mapping from chains +// to a single FChain. The consensus FChain for a given chain is the FChain with the most observations +// that was observed at least f times, where f is the F of the DON (p.reportingCfg.F). +func fChainConsensus( + lggr logger.Logger, + F int, + fChainValues map[cciptypes.ChainSelector][]int, +) map[cciptypes.ChainSelector]int { + consensus := make(map[cciptypes.ChainSelector]int) + + for chain, fValues := range fChainValues { + fChain, count := mostFrequentElem(fValues) + if count < F { + // TODO: metrics + lggr.Warnf("failed to reach consensus on fChain values for chain %d because no single fChain "+ + "value was observed more than the expected %d times, found fChain value %d observed by only %d oracles, "+ + "fChain values: %v", + chain, F, fChain, count, fValues) + } + + consensus[chain] = fChain + } + + return consensus +} + +// Given a list of elems, return the elem that occurs most frequently and how often it occurs +func mostFrequentElem[T comparable](elems []T) (T, int) { + var mostFrequentElem T + + counts := counts(elems) + maxCount := 0 + + for elem, count := range counts { + if count > maxCount { + mostFrequentElem = elem + maxCount = count + } + } + + return mostFrequentElem, maxCount +} + +// Given a list of elems, return a map from elems to how often they occur in the given list +func counts[T comparable](elems []T) map[T]int { + m := make(map[T]int) + for _, elem := range elems { + m[elem]++ + } + + return m +} diff --git a/commitrmnocb/plugin.go b/commitrmnocb/plugin.go new file mode 100644 index 000000000..0f234008b --- /dev/null +++ b/commitrmnocb/plugin.go @@ -0,0 +1,137 @@ +package commitrmnocb + +import ( + "context" + "fmt" + "time" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" + "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/pluginconfig" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +type Plugin struct { + nodeID commontypes.OracleID + oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID + cfg pluginconfig.CommitPluginConfig + ccipReader reader.CCIP + readerSyncer *plugincommon.BackgroundReaderSyncer + tokenPricesReader reader.TokenPrices + reportCodec cciptypes.CommitPluginCodec + lggr logger.Logger + homeChain reader.HomeChain + reportingCfg ocr3types.ReportingPluginConfig + chainSupport ChainSupport + observer Observer +} + +func NewPlugin( + _ context.Context, + nodeID commontypes.OracleID, + oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID, + cfg pluginconfig.CommitPluginConfig, + ccipReader reader.CCIP, + tokenPricesReader reader.TokenPrices, + reportCodec cciptypes.CommitPluginCodec, + msgHasher cciptypes.MessageHasher, + lggr logger.Logger, + homeChain reader.HomeChain, + reportingCfg ocr3types.ReportingPluginConfig, +) *Plugin { + readerSyncer := plugincommon.NewBackgroundReaderSyncer( + lggr, + ccipReader, + syncTimeout(cfg.SyncTimeout), + syncFrequency(cfg.SyncFrequency), + ) + if err := readerSyncer.Start(context.Background()); err != nil { + lggr.Errorw("error starting background reader syncer", "err", err) + } + + chainSupport := CCIPChainSupport{ + lggr: lggr, + homeChain: homeChain, + oracleIDToP2pID: oracleIDToP2pID, + nodeID: nodeID, + destChain: cfg.DestChain, + } + + observer := ObserverImpl{ + lggr: lggr, + homeChain: homeChain, + nodeID: nodeID, + chainSupport: chainSupport, + ccipReader: ccipReader, + msgHasher: msgHasher, + } + + return &Plugin{ + nodeID: nodeID, + oracleIDToP2pID: oracleIDToP2pID, + lggr: lggr, + cfg: cfg, + tokenPricesReader: tokenPricesReader, + ccipReader: ccipReader, + homeChain: homeChain, + readerSyncer: readerSyncer, + reportCodec: reportCodec, + reportingCfg: reportingCfg, + chainSupport: chainSupport, + observer: observer, + } +} + +func (p *Plugin) Close() error { + timeout := 10 * time.Second + ctx, cf := context.WithTimeout(context.Background(), timeout) + defer cf() + + if err := p.readerSyncer.Close(); err != nil { + p.lggr.Errorw("error closing reader syncer", "err", err) + } + + if err := p.ccipReader.Close(ctx); err != nil { + return fmt.Errorf("close ccip reader: %w", err) + } + + return nil +} + +func (p *Plugin) decodeOutcome(outcome ocr3types.Outcome) (Outcome, State) { + if len(outcome) == 0 { + return Outcome{}, SelectingRangesForReport + } + + decodedOutcome, err := DecodeOutcome(outcome) + if err != nil { + p.lggr.Errorw("Failed to decode Outcome", "outcome", outcome, "err", err) + return Outcome{}, SelectingRangesForReport + } + + return decodedOutcome, decodedOutcome.NextState() +} + +func syncFrequency(configuredValue time.Duration) time.Duration { + if configuredValue.Milliseconds() == 0 { + return 10 * time.Second + } + return configuredValue +} + +func syncTimeout(configuredValue time.Duration) time.Duration { + if configuredValue.Milliseconds() == 0 { + return 3 * time.Second + } + return configuredValue +} + +// Interface compatibility checks. +var _ ocr3types.ReportingPlugin[[]byte] = &Plugin{} diff --git a/commitrmnocb/query.go b/commitrmnocb/query.go new file mode 100644 index 000000000..471e0e228 --- /dev/null +++ b/commitrmnocb/query.go @@ -0,0 +1,12 @@ +package commitrmnocb + +import ( + "context" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" +) + +func (p *Plugin) Query(_ context.Context, outCtx ocr3types.OutcomeContext) (types.Query, error) { + return types.Query{}, nil +} diff --git a/commitrmnocb/report.go b/commitrmnocb/report.go new file mode 100644 index 000000000..4c424fb01 --- /dev/null +++ b/commitrmnocb/report.go @@ -0,0 +1,86 @@ +package commitrmnocb + +import ( + "context" + "encoding/hex" + "fmt" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/commit" +) + +func (p *Plugin) Reports(seqNr uint64, outcomeBytes ocr3types.Outcome) ([]ocr3types.ReportWithInfo[[]byte], error) { + outcome, err := DecodeOutcome(outcomeBytes) + if err != nil { + // TODO: metrics + p.lggr.Errorw("failed to decode Outcome", "outcomeBytes", outcomeBytes, "err", err) + return nil, fmt.Errorf("failed to decode Outcome (%s): %w", hex.EncodeToString(outcomeBytes), err) + } + + // Reports are only generated from "ReportGenerated" outcomes + if outcome.OutcomeType != ReportGenerated { + return []ocr3types.ReportWithInfo[[]byte]{}, nil + } + + rep := cciptypes.NewCommitPluginReport(outcome.RootsToReport, outcome.TokenPrices, outcome.GasPrices) + + encodedReport, err := p.reportCodec.Encode(context.Background(), rep) + if err != nil { + return nil, fmt.Errorf("encode commit plugin report: %w", err) + } + + return []ocr3types.ReportWithInfo[[]byte]{{Report: encodedReport, Info: nil}}, nil +} + +func (p *Plugin) ShouldAcceptAttestedReport( + ctx context.Context, u uint64, r ocr3types.ReportWithInfo[[]byte], +) (bool, error) { + decodedReport, err := p.reportCodec.Decode(ctx, r.Report) + if err != nil { + return false, fmt.Errorf("decode commit plugin report: %w", err) + } + + isEmpty := decodedReport.IsEmpty() + if isEmpty { + p.lggr.Infow("skipping empty report") + return false, nil + } + + return true, nil +} + +func (p *Plugin) ShouldTransmitAcceptedReport( + ctx context.Context, u uint64, r ocr3types.ReportWithInfo[[]byte], +) (bool, error) { + isWriter, err := p.chainSupport.SupportsDestChain(p.nodeID) + if err != nil { + return false, fmt.Errorf("can't know if it's a writer: %w", err) + } + if !isWriter { + p.lggr.Infow("not a writer, skipping report transmission") + return false, nil + } + + decodedReport, err := p.reportCodec.Decode(ctx, r.Report) + if err != nil { + return false, fmt.Errorf("decode commit plugin report: %w", err) + } + + isValid, err := commit.ValidateMerkleRootsState(ctx, p.lggr, decodedReport, p.ccipReader) + if !isValid { + return false, nil + } + if err != nil { + return false, fmt.Errorf("validate merkle roots state: %w", err) + } + + p.lggr.Infow("transmitting report", + "roots", len(decodedReport.MerkleRoots), + "tokenPriceUpdates", len(decodedReport.PriceUpdates.TokenPriceUpdates), + "gasPriceUpdates", len(decodedReport.PriceUpdates.GasPriceUpdates), + ) + return true, nil +} diff --git a/commitrmnocb/types.go b/commitrmnocb/types.go new file mode 100644 index 000000000..d23fb2255 --- /dev/null +++ b/commitrmnocb/types.go @@ -0,0 +1,265 @@ +package commitrmnocb + +import ( + "encoding/json" + "fmt" + "sort" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-ccip/plugintypes" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +type Query struct { + RmnOnRampMaxSeqNums []plugintypes.SeqNumChain + MerkleRoots []cciptypes.MerkleRootChain +} + +func (q Query) Encode() ([]byte, error) { + return json.Marshal(q) +} + +func DecodeCommitPluginQuery(encodedQuery []byte) (Query, error) { + q := Query{} + err := json.Unmarshal(encodedQuery, &q) + return q, err +} + +func NewCommitQuery(rmnOnRampMaxSeqNums []plugintypes.SeqNumChain, merkleRoots []cciptypes.MerkleRootChain) Query { + return Query{ + RmnOnRampMaxSeqNums: rmnOnRampMaxSeqNums, + MerkleRoots: merkleRoots, + } +} + +type Observation struct { + MerkleRoots []cciptypes.MerkleRootChain `json:"merkleRoots"` + GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` + TokenPrices []cciptypes.TokenPrice `json:"tokenPrices"` + OnRampMaxSeqNums []plugintypes.SeqNumChain `json:"onRampMaxSeqNums"` + OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` + FChain map[cciptypes.ChainSelector]int `json:"fChain"` +} + +func (obs Observation) Encode() ([]byte, error) { + encodedObservation, err := json.Marshal(obs) + if err != nil { + return nil, fmt.Errorf("failed to encode Observation: %w", err) + } + + return encodedObservation, nil +} + +func DecodeCommitPluginObservation(encodedObservation []byte) (Observation, error) { + o := Observation{} + err := json.Unmarshal(encodedObservation, &o) + return o, err +} + +// AggregatedObservation is the aggregation of a list of observations +type AggregatedObservation struct { + // A map from chain selectors to the list of merkle roots observed for each chain + MerkleRoots map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain + + // A map from chain selectors to the list of gas prices observed for each chain + GasPrices map[cciptypes.ChainSelector][]cciptypes.BigInt + + // A map from token IDs to the list of prices observed for each token + TokenPrices map[types.Account][]cciptypes.BigInt + + // A map from chain selectors to the list of OnRamp max sequence numbers observed for each chain + OnRampMaxSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum + + // A map from chain selectors to the list of OffRamp next sequence numbers observed for each chain + OffRampNextSeqNums map[cciptypes.ChainSelector][]cciptypes.SeqNum + + // A map from chain selectors to the list of f (failure tolerance) observed for each chain + FChain map[cciptypes.ChainSelector][]int +} + +// aggregateObservations takes a list of observations and produces an AggregatedObservation +func aggregateObservations(aos []types.AttributedObservation) AggregatedObservation { + aggObs := AggregatedObservation{ + MerkleRoots: make(map[cciptypes.ChainSelector][]cciptypes.MerkleRootChain), + GasPrices: make(map[cciptypes.ChainSelector][]cciptypes.BigInt), + TokenPrices: make(map[types.Account][]cciptypes.BigInt), + OnRampMaxSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), + OffRampNextSeqNums: make(map[cciptypes.ChainSelector][]cciptypes.SeqNum), + FChain: make(map[cciptypes.ChainSelector][]int), + } + + for _, ao := range aos { + obs, err := DecodeCommitPluginObservation(ao.Observation) + if err != nil { + // TODO: lggr + continue + } + + // MerkleRoots + for _, merkleRoot := range obs.MerkleRoots { + aggObs.MerkleRoots[merkleRoot.ChainSel] = + append(aggObs.MerkleRoots[merkleRoot.ChainSel], merkleRoot) + } + + // GasPrices + for _, gasPriceChain := range obs.GasPrices { + aggObs.GasPrices[gasPriceChain.ChainSel] = + append(aggObs.GasPrices[gasPriceChain.ChainSel], gasPriceChain.GasPrice) + } + + // TokenPrices + for _, tokenPrice := range obs.TokenPrices { + aggObs.TokenPrices[tokenPrice.TokenID] = + append(aggObs.TokenPrices[tokenPrice.TokenID], tokenPrice.Price) + } + + // OnRampMaxSeqNums + for _, seqNumChain := range obs.OnRampMaxSeqNums { + aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel] = + append(aggObs.OnRampMaxSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) + } + + // OffRampNextSeqNums + for _, seqNumChain := range obs.OffRampNextSeqNums { + aggObs.OffRampNextSeqNums[seqNumChain.ChainSel] = + append(aggObs.OffRampNextSeqNums[seqNumChain.ChainSel], seqNumChain.SeqNum) + } + + // FChain + for chainSel, f := range obs.FChain { + aggObs.FChain[chainSel] = append(aggObs.FChain[chainSel], f) + } + } + + return aggObs +} + +// ConsensusObservation holds the consensus values for all chains across all observations in a round +type ConsensusObservation struct { + // A map from chain selectors to each chain's consensus merkle root + MerkleRoots map[cciptypes.ChainSelector]cciptypes.MerkleRootChain + + // A map from chain selectors to each chain's consensus gas prices + GasPrices map[cciptypes.ChainSelector]cciptypes.BigInt + + // A map from token IDs to each token's consensus price + TokenPrices map[types.Account]cciptypes.BigInt + + // A map from chain selectors to each chain's consensus OnRamp max sequence number + OnRampMaxSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum + + // A map from chain selectors to each chain's consensus OffRamp next sequence number + OffRampNextSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum + + // A map from chain selectors to each chain's consensus f (failure tolerance) + FChain map[cciptypes.ChainSelector]int +} + +// GasPricesArray returns a list of gas prices +func (co ConsensusObservation) GasPricesArray() []cciptypes.GasPriceChain { + gasPrices := make([]cciptypes.GasPriceChain, 0, len(co.GasPrices)) + for chain, gasPrice := range co.GasPrices { + gasPrices = append(gasPrices, cciptypes.NewGasPriceChain(gasPrice.Int, chain)) + } + + return gasPrices +} + +// TokenPricesArray returns a list of token prices +func (co ConsensusObservation) TokenPricesArray() []cciptypes.TokenPrice { + tokenPrices := make([]cciptypes.TokenPrice, 0, len(co.TokenPrices)) + for tokenID, tokenPrice := range co.TokenPrices { + tokenPrices = append(tokenPrices, cciptypes.NewTokenPrice(tokenID, tokenPrice.Int)) + } + + return tokenPrices +} + +type OutcomeType int + +const ( + ReportIntervalsSelected OutcomeType = iota + 1 + ReportGenerated + ReportEmpty + ReportInFlight + ReportTransmitted + ReportTransmissionFailed +) + +type Outcome struct { + OutcomeType OutcomeType `json:"outcomeType"` + RangesSelectedForReport []plugintypes.ChainRange `json:"rangesSelectedForReport"` + RootsToReport []cciptypes.MerkleRootChain `json:"rootsToReport"` + OffRampNextSeqNums []plugintypes.SeqNumChain `json:"offRampNextSeqNums"` + TokenPrices []cciptypes.TokenPrice `json:"tokenPrices"` + GasPrices []cciptypes.GasPriceChain `json:"gasPrices"` + ReportTransmissionCheckAttempts uint `json:"reportTransmissionCheckAttempts"` +} + +// Sort all fields of the given Outcome +func (o Outcome) sort() { + sort.Slice(o.RangesSelectedForReport, func(i, j int) bool { + return o.RangesSelectedForReport[i].ChainSel < o.RangesSelectedForReport[j].ChainSel + }) + sort.Slice(o.RootsToReport, func(i, j int) bool { + return o.RootsToReport[i].ChainSel < o.RootsToReport[j].ChainSel + }) + sort.Slice(o.OffRampNextSeqNums, func(i, j int) bool { + return o.OffRampNextSeqNums[i].ChainSel < o.OffRampNextSeqNums[j].ChainSel + }) + sort.Slice(o.TokenPrices, func(i, j int) bool { + return o.TokenPrices[i].TokenID < o.TokenPrices[j].TokenID + }) + sort.Slice(o.GasPrices, func(i, j int) bool { + return o.GasPrices[i].ChainSel < o.GasPrices[j].ChainSel + }) +} + +// Encode encodes an Outcome deterministically +func (o Outcome) Encode() ([]byte, error) { + + // Sort all lists to ensure deterministic serialization + o.sort() + + encodedOutcome, err := json.Marshal(o) + if err != nil { + return nil, fmt.Errorf("failed to encode Outcome: %w", err) + } + + return encodedOutcome, nil +} + +func DecodeOutcome(b []byte) (Outcome, error) { + o := Outcome{} + err := json.Unmarshal(b, &o) + return o, err +} + +func (o Outcome) NextState() State { + switch o.OutcomeType { + case ReportIntervalsSelected: + return BuildingReport + case ReportGenerated: + return WaitingForReportTransmission + case ReportEmpty: + return SelectingRangesForReport + case ReportInFlight: + return WaitingForReportTransmission + case ReportTransmitted: + return SelectingRangesForReport + case ReportTransmissionFailed: + return SelectingRangesForReport + default: + return SelectingRangesForReport + } +} + +type State int + +const ( + SelectingRangesForReport State = iota + 1 + BuildingReport + WaitingForReportTransmission +) diff --git a/commitrmnocb/validate_observation.go b/commitrmnocb/validate_observation.go new file mode 100644 index 000000000..7d7181ebb --- /dev/null +++ b/commitrmnocb/validate_observation.go @@ -0,0 +1,146 @@ +package commitrmnocb + +import ( + "fmt" + + mapset "github.com/deckarep/golang-set/v2" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-ccip/commit" + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +// ValidateObservation validates an observation to ensure it is well-formed +func (p *Plugin) ValidateObservation(_ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error { + obs, err := DecodeCommitPluginObservation(ao.Observation) + if err != nil { + return fmt.Errorf("failed to decode commit plugin observation: %w", err) + } + + if err := validateFChain(obs.FChain); err != nil { + return fmt.Errorf("failed to validate FChain: %w", err) + } + + observerSupportedChains, err := p.chainSupport.SupportedChains(ao.Observer) + if err != nil { + return fmt.Errorf("failed to get supported chains: %w", err) + } + + supportsDestChain, err := p.chainSupport.SupportsDestChain(ao.Observer) + if err != nil { + return fmt.Errorf("call to supportsDestChain failed: %w", err) + } + + if err := validateObservedMerkleRoots(obs.MerkleRoots, ao.Observer, observerSupportedChains); err != nil { + return fmt.Errorf("failed to validate MerkleRoots: %w", err) + } + + if err := validateObservedOnRampMaxSeqNums(obs.OnRampMaxSeqNums, ao.Observer, observerSupportedChains); err != nil { + return fmt.Errorf("failed to validate OnRampMaxSeqNums: %w", err) + } + + if err := validateObservedOffRampMaxSeqNums(obs.OffRampNextSeqNums, ao.Observer, supportsDestChain); err != nil { + return fmt.Errorf("failed to validate OffRampNextSeqNums: %w", err) + } + + if err := commit.ValidateObservedTokenPrices(obs.TokenPrices); err != nil { + return fmt.Errorf("failed to validate token prices: %w", err) + } + + if err := commit.ValidateObservedGasPrices(obs.GasPrices); err != nil { + return fmt.Errorf("failed to validate gas prices: %w", err) + } + + return nil +} + +func validateObservedMerkleRoots( + merkleRoots []cciptypes.MerkleRootChain, + observer commontypes.OracleID, + observerSupportedChains mapset.Set[cciptypes.ChainSelector], +) error { + if len(merkleRoots) == 0 { + return nil + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, root := range merkleRoots { + if !observerSupportedChains.Contains(root.ChainSel) { + return fmt.Errorf("found merkle root for chain %d, but this chain is not supported by Observer %d", + root.ChainSel, observer) + } + + if seenChains.Contains(root.ChainSel) { + return fmt.Errorf("duplicate merkle root for chain %d", root.ChainSel) + } + seenChains.Add(root.ChainSel) + } + + return nil +} + +func validateObservedOnRampMaxSeqNums( + onRampMaxSeqNums []plugintypes.SeqNumChain, + observer commontypes.OracleID, + observerSupportedChains mapset.Set[cciptypes.ChainSelector], +) error { + if len(onRampMaxSeqNums) == 0 { + return nil + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, seqNumChain := range onRampMaxSeqNums { + if !observerSupportedChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("found onRampMaxSeqNum for chain %d, but this chain is not supported by Observer %d, "+ + "observerSupportedChains: %v, onRampMaxSeqNums: %v", + seqNumChain.ChainSel, observer, observerSupportedChains, onRampMaxSeqNums) + } + + if seenChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("duplicate onRampMaxSeqNum for chain %d", seqNumChain.ChainSel) + } + seenChains.Add(seqNumChain.ChainSel) + } + + return nil +} + +func validateObservedOffRampMaxSeqNums( + offRampMaxSeqNums []plugintypes.SeqNumChain, + observer commontypes.OracleID, + supportsDestChain bool, +) error { + if len(offRampMaxSeqNums) == 0 { + return nil + } + + if !supportsDestChain { + return fmt.Errorf("observer %d does not support dest chain, but has observed %d offRampMaxSeqNums", + observer, len(offRampMaxSeqNums)) + } + + seenChains := mapset.NewSet[cciptypes.ChainSelector]() + for _, seqNumChain := range offRampMaxSeqNums { + if seenChains.Contains(seqNumChain.ChainSel) { + return fmt.Errorf("duplicate offRampMaxSeqNum for chain %d", seqNumChain.ChainSel) + } + seenChains.Add(seqNumChain.ChainSel) + } + + return nil +} + +func validateFChain(fChain map[cciptypes.ChainSelector]int) error { + for _, f := range fChain { + if f < 0 { + return fmt.Errorf("fChain %d is negative", f) + } + } + + return nil +} diff --git a/commitrmnocb/validate_observation_test.go b/commitrmnocb/validate_observation_test.go new file mode 100644 index 000000000..68d3d6249 --- /dev/null +++ b/commitrmnocb/validate_observation_test.go @@ -0,0 +1,212 @@ +package commitrmnocb + +import ( + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/stretchr/testify/assert" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +func Test_validateObservedMerkleRoots(t *testing.T) { + testCases := []struct { + name string + merkleRoots []cciptypes.MerkleRootChain + observer commontypes.OracleID + observerSupportedChains mapset.Set[cciptypes.ChainSelector] + expErr bool + }{ + { + name: "Chain not supported", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), + expErr: true, + }, + { + name: "Duplicate chains", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{3, 7}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + merkleRoots: []cciptypes.MerkleRootChain{ + {ChainSel: 1, SeqNumsRange: [2]cciptypes.SeqNum{10, 20}, MerkleRoot: [32]byte{1, 2, 3}}, + {ChainSel: 2, SeqNumsRange: [2]cciptypes.SeqNum{24, 45}, MerkleRoot: [32]byte{1, 2, 3}}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedMerkleRoots(tc.merkleRoots, tc.observer, tc.observerSupportedChains) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateObservedOnRampMaxSeqNums(t *testing.T) { + testCases := []struct { + name string + onRampMaxSeqNums []plugintypes.SeqNumChain + observer commontypes.OracleID + observerSupportedChains mapset.Set[cciptypes.ChainSelector] + expErr bool + }{ + { + name: "Chain not supported", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 3, 4, 5), + expErr: true, + }, + { + name: "Duplicate chains", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + {ChainSel: 2, SeqNum: 33}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + onRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + observerSupportedChains: mapset.NewSet[cciptypes.ChainSelector](1, 2), + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedOnRampMaxSeqNums(tc.onRampMaxSeqNums, tc.observer, tc.observerSupportedChains) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateObservedOffRampMaxSeqNums(t *testing.T) { + testCases := []struct { + name string + offRampMaxSeqNums []plugintypes.SeqNumChain + observer commontypes.OracleID + supportsDestChain bool + expErr bool + }{ + { + name: "Dest chain not supported", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + supportsDestChain: false, + expErr: true, + }, + { + name: "Duplicate chains", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + {ChainSel: 2, SeqNum: 33}, + }, + observer: 10, + supportsDestChain: false, + expErr: true, + }, + { + name: "Valid offRampMaxSeqNums", + offRampMaxSeqNums: []plugintypes.SeqNumChain{ + {ChainSel: 1, SeqNum: 10}, + {ChainSel: 2, SeqNum: 20}, + }, + observer: 10, + supportsDestChain: true, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateObservedOffRampMaxSeqNums(tc.offRampMaxSeqNums, tc.observer, tc.supportsDestChain) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func Test_validateFChain(t *testing.T) { + testCases := []struct { + name string + fChain map[cciptypes.ChainSelector]int + expErr bool + }{ + { + name: "FChain contains negative values", + fChain: map[cciptypes.ChainSelector]int{ + 1: 11, + 2: -4, + }, + expErr: true, + }, + { + name: "FChain valid", + fChain: map[cciptypes.ChainSelector]int{ + 12: 6, + 7: 9, + }, + expErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateFChain(tc.fChain) + + if tc.expErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/execute/exectypes/outcome.go b/execute/exectypes/outcome.go index 6c39ad107..20dda0053 100644 --- a/execute/exectypes/outcome.go +++ b/execute/exectypes/outcome.go @@ -7,8 +7,47 @@ import ( cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) +type PluginState string + +const ( + // Unknown is the zero value, this allows a "Next" state transition for uninitialized values (i.e. the first round). + Unknown PluginState = "" + + // GetCommitReports is the first step, it is used to select commit reports from the destination chain. + GetCommitReports PluginState = "GetCommitReports" + + // GetMessages is the second step, given a set of commit reports it fetches the associated messages. + GetMessages PluginState = "GetMessages" + + // Filter is the final step, any additional destination data is collected to complete the execution report. + Filter PluginState = "Filter" +) + +// Next returns the next state for the plugin. The Unknown state is used to transition from uninitialized values. +func (p PluginState) Next() PluginState { + switch p { + case GetCommitReports: + return GetMessages + + case GetMessages: + // TODO: go to Filter after GetMessages + return GetCommitReports + + case Unknown: + fallthrough + case Filter: + return GetCommitReports + + default: + panic("unexpected execute plugin state") + } +} + // Outcome is the outcome of the ExecutePlugin. type Outcome struct { + // State that the outcome was generated for. + State PluginState + // PendingCommitReports are the oldest reports with pending commits. The slice is // sorted from oldest to newest. PendingCommitReports []CommitData `json:"commitReports"` @@ -17,28 +56,27 @@ type Outcome struct { Report cciptypes.ExecutePluginReport `json:"report"` } +// IsEmpty returns true if the outcome has no pending commit reports or chain reports. func (o Outcome) IsEmpty() bool { return len(o.PendingCommitReports) == 0 && len(o.Report.ChainReports) == 0 } +// NewOutcome creates a new Outcome with the pending commit reports and the chain reports sorted. func NewOutcome( + state PluginState, pendingCommits []CommitData, report cciptypes.ExecutePluginReport, ) Outcome { - return newSortedOutcome(pendingCommits, report) -} - -// Encode encodes the outcome by first sorting the pending commit reports and the chain reports -// and then JSON marshalling. -// The encoding MUST be deterministic. -func (o Outcome) Encode() ([]byte, error) { - // We sort again here in case construction is not via the constructor. - return json.Marshal(newSortedOutcome(o.PendingCommitReports, o.Report)) + return newSortedOutcome(state, pendingCommits, report) } +// newSortedOutcome ensures canonical ordering of the outcome. +// TODO: handle canonicalization in the encoder. func newSortedOutcome( + state PluginState, pendingCommits []CommitData, - report cciptypes.ExecutePluginReport) Outcome { + report cciptypes.ExecutePluginReport, +) Outcome { pendingCommitsCP := append([]CommitData{}, pendingCommits...) reportCP := append([]cciptypes.ExecutePluginReportSingleChain{}, report.ChainReports...) sort.Slice( @@ -52,11 +90,21 @@ func newSortedOutcome( return reportCP[i].SourceChainSelector < reportCP[j].SourceChainSelector }) return Outcome{ + State: state, PendingCommitReports: pendingCommitsCP, Report: cciptypes.ExecutePluginReport{ChainReports: reportCP}, } } +// Encode encodes the outcome by first sorting the pending commit reports and the chain reports +// and then JSON marshalling. +// The encoding MUST be deterministic. +func (o Outcome) Encode() ([]byte, error) { + // We sort again here in case construction is not via the constructor. + return json.Marshal(newSortedOutcome(o.State, o.PendingCommitReports, o.Report)) +} + +// DecodeOutcome decodes the outcome from JSON. func DecodeOutcome(b []byte) (Outcome, error) { o := Outcome{} err := json.Unmarshal(b, &o) diff --git a/execute/exectypes/outcome_test.go b/execute/exectypes/outcome_test.go new file mode 100644 index 000000000..35dd376b7 --- /dev/null +++ b/execute/exectypes/outcome_test.go @@ -0,0 +1,53 @@ +package exectypes + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPluginState_Next(t *testing.T) { + tests := []struct { + name string + p PluginState + want PluginState + isPanic bool + }{ + { + name: "Zero value", + p: Unknown, + want: GetCommitReports, + }, + { + name: "Phase 1 to 2", + p: GetCommitReports, + want: GetMessages, + }, + { + name: "Phase 2 to 1", + p: GetMessages, + want: GetCommitReports, + }, + { + name: "panic", + p: PluginState("ElToroLoco"), + isPanic: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.isPanic { + require.Panics(t, func() { + tt.p.Next() + }) + return + } + + if got := tt.p.Next(); got != tt.want { + t.Errorf("Next() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/execute/factory.go b/execute/factory.go index d8fa360d3..5f3d487e7 100644 --- a/execute/factory.go +++ b/execute/factory.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/libocr/commontypes" ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" + "github.com/smartcontractkit/chainlink-ccip/execute/exectypes" "github.com/smartcontractkit/chainlink-ccip/execute/internal/gas" "github.com/smartcontractkit/chainlink-ccip/internal/reader" "github.com/smartcontractkit/chainlink-ccip/pluginconfig" @@ -54,6 +55,7 @@ type PluginFactory struct { msgHasher cciptypes.MessageHasher homeChainReader reader.HomeChain estimateProvider gas.EstimateProvider + tokenDataReader exectypes.TokenDataReader contractReaders map[cciptypes.ChainSelector]types.ContractReader chainWriters map[cciptypes.ChainSelector]types.ChainWriter } @@ -64,6 +66,7 @@ func NewPluginFactory( execCodec cciptypes.ExecutePluginCodec, msgHasher cciptypes.MessageHasher, homeChainReader reader.HomeChain, + tokenDataReader exectypes.TokenDataReader, estimateProvider gas.EstimateProvider, contractReaders map[cciptypes.ChainSelector]types.ContractReader, chainWriters map[cciptypes.ChainSelector]types.ChainWriter, @@ -77,6 +80,7 @@ func NewPluginFactory( estimateProvider: estimateProvider, contractReaders: contractReaders, chainWriters: chainWriters, + tokenDataReader: tokenDataReader, } } @@ -115,7 +119,7 @@ func (p PluginFactory) NewReportingPlugin( p.execCodec, p.msgHasher, p.homeChainReader, - nil, // TODO: token data reader + p.tokenDataReader, p.estimateProvider, p.lggr, ), ocr3types.ReportingPluginInfo{ diff --git a/execute/plugin.go b/execute/plugin.go index c4609f74a..7a75fcf35 100644 --- a/execute/plugin.go +++ b/execute/plugin.go @@ -164,69 +164,90 @@ func (p *Plugin) Observation( if err != nil { return types.Observation{}, fmt.Errorf("unable to decode previous outcome: %w", err) } + p.lggr.Infow("decoded previous outcome", "previousOutcome", previousOutcome) } - fetchFrom := time.Now().Add(-p.cfg.OffchainConfig.MessageVisibilityInterval.Duration()).UTC() - p.lggr.Infow("decoded previous outcome", "previousOutcome", previousOutcome) + state := previousOutcome.State.Next() + switch state { + case exectypes.GetCommitReports: + fetchFrom := time.Now().Add(-p.cfg.OffchainConfig.MessageVisibilityInterval.Duration()).UTC() - // Phase 1: Gather commit reports from the destination chain and determine which messages are required to build a - // valid execution report. - var groupedCommits exectypes.CommitObservations - supportsDest, err := p.supportsDestChain() - if err != nil { - return types.Observation{}, fmt.Errorf("unable to determine if the destination chain is supported: %w", err) - } - if supportsDest { - groupedCommits, err = getPendingExecutedReports(ctx, p.ccipReader, p.cfg.DestChain, fetchFrom, p.lggr) + // Phase 1: Gather commit reports from the destination chain and determine which messages are required to build + // a valid execution report. + supportsDest, err := p.supportsDestChain() if err != nil { - return types.Observation{}, err + return types.Observation{}, fmt.Errorf("unable to determine if the destination chain is supported: %w", err) } + if supportsDest { + groupedCommits, err := getPendingExecutedReports(ctx, p.ccipReader, p.cfg.DestChain, fetchFrom, p.lggr) + if err != nil { + return types.Observation{}, err + } - // TODO: truncate grouped commits to a maximum observation size. - // Cache everything which is not executed. - } - - // Phase 2: Gather messages from the source chains and build the execution report. - messages := make(exectypes.MessageObservations) - if len(previousOutcome.PendingCommitReports) == 0 { - p.lggr.Debug("TODO: No reports to execute. This is expected after a cold start.") - // No reports to execute. - // This is expected after a cold start. - } else { - commitReportCache := make(map[cciptypes.ChainSelector][]exectypes.CommitData) - for _, report := range previousOutcome.PendingCommitReports { - commitReportCache[report.SourceChain] = append(commitReportCache[report.SourceChain], report) + // TODO: truncate grouped to a maximum observation size? + return exectypes.NewObservation(groupedCommits, nil).Encode() } - for selector, reports := range commitReportCache { - if len(reports) == 0 { - continue + // No observation for non-dest readers. + return types.Observation{}, nil + case exectypes.GetMessages: + // Phase 2: Gather messages from the source chains and build the execution report. + messages := make(exectypes.MessageObservations) + if len(previousOutcome.PendingCommitReports) == 0 { + p.lggr.Debug("TODO: No reports to execute. This is expected after a cold start.") + // No reports to execute. + // This is expected after a cold start. + } else { + commitReportCache := make(map[cciptypes.ChainSelector][]exectypes.CommitData) + for _, report := range previousOutcome.PendingCommitReports { + commitReportCache[report.SourceChain] = append(commitReportCache[report.SourceChain], report) } - ranges, err := computeRanges(reports) - if err != nil { - return types.Observation{}, err - } + for selector, reports := range commitReportCache { + if len(reports) == 0 { + continue + } - // Read messages for each range. - for _, seqRange := range ranges { - msgs, err := p.ccipReader.MsgsBetweenSeqNums(ctx, selector, seqRange) + ranges, err := computeRanges(reports) if err != nil { - return nil, err + return types.Observation{}, err } - for _, msg := range msgs { - if _, ok := messages[selector]; !ok { - messages[selector] = make(map[cciptypes.SeqNum]cciptypes.Message) + + // Read messages for each range. + for _, seqRange := range ranges { + msgs, err := p.ccipReader.MsgsBetweenSeqNums(ctx, selector, seqRange) + if err != nil { + return nil, err + } + for _, msg := range msgs { + if _, ok := messages[selector]; !ok { + messages[selector] = make(map[cciptypes.SeqNum]cciptypes.Message) + } + messages[selector][msg.Header.SequenceNumber] = msg } - messages[selector][msg.Header.SequenceNumber] = msg } } } - } - // TODO: Fire off messages for an attestation check service. + // Regroup the commit reports back into the observation format. + // TODO: use same format for Observation and Outcome. + groupedCommits := make(exectypes.CommitObservations) + for _, report := range previousOutcome.PendingCommitReports { + if _, ok := groupedCommits[report.SourceChain]; !ok { + groupedCommits[report.SourceChain] = []exectypes.CommitData{} + } + groupedCommits[report.SourceChain] = append(groupedCommits[report.SourceChain], report) + } + + // TODO: Fire off messages for an attestation check service. + return exectypes.NewObservation(groupedCommits, messages).Encode() - return exectypes.NewObservation(groupedCommits, messages).Encode() + case exectypes.Filter: + // TODO: pass the previous two through, add in the nonces. + return types.Observation{}, fmt.Errorf("unknown state") + default: + return types.Observation{}, fmt.Errorf("unknown state") + } } func (p *Plugin) ValidateObservation( @@ -320,6 +341,18 @@ func selectReport( func (p *Plugin) Outcome( outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, ) (ocr3types.Outcome, error) { + var previousOutcome exectypes.Outcome + if outctx.PreviousOutcome != nil { + var err error + previousOutcome, err = exectypes.DecodeOutcome(outctx.PreviousOutcome) + if err != nil { + return nil, fmt.Errorf("unable to decode previous outcome: %w", err) + } + } + + ///////////////////////////////////////////// + // Decode the observations and merge them. // + ///////////////////////////////////////////// decodedObservations, err := decodeAttributedObservations(aos) if err != nil { return ocr3types.Outcome{}, fmt.Errorf("unable to decode observations: %w", err) @@ -359,6 +392,10 @@ func (p *Plugin) Outcome( mergedCommitObservations, mergedMessageObservations) + ////////////////////////// + // common preprocessing // + ////////////////////////// + // flatten commit reports and sort by timestamp. var commitReports []exectypes.CommitData for _, report := range observation.CommitReports { @@ -372,46 +409,57 @@ func (p *Plugin) Outcome( fmt.Sprintf("[oracle %d] exec outcome: commit reports", p.reportingCfg.OracleID), "commitReports", commitReports) - // add messages to their commitReports. - for i, report := range commitReports { - report.Messages = nil - for i := report.SequenceNumberRange.Start(); i <= report.SequenceNumberRange.End(); i++ { - if msg, ok := observation.Messages[report.SourceChain][i]; ok { - report.Messages = append(report.Messages, msg) + state := previousOutcome.State.Next() + switch state { + case exectypes.GetCommitReports: + outcome := exectypes.NewOutcome(state, commitReports, cciptypes.ExecutePluginReport{}) + return outcome.Encode() + case exectypes.GetMessages: + // add messages to their commitReports. + for i, report := range commitReports { + report.Messages = nil + for i := report.SequenceNumberRange.Start(); i <= report.SequenceNumberRange.End(); i++ { + if msg, ok := observation.Messages[report.SourceChain][i]; ok { + report.Messages = append(report.Messages, msg) + } } + commitReports[i].Messages = report.Messages } - commitReports[i].Messages = report.Messages - } - // TODO: this function should be pure, a context should not be needed. - outcomeReports, commitReports, err := - selectReport( - context.Background(), - p.lggr, p.msgHasher, - p.reportCodec, - p.tokenDataReader, - p.estimateProvider, - commitReports, - maxReportSizeBytes, - p.cfg.OffchainConfig.BatchGasLimit) - if err != nil { - return ocr3types.Outcome{}, fmt.Errorf("unable to extract proofs: %w", err) - } + // TODO: this function should be pure, a context should not be needed. + outcomeReports, commitReports, err := + selectReport( + context.Background(), + p.lggr, p.msgHasher, + p.reportCodec, + p.tokenDataReader, + p.estimateProvider, + commitReports, + maxReportSizeBytes, + p.cfg.OffchainConfig.BatchGasLimit) + if err != nil { + return ocr3types.Outcome{}, fmt.Errorf("unable to extract proofs: %w", err) + } - execReport := cciptypes.ExecutePluginReport{ - ChainReports: outcomeReports, - } + execReport := cciptypes.ExecutePluginReport{ + ChainReports: outcomeReports, + } - outcome := exectypes.NewOutcome(commitReports, execReport) - if outcome.IsEmpty() { - return nil, nil - } + outcome := exectypes.NewOutcome(state, commitReports, execReport) + if outcome.IsEmpty() { + return nil, nil + } - p.lggr.Infow( - fmt.Sprintf("[oracle %d] exec outcome: generated outcome", p.reportingCfg.OracleID), - "outcome", outcome) + p.lggr.Infow( + fmt.Sprintf("[oracle %d] exec outcome: generated outcome", p.reportingCfg.OracleID), + "outcome", outcome) - return outcome.Encode() + return outcome.Encode() + case exectypes.Filter: + panic("not implemented") + default: + panic("unknown state") + } } func (p *Plugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[[]byte], error) { diff --git a/execute/plugin_e2e_test.go b/execute/plugin_e2e_test.go index e14025582..b58612bb5 100644 --- a/execute/plugin_e2e_test.go +++ b/execute/plugin_e2e_test.go @@ -14,6 +14,7 @@ import ( commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/smartcontractkit/chainlink-ccip/chainconfig" @@ -112,6 +113,10 @@ func setupHomeChainPoller(lggr logger.Logger, chainConfigInfos []reader.ChainCon // to prevent linting error because of logging after finishing tests, we close the poller after each test, having // lower polling interval make it catch up faster time.Minute, + types.BoundContract{ + Address: "0xCCIPConfigFakeAddress", + Name: consts.ContractNameCCIPConfig, + }, ) return homeChain diff --git a/execute/plugin_test.go b/execute/plugin_test.go index 2b0fe365d..3119940ce 100644 --- a/execute/plugin_test.go +++ b/execute/plugin_test.go @@ -406,7 +406,7 @@ func TestPlugin_Reports_UnableToEncode(t *testing.T) { codec.On("Encode", mock.Anything, mock.Anything). Return(nil, fmt.Errorf("test error")) p := &Plugin{reportCodec: codec} - report, err := exectypes.NewOutcome(nil, cciptypes.ExecutePluginReport{}).Encode() + report, err := exectypes.NewOutcome(exectypes.Unknown, nil, cciptypes.ExecutePluginReport{}).Encode() require.NoError(t, err) _, err = p.Reports(0, report) diff --git a/internal/mocks/chain_support.go b/internal/mocks/chain_support.go new file mode 100644 index 000000000..7c9ed5e75 --- /dev/null +++ b/internal/mocks/chain_support.go @@ -0,0 +1,36 @@ +package mocks + +import ( + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/stretchr/testify/mock" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" +) + +type ChainSupport struct { + *mock.Mock +} + +func NewChainSupport() *ChainSupport { + return &ChainSupport{ + Mock: &mock.Mock{}, + } +} + +func (c ChainSupport) KnownSourceChainsSlice() ([]cciptypes.ChainSelector, error) { + args := c.Called() + return args.Get(0).([]cciptypes.ChainSelector), args.Error(1) +} + +// SupportedChains returns the set of chains that the given Oracle is configured to access +func (c ChainSupport) SupportedChains(oracleID commontypes.OracleID) (mapset.Set[cciptypes.ChainSelector], error) { + args := c.Called(oracleID) + return args.Get(0).(mapset.Set[cciptypes.ChainSelector]), args.Error(1) +} + +// SupportsDestChain returns true if the given oracle supports the dest chain, returns false otherwise +func (c ChainSupport) SupportsDestChain(oracleID commontypes.OracleID) (bool, error) { + args := c.Called(oracleID) + return args.Get(0).(bool), args.Error(1) +} diff --git a/internal/mocks/observer.go b/internal/mocks/observer.go new file mode 100644 index 000000000..2eb470fb1 --- /dev/null +++ b/internal/mocks/observer.go @@ -0,0 +1,49 @@ +package mocks + +import ( + "context" + + "github.com/stretchr/testify/mock" + + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + + "github.com/smartcontractkit/chainlink-ccip/plugintypes" +) + +type Observer struct { + *mock.Mock +} + +func NewObserver() *Observer { + return &Observer{ + Mock: &mock.Mock{}, + } +} + +func (o Observer) ObserveOffRampNextSeqNums(ctx context.Context) []plugintypes.SeqNumChain { + args := o.Called(ctx) + return args.Get(0).([]plugintypes.SeqNumChain) +} + +func (o Observer) ObserveMerkleRoots( + ctx context.Context, + ranges []plugintypes.ChainRange, +) []cciptypes.MerkleRootChain { + args := o.Called(ctx, ranges) + return args.Get(0).([]cciptypes.MerkleRootChain) +} + +func (o Observer) ObserveTokenPrices(ctx context.Context) []cciptypes.TokenPrice { + args := o.Called(ctx) + return args.Get(0).([]cciptypes.TokenPrice) +} + +func (o Observer) ObserveGasPrices(ctx context.Context) []cciptypes.GasPriceChain { + args := o.Called(ctx) + return args.Get(0).([]cciptypes.GasPriceChain) +} + +func (o Observer) ObserveFChain() map[cciptypes.ChainSelector]int { + args := o.Called() + return args.Get(0).(map[cciptypes.ChainSelector]int) +} diff --git a/internal/plugincommon/ccipreader_test.go b/internal/plugincommon/ccipreader_test.go index 5e52eaabb..b311a8280 100644 --- a/internal/plugincommon/ccipreader_test.go +++ b/internal/plugincommon/ccipreader_test.go @@ -52,7 +52,7 @@ func TestBackgroundReaderSyncer(t *testing.T) { assert.NoError(t, err, "start success") assert.Eventually(t, func() bool { return mockReader.AssertExpectations(t) - }, time.Second, 10*time.Millisecond) + }, 3*time.Second, 10*time.Millisecond) err = readerSyncer.Close() assert.NoError(t, err, "closing a started syncer") }) diff --git a/internal/reader/home_chain.go b/internal/reader/home_chain.go index fd91bc9a8..c679bdece 100644 --- a/internal/reader/home_chain.go +++ b/internal/reader/home_chain.go @@ -63,6 +63,10 @@ type homeChainPoller struct { mutex *sync.RWMutex state state failedPolls uint + // TODO: currently unused but will be passed into GetLatestValue + // once the chainlink-common breaking change comes in + // (https://github.com/smartcontractkit/chainlink-common/pull/603). + ccipConfigBoundContract types.BoundContract // How frequently the poller fetches the chain configs pollingDuration time.Duration } @@ -73,15 +77,17 @@ func NewHomeChainConfigPoller( homeChainReader types.ContractReader, lggr logger.Logger, pollingInterval time.Duration, + ccipConfigBoundContract types.BoundContract, ) HomeChain { return &homeChainPoller{ - stopCh: make(chan struct{}), - homeChainReader: homeChainReader, - state: state{}, - mutex: &sync.RWMutex{}, - failedPolls: 0, - lggr: lggr, - pollingDuration: pollingInterval, + stopCh: make(chan struct{}), + homeChainReader: homeChainReader, + state: state{}, + mutex: &sync.RWMutex{}, + failedPolls: 0, + lggr: lggr, + pollingDuration: pollingInterval, + ccipConfigBoundContract: ccipConfigBoundContract, } } diff --git a/internal/reader/home_chain_test.go b/internal/reader/home_chain_test.go index 160b2b2ef..3466154b0 100644 --- a/internal/reader/home_chain_test.go +++ b/internal/reader/home_chain_test.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" "github.com/stretchr/testify/mock" @@ -53,6 +54,10 @@ func TestHomeChainConfigPoller_HealthReport(t *testing.T) { homeChainReader, logger.Test(t), tickTime, + types.BoundContract{ + Address: "0xCCIPConfigFakeAddress", + Name: consts.ContractNameCCIPConfig, + }, ) require.NoError(t, configPoller.Start(context.Background())) // Initially it's healthy @@ -157,6 +162,10 @@ func Test_PollingWorking(t *testing.T) { homeChainReader, logger.Test(t), tickTime, + types.BoundContract{ + Address: "0xCCIPConfigFakeAddress", + Name: consts.ContractNameCCIPConfig, + }, ) require.NoError(t, configPoller.Start(context.Background())) @@ -212,6 +221,10 @@ func Test_HomeChainPoller_GetOCRConfig(t *testing.T) { homeChainReader, logger.Test(t), 10*time.Millisecond, + types.BoundContract{ + Address: "0xCCIPConfigFakeAddress", + Name: consts.ContractNameCCIPConfig, + }, ) configs, err := configPoller.GetOCRConfigs(context.Background(), donID, pluginType) diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 788dcd9cc..bde09b1c1 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -75,5 +75,5 @@ const ( const ( EventAttributeSequenceNumber = "SequenceNumber" EventAttributeSourceChain = "SourceChain" - EventAttributeDestChain = "DestChain" + EventAttributeDestChain = "destChain" ) diff --git a/pkg/reader/home_chain.go b/pkg/reader/home_chain.go index e2d38e78d..8bc07d88c 100644 --- a/pkg/reader/home_chain.go +++ b/pkg/reader/home_chain.go @@ -23,6 +23,7 @@ func NewHomeChainReader( homeChainReader types.ContractReader, lggr logger.Logger, pollingInterval time.Duration, + ccipConfigBoundContract types.BoundContract, ) HomeChain { - return reader_internal.NewHomeChainConfigPoller(homeChainReader, lggr, pollingInterval) + return reader_internal.NewHomeChainConfigPoller(homeChainReader, lggr, pollingInterval, ccipConfigBoundContract) } diff --git a/pluginconfig/commit.go b/pluginconfig/commit.go index c535a9aba..db99daf29 100644 --- a/pluginconfig/commit.go +++ b/pluginconfig/commit.go @@ -22,6 +22,9 @@ type CommitPluginConfig struct { // NewMsgScanBatchSize is the number of max new messages to scan, typically set to 256. NewMsgScanBatchSize int `json:"newMsgScanBatchSize"` + // The maximum number of times to check if the previous report has been transmitted + MaxReportTransmissionCheckAttempts uint + // SyncTimeout is the timeout for syncing the commit plugin reader. SyncTimeout time.Duration `json:"syncTimeout"` diff --git a/plugintypes/commit.go b/plugintypes/commit.go index ea38dee6f..7b4659faa 100644 --- a/plugintypes/commit.go +++ b/plugintypes/commit.go @@ -110,3 +110,8 @@ func NewSeqNumChain(chainSel cciptypes.ChainSelector, seqNum cciptypes.SeqNum) S SeqNum: seqNum, } } + +type ChainRange struct { + ChainSel cciptypes.ChainSelector `json:"chain"` + SeqNumRange cciptypes.SeqNumRange `json:"seqNumRange"` +}