From 9af7e76a0d5d4b04e4e938823cd6a731bda8b625 Mon Sep 17 00:00:00 2001 From: Daisuke Kanda Date: Wed, 3 Jul 2024 09:25:14 +0000 Subject: [PATCH] binary search messages in multicall Signed-off-by: Daisuke Kanda --- pkg/relay/ethereum/findItems_test.go | 75 +++++++ pkg/relay/ethereum/tx.go | 318 +++++++++++++++++++++------ 2 files changed, 327 insertions(+), 66 deletions(-) create mode 100644 pkg/relay/ethereum/findItems_test.go diff --git a/pkg/relay/ethereum/findItems_test.go b/pkg/relay/ethereum/findItems_test.go new file mode 100644 index 0000000..b8c8325 --- /dev/null +++ b/pkg/relay/ethereum/findItems_test.go @@ -0,0 +1,75 @@ +package ethereum + +import ( + "testing" + "fmt" + "slices" +) + +func TestFindItems(t *testing.T) { + cases := []struct{ + size int + expect int + expectLog []int + }{ + { + size: 0, + expect: 0, + expectLog: []int{ }, + }, + { + size: 10, + expect: 10, + expectLog: []int{ 10 }, + }, + { + size: 10, + expect: 9, + expectLog: []int{ 10, 5, 7, 8, 9 }, + }, + { + size: 10, + expect: 0, + expectLog: []int{ 10, 5, 2, 1 }, + }, + { + size: 10, + expect: 0, + expectLog: []int{ 10, 5, 2, 1 }, + }, + { + size: 10, + expect: 2, + expectLog: []int{ 10, 5, 2, 3 }, + }, + } + + for _, c := range cases { + type Data struct{ + expect int + log []int + } + data := Data{ expect: c.expect, log: make([]int, 0, c.size) } + result, err := findItems(c.size, &data, func(count int, d *Data) (error) { + d.log = append(d.log, count) + if count <= data.expect { + return nil + } else { + return fmt.Errorf("fail at count=%d", count) + } + }) + if c.expect == 0 { + if err == nil { + t.Errorf("findItems(%d,%d) unexpectedly returned %v", c.size, c.expect, result) + } + } else if err != nil { + t.Errorf("findItems(%d,%d) unexpectedly failed: %v", c.size, c.expect, err) + } else if result != c.expect { + t.Errorf("findItems(%d,%d) has been mistakenly resulted in %v", c.size, c.expect, result) + } + if slices.Compare(data.log, c.expectLog) != 0 { + t.Errorf("findItems(%d,%d) log unpexpected %v, expected=%v", c.size, c.expect, data.log, c.expectLog) + } + } +} + diff --git a/pkg/relay/ethereum/tx.go b/pkg/relay/ethereum/tx.go index e504a55..963ae73 100644 --- a/pkg/relay/ethereum/tx.go +++ b/pkg/relay/ethereum/tx.go @@ -416,8 +416,8 @@ func (c *Chain) getRevertReasonFromEstimateGas(err error) (string, []byte, error type CallIter struct { msgs []sdk.Msg + txs []gethtypes.Transaction cursor int - opts *bind.TransactOpts skipUpdateClientCommitment bool } func NewCallIter(msgs []sdk.Msg, skipUpdateClientCommitment bool) CallIter { @@ -483,19 +483,10 @@ func (iter *CallIter) sendSingleTx(ctx context.Context, c *Chain) (*gethtypes.Tr return nil, err } - if rawTxData, err := tx.MarshalBinary(); err != nil { - logger.Error("failed to encode tx", err) - } else { - logger.Logger = logger.With(logAttrRawTxData, hex.EncodeToString(rawTxData)) - } - - txGasLimit, isBreak, err := estimateGas(ctx, c, tx, false, logger) + txGasLimit, err := estimateGas(ctx, c, tx, true, logger) if err != nil { return nil, err } - if isBreak { - return nil, fmt.Errorf("isBraek should not be false") - } opts.GasLimit = txGasLimit } @@ -509,20 +500,16 @@ func (iter *CallIter) sendSingleTx(ctx context.Context, c *Chain) (*gethtypes.Tr return tx, nil } +/* func (iter *CallIter) sendMultiTx(ctx context.Context, c *Chain) (*gethtypes.Transaction, error) { from := iter.Cursor() if (iter.End()) { return nil, nil } + // now iter.cursor < len(iter.msgs) logger := c.GetChainLogger() - var ( - saveTxGasLimit uint64 - saveRawTxData string - ) - calls := make([]multicall3.Multicall3Call, 0, len(iter.msgs)) - opts, err := c.TxOpts(ctx, true); if err != nil { if err != nil { @@ -530,98 +517,296 @@ func (iter *CallIter) sendMultiTx(ctx context.Context, c *Chain) (*gethtypes.Tra } } - for !iter.End() && saveTxGasLimit < c.Config().MaxGasLimit { - logger := &log.RelayLogger{Logger: logger.With( - logAttrMsgIndexFrom, from, - logAttrMsgIndexTo, iter.Cursor(), - logAttrMsgType, fmt.Sprintf("%T", *iter.Current()), - )} - + if iter.txs == nil { // create txs at first multicall call opts.GasLimit = math.MaxUint64 opts.NoSend = true - tx, err := c.SendTx(opts, *iter.Current(), iter.skipUpdateClientCommitment) - if err != nil { - if len(calls) > 0 { - break + txs := make([]gethtypes.Transaction, 0, len(iter.msgs)) + for i := 0; i < len(iter.msgs); i++ { + tx, err := c.SendTx(opts, iter.msgs[i], iter.skipUpdateClientCommitment) + if err != nil { + logger := &log.RelayLogger{Logger: logger.With( + logAttrMsgIndexFrom, i, + logAttrMsgIndexTo, i, + logAttrMsgType, fmt.Sprintf("%T", iter.msgs[i]), + )} + logger.Error("failed to build tx for gas estimation", err) + return nil, err } - logger.Error("failed to build tx for gas estimation", err) - return nil, err + if tx.To() == nil { + err2 := fmt.Errorf("no target address") + logger.Error("failed to construct Multicall3Call", err2) + return nil, err2 + } + txs = append(txs, *tx) } + iter.txs = txs + } - to := tx.To(); if to == nil { - if len(calls) > 0 { + var ( + lastOkRawTxData string + lastOkCalls []multicall3.Multicall3Call + lastOkGasLimit uint64 + ) + lastOkIndex := len(iter.msgs) // it means undefined + lastNgIndex := len(iter.msgs) // it means undefined +L1: + for true { + var index int + fmt.Printf("for: msgs=%d, lastOk=%d, lastNg=%d\n", len(iter.msgs), lastOkIndex,lastNgIndex) + if lastNgIndex == len(iter.msgs) { + // note that 0 < len(iter.msgs) + index = len(iter.msgs) - 1 + if lastOkIndex == index { break } - err2 := fmt.Errorf("no target address") - logger.Error("failed to construct Multicall3Call", err2) - return nil, err2 + } else if lastOkIndex == len(iter.msgs) { + // note that lastNgIndex != from + index = (lastNgIndex - from) / 2 + } else if lastOkIndex + 1 == lastNgIndex { + break + } else { + index = (lastNgIndex - lastOkIndex) / 2 } + fmt.Printf("for: index=%d\n", index) - newCalls := append(calls, multicall3.Multicall3Call{ - Target: *tx.To(), - CallData: tx.Data(), - }) - multiTx, err := c.multicall3.Aggregate(opts, newCalls) + logger := &log.RelayLogger{Logger: logger.With( + logAttrMsgIndexFrom, from, + logAttrMsgIndexTo, index, + logAttrMsgType, fmt.Sprintf("%T", iter.msgs[index]), + )} + + opts.GasLimit = math.MaxUint64 + opts.NoSend = true + calls := make([]multicall3.Multicall3Call, 0, len(iter.msgs)) + for i := from; i <= index; i++ { + calls = append(calls, multicall3.Multicall3Call{ + Target: *iter.txs[i].To(), + CallData: iter.txs[i].Data(), + }) + } + + multiTx, err := c.multicall3.Aggregate(opts, calls) if err != nil { - if len(calls) > 0 { - break + if index == from { + logger.Error("failed to call Multicall3.Aggregate", err) + return nil, err + } else { + lastNgIndex = index + continue L1 } - logger.Error("failed to call Multicall3.Aggregate", err) - return nil, err } if rawTxData, err := multiTx.MarshalBinary(); err != nil { logger.Error("failed to encode multiTx", err) } else { - saveRawTxData = hex.EncodeToString(rawTxData) - logger.Logger = logger.With(logAttrRawTxData, saveRawTxData) + lastOkRawTxData = hex.EncodeToString(rawTxData) + logger.Logger = logger.With(logAttrRawTxData, lastOkRawTxData) } - txGasLimit, isBreak, err := estimateGas(ctx, c, multiTx, 0 < len(calls), logger) + txGasLimit, isBreak, err := estimateGas(ctx, c, multiTx, from < index, logger) if err != nil { return nil, err } if isBreak { // tx is fail or gas overs limit and 0 < len(calls) - break // send txs with last calls + lastNgIndex = index + break } // this calls is ok. save it and try to include next call - saveTxGasLimit = txGasLimit - calls = newCalls + lastOkIndex = index + lastOkCalls = calls + lastOkGasLimit = txGasLimit iter.Next() } - // now len(calls) > 0 logger = &log.RelayLogger{Logger: logger.With( logAttrMsgIndexFrom, from, - logAttrMsgIndexTo, iter.Cursor(), - logAttrRawTxData, saveRawTxData, + logAttrMsgIndexTo, lastOkIndex, + logAttrRawTxData, lastOkRawTxData, )} - opts.GasLimit = min(saveTxGasLimit, c.Config().MaxGasLimit) + opts.GasLimit = min(lastOkGasLimit, c.Config().MaxGasLimit) opts.NoSend = false - tx, err := c.multicall3.Aggregate(opts, calls) + tx, err := c.multicall3.Aggregate(opts, lastOkCalls) if err != nil { logger.Error("failed to send msg", err) return nil, err } + iter.cursor = lastOkIndex + 1 return tx, nil } +*/ + +func (iter *CallIter) sendMultiTx(ctx context.Context, c *Chain) (*gethtypes.Transaction, error) { + if (iter.End()) { + return nil, nil + } + // now iter.cursor < len(iter.msgs) + + logger := c.GetChainLogger() + + opts, err := c.TxOpts(ctx, true); + if err != nil { + if err != nil { + return nil, err + } + } + + if iter.txs == nil { // create txs at first multicall call + opts.GasLimit = math.MaxUint64 + opts.NoSend = true + txs := make([]gethtypes.Transaction, 0, len(iter.msgs)) + for i := 0; i < len(iter.msgs); i++ { + tx, err := c.SendTx(opts, iter.msgs[i], iter.skipUpdateClientCommitment) + if err != nil { + logger := &log.RelayLogger{Logger: logger.With( + logAttrMsgIndexFrom, i, + logAttrMsgIndexTo, i, + logAttrMsgType, fmt.Sprintf("%T", iter.msgs[i]), + )} + logger.Error("failed to build tx for gas estimation", err) + return nil, err + } + if tx.To() == nil { + err2 := fmt.Errorf("no target address") + logger.Error("failed to construct Multicall3Call", err2) + return nil, err2 + } + txs = append(txs, *tx) + } + iter.txs = txs + } + + type Data struct { + ctx context.Context + c *Chain + iter *CallIter + opts *bind.TransactOpts + lastOkCalls []multicall3.Multicall3Call + lastOkGasLimit uint64 + } + + data := Data { ctx, c, iter, opts, nil, 0 } + count, err := findItems( + len(iter.msgs) - iter.Cursor(), + &data, + func(count int, d *Data) (error) { + from := d.iter.Cursor() + to := from + count + + logger := &log.RelayLogger{Logger: logger.With( + logAttrMsgIndexFrom, from, + logAttrMsgIndexTo, from + count, + logAttrMsgType, fmt.Sprintf("%T", d.iter.msgs[from + count - 1]), + )} + + calls := make([]multicall3.Multicall3Call, 0, count) + for i := from; i < to; i++ { + calls = append(calls, multicall3.Multicall3Call{ + Target: *d.iter.txs[i].To(), + CallData: d.iter.txs[i].Data(), + }) + } + + d.opts.GasLimit = math.MaxUint64 + d.opts.NoSend = true + multiTx, err := c.multicall3.Aggregate(d.opts, calls) + if err != nil { + return err + } + + txGasLimit, err := estimateGas(ctx, c, multiTx, 1 == count, logger) + if err != nil { + return err + } + + d.lastOkGasLimit = txGasLimit + d.lastOkCalls = calls + return nil + }) + + logger = &log.RelayLogger{Logger: logger.With( + logAttrMsgIndexFrom, iter.Cursor(), + logAttrMsgIndexTo, iter.Cursor() + count, + logAttrMsgType, fmt.Sprintf("%T", iter.msgs[iter.Cursor() + count - 1]), + )} + + if err != nil { + logger.Error("failed to multicall", err) + return nil, err + } + + opts.GasLimit = min(data.lastOkGasLimit, c.Config().MaxGasLimit) + opts.NoSend = false + tx, err := c.multicall3.Aggregate(opts, data.lastOkCalls) + if err != nil { + logger.Error("failed to send multicall tx", err) + return nil, err + } + iter.cursor += count + return tx, nil +} + +func findItems[D any]( + size int, + userdata *D, + fnTest func(int, *D) (error), +) (int, error) { + if (size <= 0) { + return 0, fmt.Errorf("empty items") + } + + lastOkCount := 0 + lastNgCount := 0 + + for true { + var count int + + if lastNgCount == 0 { + count = size + if lastOkCount == count { + return lastOkCount, nil + } + } else if lastOkCount == 0 { + if lastNgCount == 1 { + return 0, fmt.Errorf("not found") + } + count = lastNgCount / 2 // note that lastNgCount >= 2 + } else if lastOkCount + 1 == lastNgCount { + return lastOkCount, nil + } else { + count = (lastNgCount + lastOkCount) / 2 + } + + err := fnTest(count, userdata) + if err != nil { + if count == 1 { + return 0, err + } + lastNgCount = count + } else { + lastOkCount = count + } + } + return lastOkCount, nil // not reached +} func estimateGas( ctx context.Context, c *Chain, tx *gethtypes.Transaction, - doBreak bool, // return 0,nil if error or gas is over + doRound bool, // return rounded gas limit when gas limit is over logger *log.RelayLogger, -) (uint64, bool, error) { +) (uint64, error) { + if rawTxData, err := tx.MarshalBinary(); err != nil { + logger.Error("failed to encode tx", err) + } else { + logger.Logger = logger.With(logAttrRawTxData, hex.EncodeToString(rawTxData)) + } + estimatedGas, err := c.client.EstimateGasFromTx(ctx, tx) if err != nil { - if doBreak { - return 0, true, nil - } - if revertReason, rawErrorData, err := c.getRevertReasonFromEstimateGas(err); err != nil { // Raw error data may be available even if revert reason isn't available. logger.Logger = logger.With(logAttrRawErrorData, hex.EncodeToString(rawErrorData)) @@ -634,21 +819,22 @@ func estimateGas( } logger.Error("failed to estimate gas", err) - return 0, false, err + return 0, err } txGasLimit := estimatedGas * c.Config().GasEstimateRate.Numerator / c.Config().GasEstimateRate.Denominator if txGasLimit > c.Config().MaxGasLimit { - if doBreak { - return 0, true, nil + if !doRound { + return 0, fmt.Errorf("estimated gas exceeds max gas limit") } + logger.Warn("estimated gas exceeds max gas limit", logAttrEstimatedGas, txGasLimit, logAttrMaxGasLimit, c.Config().MaxGasLimit, ) - return c.Config().MaxGasLimit, false, nil + return c.Config().MaxGasLimit, nil } - return txGasLimit, false, nil + return txGasLimit, nil }