Skip to content

Commit

Permalink
Merge branch 'main' into solana/mcm-clear-root
Browse files Browse the repository at this point in the history
  • Loading branch information
jadepark-dev authored Dec 20, 2024
2 parents bdbe7d5 + bf224dc commit 880d236
Show file tree
Hide file tree
Showing 13 changed files with 720 additions and 559 deletions.
16 changes: 16 additions & 0 deletions execute/exectypes/observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"

"github.com/smartcontractkit/chainlink-ccip/execute/internal"
dt "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon/discovery/discoverytypes"
cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)
Expand All @@ -17,6 +18,8 @@ type MessageObservations map[cciptypes.ChainSelector]map[cciptypes.SeqNum]ccipty

type MessageHashes map[cciptypes.ChainSelector]map[cciptypes.SeqNum]cciptypes.Bytes32

type EncodedMsgAndTokenDataSizes map[cciptypes.ChainSelector]map[cciptypes.SeqNum]int

// Flatten nested maps into a slice of messages.
func (mo MessageObservations) Flatten() []cciptypes.Message {
var results []cciptypes.Message
Expand All @@ -43,6 +46,19 @@ func GetHashes(ctx context.Context, mo MessageObservations, hasher cciptypes.Mes
return hashes, nil
}

// GetEncodedMsgAndTokenDataSizes calculates the encoded sizes of messages and their token data counterpart.
func GetEncodedMsgAndTokenDataSizes(mo MessageObservations, tds TokenDataObservations) EncodedMsgAndTokenDataSizes {
sizes := make(EncodedMsgAndTokenDataSizes)
for chain, msgs := range mo {
sizes[chain] = make(map[cciptypes.SeqNum]int)
for seq, msg := range msgs {
td := tds[chain][seq]
sizes[chain][seq] = internal.EncodedSize(msg) + internal.EncodedSize(td)
}
}
return sizes
}

// NonceObservations contain the latest nonce for senders in the previously observed messages.
// Nonces are organized by source chain selector and the string encoded sender address. The address
// must be encoding according to the destination chain requirements with typeconv.AddressBytesToString.
Expand Down
4 changes: 3 additions & 1 deletion execute/internal/gas/gas_estimate_provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gas

import "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
import (
"github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)

// EstimateProvider is used to estimate the gas cost of a message or a merkle tree.
// TODO: Move to pkg/types/ccipocr3 or remove.
Expand Down
20 changes: 20 additions & 0 deletions execute/internal/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package internal

import "encoding/json"

func EncodedSize[T any](obj T) int {
enc, err := json.Marshal(obj)
if err != nil {
return 0
}
return len(enc)
}

func RemoveIthElement[T any](slice []T, i int) []T {
if i < 0 || i >= len(slice) {
return slice // Return the original slice if index is out of bounds
}
newSlice := make([]T, 0, len(slice)-1)
newSlice = append(newSlice, slice[:i]...)
return append(newSlice, slice[i+1:]...)
}
31 changes: 31 additions & 0 deletions execute/internal/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package internal

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestRemoveIthElement(t *testing.T) {
tests := []struct {
name string
slice []int
index int
expected []int
}{
{"Remove middle element", []int{1, 2, 3, 4, 5}, 2, []int{1, 2, 4, 5}},
{"Remove first element", []int{1, 2, 3, 4, 5}, 0, []int{2, 3, 4, 5}},
{"Remove last element", []int{1, 2, 3, 4, 5}, 4, []int{1, 2, 3, 4}},
{"Index out of bounds (negative)", []int{1, 2, 3, 4, 5}, -1, []int{1, 2, 3, 4, 5}},
{"Index out of bounds (too large)", []int{1, 2, 3, 4, 5}, 5, []int{1, 2, 3, 4, 5}},
{"Single element slice", []int{1}, 0, []int{}},
{"Empty slice", []int{}, 0, []int{}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RemoveIthElement(tt.slice, tt.index)
assert.Equal(t, tt.expected, result)
})
}
}
7 changes: 6 additions & 1 deletion execute/observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ func (p *Plugin) getMessagesObservation(
if err1 != nil {
return exectypes.Observation{}, fmt.Errorf("unable to process token data %w", err1)
}
if validateTokenDataObservations(messageObs, tkData) != nil {
return exectypes.Observation{}, fmt.Errorf("invalid token data observations")
}

costlyMessages, err := p.costlyMessageObserver.Observe(ctx, messageObs.Flatten(), messageTimestamps)
if err != nil {
Expand All @@ -252,9 +255,11 @@ func (p *Plugin) getMessagesObservation(
observation.Hashes = hashes
observation.CostlyMessages = costlyMessages
observation.TokenData = tkData
//observation.MessageAndTokenDataEncodedSizes = exectypes.GetEncodedMsgAndTokenDataSizes(messageObs, tkData)

// Make sure encoded observation fits within the maximum observation size.
observation, err = truncateObservation(observation, maxObservationLength)
//observation, err = truncateObservation(observation, maxObservationLength, p.emptyEncodedSizes)
observation, err = p.observationOptimizer.TruncateObservation(observation)
if err != nil {
return exectypes.Observation{}, fmt.Errorf("unable to truncate observation: %w", err)
}
Expand Down
220 changes: 220 additions & 0 deletions execute/optimizers/type_optimizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
package optimizers

import (
"fmt"
"sort"

"github.com/smartcontractkit/chainlink-ccip/execute/exectypes"
"github.com/smartcontractkit/chainlink-ccip/execute/internal"

"golang.org/x/exp/maps"

cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)

type ObservationOptimizer struct {
maxEncodedSize int
emptyEncodedSizes EmptyEncodeSizes
}

func NewObservationOptimizer(maxEncodedSize int) ObservationOptimizer {
return ObservationOptimizer{
maxEncodedSize: maxEncodedSize,
emptyEncodedSizes: NewEmptyEncodeSizes(),
}
}

type EmptyEncodeSizes struct {
MessageAndTokenData int
CommitData int
SeqNumMap int
}

func NewEmptyEncodeSizes() EmptyEncodeSizes {
emptyMsg := cciptypes.Message{}
emptyTokenData := exectypes.MessageTokenData{}
emptyCommitData := exectypes.CommitData{}
emptySeqNrSize := internal.EncodedSize(make(map[cciptypes.SeqNum]cciptypes.Message))

return EmptyEncodeSizes{
MessageAndTokenData: internal.EncodedSize(emptyMsg) + internal.EncodedSize(emptyTokenData),
CommitData: internal.EncodedSize(emptyCommitData), // 305
SeqNumMap: emptySeqNrSize, // 2
}
}

// TruncateObservation truncates the observation to fit within the given op.maxEncodedSize after encoding.
// It removes data from the observation in the following order:
// For each chain, pick last report and start removing messages one at a time.
// If removed all messages from the report, remove the report.
// If removed last report in the chain, remove the chain.
// After removing full report from a chain, move to the next chain and repeat. This ensures that we don't
// exclude messages from one chain only.
// Keep repeating this process until the encoded observation fits within the op.maxEncodedSize
// Important Note: We can't delete messages completely from single reports as we need them to create merkle proofs.
//
//nolint:gocyclo
func (op ObservationOptimizer) TruncateObservation(observation exectypes.Observation) (exectypes.Observation, error) {
obs := observation
encodedObs, err := obs.Encode()
if err != nil {
return exectypes.Observation{}, err
}
encodedObsSize := len(encodedObs)
if encodedObsSize <= op.maxEncodedSize {
return obs, nil
}

chains := maps.Keys(obs.CommitReports)
sort.Slice(chains, func(i, j int) bool {
return chains[i] < chains[j]
})

messageAndTokenDataEncodedSizes := exectypes.GetEncodedMsgAndTokenDataSizes(obs.Messages, obs.TokenData)
// While the encoded obs is too large, continue filtering data.
for encodedObsSize > op.maxEncodedSize {
// go through each chain and truncate observations for the final commit report.
for _, chain := range chains {
commits := obs.CommitReports[chain]
if len(commits) == 0 {
continue
}
lastCommit := &commits[len(commits)-1]
seqNum := lastCommit.SequenceNumberRange.Start()
// Remove messages one by one starting from the last message of the last commit report.
for seqNum <= lastCommit.SequenceNumberRange.End() {
if _, ok := obs.Messages[chain][seqNum]; !ok {
return exectypes.Observation{}, fmt.Errorf("missing message with seqNr %d from chain %d", seqNum, chain)
}
obs.Messages[chain][seqNum] = cciptypes.Message{}
obs.TokenData[chain][seqNum] = exectypes.NewMessageTokenData()
// Subtract the removed message and token size
encodedObsSize -= messageAndTokenDataEncodedSizes[chain][seqNum]
// Add empty message and token encoded size
encodedObsSize += op.emptyEncodedSizes.MessageAndTokenData
seqNum++
// Once we assert the estimation is less than the max size we double-check with actual encoding size.
// Otherwise, we short circuit after checking the estimation only
if encodedObsSize <= op.maxEncodedSize && fitsWithinSize(obs, op.maxEncodedSize) {
return obs, nil
}
}

var bytesTruncated int
// Reaching here means that all messages in the report are truncated, truncate the last commit
obs, bytesTruncated = op.truncateLastCommit(obs, chain)

encodedObsSize -= bytesTruncated

if len(obs.CommitReports[chain]) == 0 {
// If the last commit report was truncated, truncate the chain
obs = op.truncateChain(obs, chain)
}

// Once we assert the estimation is less than the max size we double-check with actual encoding size.
// Otherwise, we short circuit after checking the estimation only
if encodedObsSize <= op.maxEncodedSize && fitsWithinSize(obs, op.maxEncodedSize) {
return obs, nil
}
}
// Truncated all chains. Return obs as is. (it has other data like contract discovery)
if len(obs.CommitReports) == 0 {
return obs, nil
}
// Encoding again after doing a full iteration on all chains and removing messages/commits.
// That is because using encoded sizes is not 100% accurate and there are some missing bytes in the calculation.
encodedObs, err = obs.Encode()
if err != nil {
return exectypes.Observation{}, err
}
encodedObsSize = len(encodedObs)
}

return obs, nil
}

func fitsWithinSize(obs exectypes.Observation, maxEncodedSize int) bool {
encodedObs, err := obs.Encode()
if err != nil {
return false
}
return len(encodedObs) <= maxEncodedSize
}

// truncateLastCommit removes the last commit from the observation.
// returns observation and the number of bytes truncated.
func (op ObservationOptimizer) truncateLastCommit(
obs exectypes.Observation,
chain cciptypes.ChainSelector,
) (exectypes.Observation, int) {
observation := obs
bytesTruncated := 0
commits := observation.CommitReports[chain]
if len(commits) == 0 {
return observation, bytesTruncated
}
lastCommit := commits[len(commits)-1]
// Remove the last commit from the list.
commits = commits[:len(commits)-1]
observation.CommitReports[chain] = commits
// Remove from the encoded size.
bytesTruncated = bytesTruncated + op.emptyEncodedSizes.CommitData + 4 // brackets, and commas
for seqNum, msg := range observation.Messages[chain] {
if lastCommit.SequenceNumberRange.Contains(seqNum) {
// Remove the message from the observation.
delete(observation.Messages[chain], seqNum)
// Remove the token data from the observation.
delete(observation.TokenData[chain], seqNum)
//delete(observation.Hashes[chain], seqNum)
// Remove the encoded size of the message and token data.
bytesTruncated += op.emptyEncodedSizes.MessageAndTokenData
bytesTruncated = bytesTruncated + 2*op.emptyEncodedSizes.SeqNumMap
bytesTruncated += 4 // for brackets and commas
// Remove costly messages
for i, costlyMessage := range observation.CostlyMessages {
if costlyMessage == msg.Header.MessageID {
observation.CostlyMessages = internal.RemoveIthElement(observation.CostlyMessages, i)
break
}
}
// Leaving Nonces untouched
}
}

return observation, bytesTruncated
}

// truncateChain removes all data related to the given chain from the observation.
// returns true if the chain was found and truncated, false otherwise.
func (op ObservationOptimizer) truncateChain(
obs exectypes.Observation,
chain cciptypes.ChainSelector,
) exectypes.Observation {
observation := obs
if _, ok := observation.CommitReports[chain]; !ok {
return observation
}
messageIDs := make(map[cciptypes.Bytes32]struct{})
// To remove costly message IDs we need to iterate over all messages and find the ones that belong to the chain.
for _, seqNumMap := range observation.Messages {
for _, message := range seqNumMap {
messageIDs[message.Header.MessageID] = struct{}{}
}
}

deleteCostlyMessages := func() {
for i, costlyMessage := range observation.CostlyMessages {
if _, ok := messageIDs[costlyMessage]; ok {
observation.CostlyMessages = append(observation.CostlyMessages[:i], observation.CostlyMessages[i+1:]...)
}
}
}

delete(observation.CommitReports, chain)
delete(observation.Messages, chain)
delete(observation.TokenData, chain)
delete(observation.Nonces, chain)
deleteCostlyMessages()

return observation
}
Loading

0 comments on commit 880d236

Please sign in to comment.