Skip to content

Commit

Permalink
Commit and Execute - Custom transmission schedule (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimkouv authored Nov 7, 2024
1 parent 4c9ee21 commit f49fe84
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 20 deletions.
29 changes: 20 additions & 9 deletions commit/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"time"

"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
"golang.org/x/exp/maps"

"github.com/smartcontractkit/chainlink-ccip/commit/merkleroot"
"github.com/smartcontractkit/chainlink-ccip/internal/plugincommon"
"github.com/smartcontractkit/chainlink-ccip/internal/plugincommon/consensus"
"github.com/smartcontractkit/chainlink-ccip/pkg/consts"
cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)

const (
// transmissionDelayMultiplier is used to calculate the transmission delay for each oracle.
transmissionDelayMultiplier = 3 * time.Second
)

// ReportInfo is the info data that will be sent with the along with the report
// It will be used to determine if the report should be accepted or not
type ReportInfo struct {
Expand Down Expand Up @@ -84,12 +92,24 @@ func (p *Plugin) Reports(
return nil, fmt.Errorf("encode report info: %w", err)
}

transmissionSchedule, err := plugincommon.GetTransmissionSchedule(
p.chainSupport,
maps.Keys(p.oracleIDToP2PID),
transmissionDelayMultiplier,
)
if err != nil {
return nil, fmt.Errorf("get transmission schedule: %w", err)
}
p.lggr.Debugw("transmission schedule override",
"transmissionSchedule", transmissionSchedule, "oracleIDToP2PID", p.oracleIDToP2PID)

return []ocr3types.ReportPlus[[]byte]{
{
ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{
Report: encodedReport,
Info: infoBytes,
},
TransmissionScheduleOverride: transmissionSchedule,
},
}, nil
}
Expand Down Expand Up @@ -127,15 +147,6 @@ func (p *Plugin) ShouldAcceptAttestedReport(
func (p *Plugin) ShouldTransmitAcceptedReport(
ctx context.Context, u uint64, r ocr3types.ReportWithInfo[[]byte],
) (bool, error) {
isWriter, err := p.chainSupport.SupportsDestChain(p.oracleID)
if err != nil {
return false, fmt.Errorf("can't know if it's a writer: %w", err)
}
if !isWriter {
p.lggr.Infow("not a writer, skipping report transmission")
return false, nil
}

// we only transmit reports if we are the "active" instance.
// we can check this by reading the OCR configs from the home chain.
isCandidate, err := p.isCandidateInstance(ctx)
Expand Down
11 changes: 9 additions & 2 deletions commit/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"fmt"
"testing"

"github.com/smartcontractkit/libocr/commontypes"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/smartcontractkit/chainlink-ccip/internal/libs/testhelpers/rand"
"github.com/smartcontractkit/chainlink-ccip/mocks/internal_/plugincommon"
reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader"
"github.com/smartcontractkit/chainlink-ccip/pkg/consts"

Expand Down Expand Up @@ -130,10 +133,14 @@ func TestPluginReports(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cs := plugincommon.NewMockChainSupport(t)
p := Plugin{
lggr: lggr,
reportCodec: reportCodec,
lggr: lggr,
reportCodec: reportCodec,
oracleIDToP2PID: map[commontypes.OracleID]libocrtypes.PeerID{1: {1}},
chainSupport: cs,
}
cs.EXPECT().SupportsDestChain(commontypes.OracleID(1)).Return(true, nil).Maybe()

outcomeBytes, err := tc.outc.Encode()
require.NoError(t, err)
Expand Down
48 changes: 39 additions & 9 deletions execute/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

mapset "github.com/deckarep/golang-set/v2"
"golang.org/x/exp/maps"

"github.com/smartcontractkit/libocr/commontypes"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
Expand All @@ -15,6 +16,8 @@ import (

"github.com/smartcontractkit/chainlink-common/pkg/logger"

"github.com/smartcontractkit/chainlink-ccip/internal/plugincommon"

"github.com/smartcontractkit/chainlink-ccip/execute/exectypes"
"github.com/smartcontractkit/chainlink-ccip/execute/internal/gas"
"github.com/smartcontractkit/chainlink-ccip/execute/report"
Expand All @@ -31,6 +34,11 @@ import (
// maxReportSizeBytes that should be returned as an execution report payload.
const maxReportSizeBytes = 250_000

const (
// transmissionDelayMultiplier is used to calculate the transmission delay for each oracle.
transmissionDelayMultiplier = 3 * time.Second
)

// Plugin implements the main ocr3 plugin logic.
type Plugin struct {
donID plugintypes.DonID
Expand All @@ -39,11 +47,12 @@ type Plugin struct {
destChain cciptypes.ChainSelector

// providers
ccipReader readerpkg.CCIPReader
reportCodec cciptypes.ExecutePluginCodec
msgHasher cciptypes.MessageHasher
homeChain reader.HomeChain
discovery *discovery.ContractDiscoveryProcessor
ccipReader readerpkg.CCIPReader
reportCodec cciptypes.ExecutePluginCodec
msgHasher cciptypes.MessageHasher
homeChain reader.HomeChain
discovery *discovery.ContractDiscoveryProcessor
chainSupport plugincommon.ChainSupport

oracleIDToP2pID map[commontypes.OracleID]libocrtypes.PeerID
tokenDataObserver tokendata.TokenDataObserver
Expand Down Expand Up @@ -96,6 +105,13 @@ func NewPlugin(
reportingCfg.F,
oracleIDToP2pID,
),
chainSupport: plugincommon.NewCCIPChainSupport(
lggr,
homeChain,
oracleIDToP2pID,
reportingCfg.OracleID,
destChain,
),
}
}

Expand Down Expand Up @@ -254,11 +270,25 @@ func (p *Plugin) Reports(
return nil, fmt.Errorf("unable to encode report: %w", err)
}

transmissionSchedule, err := plugincommon.GetTransmissionSchedule(
p.chainSupport,
maps.Keys(p.oracleIDToP2pID),
transmissionDelayMultiplier,
)
if err != nil {
return nil, fmt.Errorf("get transmission schedule: %w", err)
}
p.lggr.Debugw("transmission schedule override",
"transmissionSchedule", transmissionSchedule, "oracleIDToP2PID", p.oracleIDToP2pID)

report := []ocr3types.ReportPlus[[]byte]{
{ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{
Report: encoded,
Info: nil,
}},
{
ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{
Report: encoded,
Info: nil,
},
TransmissionScheduleOverride: transmissionSchedule,
},
}

return report, nil
Expand Down
54 changes: 54 additions & 0 deletions internal/plugincommon/transmitters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package plugincommon

import (
"fmt"
"time"

"github.com/smartcontractkit/libocr/commontypes"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
)

// GetTransmissionSchedule returns a TransmissionSchedule for the provided oracles.
// It uses the ChainSupport service to query which oracles support the destination chain.
// It returns an error if no oracles support the destination chain.
// Read more about TransmissionDelay at ocr3types.TransmissionSchedule
// The transmissionDelayMultiplier is used in the following way:
//
// Assume that we have transmitters: [1, 3, 5]
// And transmissionDelayMultiplier = 5s
// Then the transmission delays will be: [5s, 10s, 15s]
func GetTransmissionSchedule(
chainSupport ChainSupport,
allTheOracles []commontypes.OracleID,
transmissionDelayMultiplier time.Duration,
) (*ocr3types.TransmissionSchedule, error) {
transmitters := make([]commontypes.OracleID, 0, len(allTheOracles))
for _, oracleID := range allTheOracles {
supportsDestChain, err := chainSupport.SupportsDestChain(oracleID)
if err != nil {
return nil, fmt.Errorf("supports dest chain %d: %w", oracleID, err)
}
if supportsDestChain {
transmitters = append(transmitters, oracleID)
}
}

transmissionDelays := make([]time.Duration, len(transmitters))

for i := range transmissionDelays {
transmissionDelays[i] = (transmissionDelayMultiplier) * time.Duration(i+1)
}

if len(transmitters) == 0 {
return nil, fmt.Errorf("no transmitters")
}

if len(transmitters) != len(transmissionDelays) {
return nil, fmt.Errorf("critical issue mismatched transmitters and transmission delays")
}

return &ocr3types.TransmissionSchedule{
Transmitters: transmitters,
TransmissionDelays: transmissionDelays,
}, nil
}
83 changes: 83 additions & 0 deletions internal/plugincommon/transmitters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package plugincommon_test

import (
"errors"
"testing"
"time"

mapset "github.com/deckarep/golang-set/v2"
"github.com/smartcontractkit/libocr/commontypes"
"github.com/stretchr/testify/require"

plugincommon2 "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon"
"github.com/smartcontractkit/chainlink-ccip/mocks/internal_/plugincommon"
)

func TestGetTransmissionSchedule(t *testing.T) {
testCases := []struct {
name string
allTheOracles []commontypes.OracleID
oraclesSupportingDest []commontypes.OracleID
transmissionDelayMultiplier time.Duration
expectedTransmitters []commontypes.OracleID
expectedTransmissionDelays []time.Duration
expectedError bool
chainSupportReturnsError bool
}{
{
name: "no oracles supporting dest leads to error",
allTheOracles: []commontypes.OracleID{1, 2, 3},
oraclesSupportingDest: []commontypes.OracleID{},
transmissionDelayMultiplier: 5 * time.Second,
expectedTransmitters: []commontypes.OracleID{},
expectedTransmissionDelays: []time.Duration{},
expectedError: true,
},
{
name: "some transmitters supporting dest",
allTheOracles: []commontypes.OracleID{1, 2, 3},
oraclesSupportingDest: []commontypes.OracleID{1, 3},
transmissionDelayMultiplier: 5 * time.Second,
expectedTransmitters: []commontypes.OracleID{1, 3},
expectedTransmissionDelays: []time.Duration{5 * time.Second, 10 * time.Second},
expectedError: false,
},
{
name: "chainsupport returns error",
allTheOracles: []commontypes.OracleID{1, 2, 3},
oraclesSupportingDest: []commontypes.OracleID{1, 3},
transmissionDelayMultiplier: 5 * time.Second,
expectedError: true,
chainSupportReturnsError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cs := plugincommon.NewMockChainSupport(t)
destSupportingOraclesSet := mapset.NewSet(tc.oraclesSupportingDest...)
for _, oracleID := range tc.allTheOracles {
var err error
if tc.chainSupportReturnsError {
err = errors.New("some error")
}
cs.On("SupportsDestChain", oracleID).
Return(destSupportingOraclesSet.Contains(oracleID), err).Maybe()
}

transmissionSchedule, err := plugincommon2.GetTransmissionSchedule(
cs,
tc.allTheOracles,
tc.transmissionDelayMultiplier,
)
if tc.expectedError {
require.Error(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tc.expectedTransmitters, transmissionSchedule.Transmitters)
require.Equal(t, tc.expectedTransmissionDelays, transmissionSchedule.TransmissionDelays)
})
}
}

0 comments on commit f49fe84

Please sign in to comment.