diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader.go index 5a22b76809..f183c32950 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader.go @@ -1,7 +1,6 @@ package ccipdata import ( - "bytes" "context" "fmt" @@ -39,7 +38,7 @@ func (d lbtcPayload) Validate() error { } type LBTCReader interface { - GetLBTCMessageInTx(ctx context.Context, payloadHash []byte, txHash string) ([]byte, error) + GetLBTCMessageInTx(ctx context.Context, payloadHash [32]byte, txHash string) ([]byte, error) Close() error } @@ -83,7 +82,7 @@ func NewLBTCReaderWithCache(lggr logger.Logger, jobID string, transmitter common return r, nil } -func (r *LBTCReaderImpl) GetLBTCMessageInTx(ctx context.Context, payloadHash []byte, txHash string) ([]byte, error) { +func (r *LBTCReaderImpl) GetLBTCMessageInTx(ctx context.Context, payloadHash [32]byte, txHash string) ([]byte, error) { var lpLogs []logpoller.Log // fetch all the lbtc logs for the provided tx hash @@ -113,11 +112,11 @@ func (r *LBTCReaderImpl) GetLBTCMessageInTx(ctx context.Context, payloadHash []b } for _, log := range lpLogs { topics := log.GetTopics() - if currentPayloadHash := topics[3]; bytes.Equal(currentPayloadHash[:], payloadHash) { + if currentPayloadHash := topics[3]; currentPayloadHash == payloadHash { return parseLBTCDepositPayload(log.Data) } } - return nil, fmt.Errorf("payload with hash=%s not found in logs", hexutil.Encode(payloadHash)) + return nil, fmt.Errorf("payload with hash=%s not found in logs", hexutil.Encode(payloadHash[:])) } func parseLBTCDepositPayload(logData []byte) ([]byte, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader_test.go index f102ae31fb..d1135e9b0a 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/lbtc_reader_test.go @@ -51,7 +51,7 @@ func Test_MockLogPoller(t *testing.T) { LogWithPayload(t, 20, payload), }, nil) - data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash[:], "0x0001") + data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash, "0x0001") assert.NoError(t, err) assert.Equal(t, payload, data) }) @@ -67,7 +67,7 @@ func Test_MockLogPoller(t *testing.T) { LogWithPayload(t, 30, []byte("0x2222")), }, nil) - data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash[:], "0x0001") + data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash, "0x0001") assert.NoError(t, err) assert.Equal(t, payload, data) }) @@ -82,7 +82,7 @@ func Test_MockLogPoller(t *testing.T) { LogWithPayload(t, 30, []byte("0x2222")), }, nil) - data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash[:], "0x0001") + data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash, "0x0001") assert.Nil(t, data) assert.Errorf(t, err, "payload with hash=%s not found in logs", payloadHash) }) @@ -94,7 +94,7 @@ func Test_MockLogPoller(t *testing.T) { lp.On("IndexedLogsByTxHash", mock.Anything, reader.eventID, reader.transmitterAddress, mock.Anything). Return([]logpoller.Log{}, nil) - data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash[:], "0x0001") + data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash, "0x0001") assert.Nil(t, data) assert.Errorf(t, err, "payload with hash=%s not found in logs", payloadHash) }) @@ -105,7 +105,7 @@ func Test_MockLogPoller(t *testing.T) { require.NoError(t, err) r, err := NewLBTCReaderWithCache(lggr, "job_1", utils.RandomAddress(), nil, rCache, false) require.NoError(t, err) - data, err := r.GetLBTCMessageInTx(context.Background(), payloadHash[:], "0x0001") + data, err := r.GetLBTCMessageInTx(context.Background(), payloadHash, "0x0001") assert.NoError(t, err) assert.Equal(t, payload, data) }) @@ -146,7 +146,7 @@ func Test_SimulatedLogPoller_FoundMultiple(t *testing.T) { reader, err := NewLBTCReader(lggr, "job_1", transmitter, lp, true) require.NoError(t, err) - data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash[:], common.Hash{}.Hex()) + data, err := reader.GetLBTCMessageInTx(context.Background(), payloadHash, common.Hash{}.Hex()) assert.NoError(t, err) assert.Equal(t, payload, data) } diff --git a/core/services/ocr2/plugins/ccip/tokendata/lbtc/lbtc.go b/core/services/ocr2/plugins/ccip/tokendata/lbtc/lbtc.go index 74c61fae82..01dbb3bc6c 100644 --- a/core/services/ocr2/plugins/ccip/tokendata/lbtc/lbtc.go +++ b/core/services/ocr2/plugins/ccip/tokendata/lbtc/lbtc.go @@ -1,7 +1,6 @@ package lbtc import ( - "bytes" "context" "crypto/sha256" "fmt" @@ -262,17 +261,23 @@ func (s *TokenDataReader) getLBTCPayloadAndHash(ctx context.Context, msg cciptyp if err != nil { return nil, [32]byte{}, err } - payloadHash := decodedSourceTokenData.ExtraData - if len(payloadHash) != 32 { - s.lggr.Warnw("SourceTokenData.extraData is not 32 bytes. LBTC Attestation probably disabled onchain", "payloadHash", payloadHash) + destTokenData := decodedSourceTokenData.ExtraData + var payloadHash [32]byte + if len(destTokenData) != 32 { + payloadHash = sha256.Sum256(destTokenData) + s.lggr.Warnw("SourceTokenData.extraData size is not 32. It could be a LBTC payload, not LBTC payload sha256. "+ + "Probably this message is sent when LBTC attestation was disabled onchain. Will use sha256 from this value", + "destTokenData", destTokenData, "newPayloadHash", payloadHash) + } else { + payloadHash = [32]byte(destTokenData) } - payload, err := s.lbtcReader.GetLBTCMessageInTx(ctx, payloadHash, msg.TxHash) + actualPayload, err := s.lbtcReader.GetLBTCMessageInTx(ctx, payloadHash, msg.TxHash) if err != nil { return nil, [32]byte{}, err } - actualPayloadHash := sha256.Sum256(payload) - if bytes.Equal(actualPayloadHash[:], payloadHash) { - return payload, [32]byte(payloadHash), nil + actualPayloadHash := sha256.Sum256(actualPayload) + if actualPayloadHash == payloadHash { + return actualPayload, payloadHash, nil } return nil, [32]byte{}, fmt.Errorf("payload hash mismatch: expected %x, got %x", payloadHash, actualPayloadHash) }