From 4d594bd99f13d1189b28f416f220dcbbc228ed5f Mon Sep 17 00:00:00 2001 From: Michael Street <5597260+MStreet3@users.noreply.github.com> Date: Tue, 24 Dec 2024 08:19:47 -0500 Subject: [PATCH] refactor(keystone/changeset): move OCR3 selector func and add logs --- .../keystone/changeset/internal/deploy.go | 34 ++--------------- .../keystone/changeset/internal/state.go | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/deployment/keystone/changeset/internal/deploy.go b/deployment/keystone/changeset/internal/deploy.go index ed914eff8f5..acaabd22131 100644 --- a/deployment/keystone/changeset/internal/deploy.go +++ b/deployment/keystone/changeset/internal/deploy.go @@ -31,7 +31,6 @@ import ( capabilities_registry "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/keystone/generated/capabilities_registry_1_1_0" kf "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/keystone/generated/forwarder_1_0_0" - ocr3_capability "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/keystone/generated/ocr3_capability_1_0_0" "github.com/smartcontractkit/chainlink-common/pkg/logger" ) @@ -324,8 +323,9 @@ func ConfigureOCR3Contract(env *deployment.Environment, chainSel uint64, dons [] return fmt.Errorf("failed to get contract set for chain %d", chainSel) } - contract, err := getOCR3Contract(contracts.OCR3, nil) + contract, err := contracts.GetOCR3Contract(nil) if err != nil { + env.Logger.Errorf("failed to get OCR3 contract: %s", err) return fmt.Errorf("failed to get OCR3 contract: %w", err) } @@ -382,8 +382,9 @@ func ConfigureOCR3ContractFromJD(env *deployment.Environment, cfg ConfigureOCR3C return nil, fmt.Errorf("failed to get contract set for chain %d", cfg.ChainSel) } - contract, err := getOCR3Contract(contracts.OCR3, cfg.Address) + contract, err := contracts.GetOCR3Contract(cfg.Address) if err != nil { + env.Logger.Errorf("%sfailed to get OCR3 contract at %s : %s", prefix, cfg.Address, err) return nil, fmt.Errorf("failed to get OCR3 contract: %w", err) } @@ -969,30 +970,3 @@ func configureForwarder(lggr logger.Logger, chain deployment.Chain, contractSet } return opMap, nil } - -// getOCR3Contract returns the OCR3 contract from the contract set. By default, it returns the only -// contract in the set if there is no address specified. If an address is specified, it returns the -// contract with that address. If the address is specified but not found in the contract set, it returns -// an error. -func getOCR3Contract(contracts map[common.Address]*ocr3_capability.OCR3Capability, addr *common.Address) (*ocr3_capability.OCR3Capability, error) { - // Fail if the OCR3 contract address is unspecified and there are multiple OCR3 contracts - if addr == nil && len(contracts) > 1 { - return nil, errors.New("OCR contract address is unspecified") - } - - // Use the first OCR3 contract if the address is unspecified - if addr == nil && len(contracts) == 1 { - // use the first OCR3 contract - for _, c := range contracts { - return c, nil - } - } - - // Select the OCR3 contract by address - if contract, ok := contracts[*addr]; ok { - return contract, nil - } - - // Fail if the OCR3 contract address is specified but not found in the contract set - return nil, fmt.Errorf("OCR3 contract address %s not found in contract set", *addr) -} diff --git a/deployment/keystone/changeset/internal/state.go b/deployment/keystone/changeset/internal/state.go index 4f12c86cd94..3253acaf4e4 100644 --- a/deployment/keystone/changeset/internal/state.go +++ b/deployment/keystone/changeset/internal/state.go @@ -1,6 +1,7 @@ package internal import ( + "errors" "fmt" "github.com/ethereum/go-ethereum/common" @@ -65,6 +66,10 @@ func (cs ContractSet) View() (view.KeystoneChainView, error) { return out, nil } +func (cs ContractSet) GetOCR3Contract(addr *common.Address) (*ocr3_capability.OCR3Capability, error) { + return getOCR3Contract(cs.OCR3, addr) +} + func GetContractSets(lggr logger.Logger, req *GetContractSetsRequest) (*GetContractSetsResponse, error) { resp := &GetContractSetsResponse{ ContractSets: make(map[uint64]ContractSet), @@ -128,3 +133,35 @@ func loadContractSet(lggr logger.Logger, chain deployment.Chain, addresses map[s } return &out, nil } + +// getOCR3Contract returns the OCR3 contract from the contract set. By default, it returns the only +// contract in the set if there is no address specified. If an address is specified, it returns the +// contract with that address. If the address is specified but not found in the contract set, it returns +// an error. +func getOCR3Contract(contracts map[common.Address]*ocr3_capability.OCR3Capability, addr *common.Address) (*ocr3_capability.OCR3Capability, error) { + // Fail if the OCR3 contract address is unspecified and there are multiple OCR3 contracts + if addr == nil && len(contracts) > 1 { + return nil, errors.New("OCR contract address is unspecified") + } + + // Use the first OCR3 contract if the address is unspecified + if addr == nil && len(contracts) == 1 { + // use the first OCR3 contract + for _, c := range contracts { + return c, nil + } + } + + // Select the OCR3 contract by address + if contract, ok := contracts[*addr]; ok { + return contract, nil + } + + addrSet := make([]string, 0, len(contracts)) + for a := range contracts { + addrSet = append(addrSet, a.String()) + } + + // Fail if the OCR3 contract address is specified but not found in the contract set + return nil, fmt.Errorf("OCR3 contract address %s not found in contract set %v", *addr, addrSet) +}