Skip to content

Commit

Permalink
Merge branch 'main' into solana/specify-writable-pool-accounts
Browse files Browse the repository at this point in the history
  • Loading branch information
aalu1418 authored Jan 9, 2025
2 parents a591676 + ff9d86b commit 3a04b12
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
34 changes: 28 additions & 6 deletions pkg/contractreader/extended.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync"
"time"

"github.com/smartcontractkit/chainlink-ccip/pkg/consts"

"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/types"
clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
Expand Down Expand Up @@ -95,7 +97,9 @@ type ExtendedBoundContract struct {
type extendedContractReader struct {
reader ContractReaderFacade
contractBindingsByName map[string][]ExtendedBoundContract
mu *sync.RWMutex
// contract names that allow multiple bindings
multiBindAllowed map[string]bool
mu *sync.RWMutex
}

func NewExtendedContractReader(baseContractReader ContractReaderFacade) Extended {
Expand All @@ -106,7 +110,10 @@ func NewExtendedContractReader(baseContractReader ContractReaderFacade) Extended
return &extendedContractReader{
reader: baseContractReader,
contractBindingsByName: make(map[string][]ExtendedBoundContract),
mu: &sync.RWMutex{},
// so far this is the only contract that allows multiple bindings
// if more contracts are added, this should be moved to a config
multiBindAllowed: map[string]bool{consts.ContractNamePriceAggregator: true},
mu: &sync.RWMutex{},
}
}

Expand Down Expand Up @@ -261,10 +268,25 @@ func (e *extendedContractReader) Bind(ctx context.Context, allBindings []types.B
e.mu.Lock()
defer e.mu.Unlock()
for _, binding := range validBindings {
e.contractBindingsByName[binding.Name] = append(e.contractBindingsByName[binding.Name], ExtendedBoundContract{
BoundAt: time.Now(),
Binding: binding,
})
if e.multiBindAllowed[binding.Name] {
e.contractBindingsByName[binding.Name] = append(e.contractBindingsByName[binding.Name], ExtendedBoundContract{
BoundAt: time.Now(),
Binding: binding,
})
} else {
if len(e.contractBindingsByName[binding.Name]) > 0 {
// Unbind the previous binding
err := e.reader.Unbind(ctx, []types.BoundContract{e.contractBindingsByName[binding.Name][0].Binding})
if err != nil {
return fmt.Errorf("failed to unbind previous binding: %w", err)
}
}
// Override the previous binding
e.contractBindingsByName[binding.Name] = []ExtendedBoundContract{{
BoundAt: time.Now(),
Binding: binding,
}}
}
}

return nil
Expand Down
38 changes: 38 additions & 0 deletions pkg/contractreader/extended_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"testing"

"github.com/smartcontractkit/chainlink-ccip/pkg/consts"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand All @@ -26,6 +28,42 @@ func TestExtendedContractReader(t *testing.T) {
bindings := extCr.GetBindings(contractName)
assert.Len(t, bindings, 0)

cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Unbind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x124"}}).Return(nil)
cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x125"}}).Return(fmt.Errorf("some err"))

err := extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x123"}})
assert.NoError(t, err)

// ignored since 0x123 already exists
err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x123"}})
assert.NoError(t, err)

err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x124"}})
assert.NoError(t, err)

// Bind fails
err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x125"}})
assert.Error(t, err)

bindings = extCr.GetBindings(contractName)
assert.Len(t, bindings, 1)
assert.Equal(t, "0x124", bindings[0].Binding.Address)
}

func TestExtendedContractReader_AllowMultiBindingForAggregator(t *testing.T) {
const contractName = consts.ContractNamePriceAggregator
cr := chainreadermocks.NewMockContractReaderFacade(t)
extCr := contractreader.NewExtendedContractReader(cr)

bindings := extCr.GetBindings(contractName)
assert.Len(t, bindings, 0)

cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Bind", context.Background(),
Expand Down

0 comments on commit 3a04b12

Please sign in to comment.