diff --git a/RELEASES.md b/RELEASES.md index 8256a2666c3e..f0ab8320c33e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,12 @@ # Release Notes +## Pending Release + +### APIs + +- Added `validationID` to `platform.getL1Validator` outputs +- Added L1 validators support to `platform.getCurrentValidators` + ## [v1.11.13](https://github.com/ava-labs/avalanchego/releases/tag/v1.11.13) This version is backwards compatible to [v1.11.0](https://github.com/ava-labs/avalanchego/releases/tag/v1.11.0). It is optional, but encouraged. diff --git a/vms/platformvm/api/static_service.go b/vms/platformvm/api/static_service.go index 326e5fb5e8a8..9f667ceb7991 100644 --- a/vms/platformvm/api/static_service.go +++ b/vms/platformvm/api/static_service.go @@ -141,14 +141,6 @@ type GenesisPermissionlessValidator struct { Signer *signer.ProofOfPossession `json:"signer,omitempty"` } -// PermissionedValidator is the repr. of a permissioned validator sent over APIs. -type PermissionedValidator struct { - Staker - // The owner the staking reward, if applicable, will go to - Connected *bool `json:"connected,omitempty"` - Uptime *json.Float32 `json:"uptime,omitempty"` -} - // PrimaryDelegator is the repr. of a primary network delegator sent over APIs. type PrimaryDelegator struct { Staker diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index c5a9df6c21a5..bb845d9e16ad 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -787,26 +787,105 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato reply.Validators = []interface{}{} - // Validator's node ID as string --> Delegators to them - vdrToDelegators := map[ids.NodeID][]platformapi.PrimaryDelegator{} - // Create set of nodeIDs nodeIDs := set.Of(args.NodeIDs...) s.vm.ctx.Lock.Lock() defer s.vm.ctx.Lock.Unlock() + // if the subnetID is the primary network, return the primary validators + if args.SubnetID == constants.PrimaryNetworkID { + primaryValidators, err := s.getPrimaryOrSubnetValidators(constants.PrimaryNetworkID, nodeIDs) + if err != nil { + return err + } + + reply.Validators = primaryValidators + return nil + } + + // Check if subnet is L1 + _, err := s.vm.state.GetSubnetToL1Conversion(args.SubnetID) + if err == database.ErrNotFound { + // Subnet is not L1, get validators for the subnet + subnetValidators, err := s.getPrimaryOrSubnetValidators(args.SubnetID, nodeIDs) + if err != nil { + return err + } + reply.Validators = subnetValidators + return nil + } + if err != nil { + return err + } + // Subnet is L1, get validators for L1 + l1Validators, err := s.getL1Validators(args.SubnetID, nodeIDs) + if err != nil { + return err + } + reply.Validators = l1Validators + return nil +} + +func (s *Service) getL1Validators(subnetID ids.ID, nodeIDs set.Set[ids.NodeID]) ([]interface{}, error) { + validators := []interface{}{} + baseStakers, l1Validators, _, err := s.vm.state.GetCurrentValidators(subnetID) + if err != nil { + return nil, err + } + + for _, staker := range baseStakers { + nodeID := staker.NodeID + if nodeIDs.Len() != 0 && !nodeIDs.Contains(nodeID) { + continue + } + weight := avajson.Uint64(staker.Weight) + apiStaker := platformapi.Staker{ + TxID: staker.TxID, + StartTime: avajson.Uint64(staker.StartTime.Unix()), + EndTime: avajson.Uint64(staker.EndTime.Unix()), + Weight: weight, + NodeID: nodeID, + } + validators = append(validators, apiStaker) + } + + for _, l1Validator := range l1Validators { + nodeID := l1Validator.NodeID + if nodeIDs.Len() != 0 && !nodeIDs.Contains(nodeID) { + continue + } + + apiL1Vdr, err := s.convertL1ValidatorToAPI(l1Validator) + if err != nil { + return nil, err + } + + validators = append(validators, apiL1Vdr) + } + + return validators, nil +} + +func (s *Service) getPrimaryOrSubnetValidators(subnetID ids.ID, nodeIDs set.Set[ids.NodeID]) ([]interface{}, error) { numNodeIDs := nodeIDs.Len() + targetStakers := make([]*state.Staker, 0, numNodeIDs) + + // Validator's node ID as string --> Delegators to them + vdrToDelegators := map[ids.NodeID][]platformapi.PrimaryDelegator{} + + validators := []interface{}{} + if numNodeIDs == 0 { // Include all nodes currentStakerIterator, err := s.vm.state.GetCurrentStakerIterator() if err != nil { - return err + return nil, err } // TODO: avoid iterating over delegators here. for currentStakerIterator.Next() { staker := currentStakerIterator.Value() - if args.SubnetID != staker.SubnetID { + if subnetID != staker.SubnetID { continue } targetStakers = append(targetStakers, staker) @@ -814,21 +893,21 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato currentStakerIterator.Release() } else { for nodeID := range nodeIDs { - staker, err := s.vm.state.GetCurrentValidator(args.SubnetID, nodeID) + staker, err := s.vm.state.GetCurrentValidator(subnetID, nodeID) switch err { case nil: case database.ErrNotFound: // nothing to do, continue continue default: - return err + return nil, err } targetStakers = append(targetStakers, staker) // TODO: avoid iterating over delegators when numNodeIDs > 1. - delegatorsIt, err := s.vm.state.GetCurrentDelegatorIterator(args.SubnetID, nodeID) + delegatorsIt, err := s.vm.state.GetCurrentDelegatorIterator(subnetID, nodeID) if err != nil { - return err + return nil, err } for delegatorsIt.Next() { staker := delegatorsIt.Value() @@ -853,7 +932,7 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato delegateeReward, err := s.vm.state.GetDelegateeReward(currentStaker.SubnetID, currentStaker.NodeID) if err != nil { - return err + return nil, err } jsonDelegateeReward := avajson.Uint64(delegateeReward) @@ -861,7 +940,7 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato case txs.PrimaryNetworkValidatorCurrentPriority, txs.SubnetPermissionlessValidatorCurrentPriority: attr, err := s.loadStakerTxAttributes(currentStaker.TxID) if err != nil { - return err + return nil, err } shares := attr.shares @@ -870,16 +949,16 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato uptime *avajson.Float32 connected *bool ) - if args.SubnetID == constants.PrimaryNetworkID { + if subnetID == constants.PrimaryNetworkID { rawUptime, err := s.vm.uptimeManager.CalculateUptimePercentFrom(currentStaker.NodeID, currentStaker.StartTime) if err != nil { - return err + return nil, err } // Transform this to a percentage (0-100) to make it consistent // with observedUptime in info.peers API currentUptime := avajson.Float32(rawUptime * 100) if err != nil { - return err + return nil, err } isConnected := s.vm.uptimeManager.IsConnected(currentStaker.NodeID) connected = &isConnected @@ -894,14 +973,14 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato if ok { validationRewardOwner, err = s.getAPIOwner(validationOwner) if err != nil { - return err + return nil, err } } delegationOwner, ok := attr.delegationRewardsOwner.(*secp256k1fx.OutputOwners) if ok { delegationRewardOwner, err = s.getAPIOwner(delegationOwner) if err != nil { - return err + return nil, err } } @@ -917,7 +996,7 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato DelegationFee: delegationFee, Signer: attr.proofOfPossession, } - reply.Validators = append(reply.Validators, vdr) + validators = append(validators, vdr) case txs.PrimaryNetworkDelegatorCurrentPriority, txs.SubnetPermissionlessDelegatorCurrentPriority: var rewardOwner *platformapi.Owner @@ -926,13 +1005,13 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato if numNodeIDs == 1 { attr, err := s.loadStakerTxAttributes(currentStaker.TxID) if err != nil { - return err + return nil, err } owner, ok := attr.rewardsOwner.(*secp256k1fx.OutputOwners) if ok { rewardOwner, err = s.getAPIOwner(owner) if err != nil { - return err + return nil, err } } } @@ -945,17 +1024,15 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato vdrToDelegators[delegator.NodeID] = append(vdrToDelegators[delegator.NodeID], delegator) case txs.SubnetPermissionedValidatorCurrentPriority: - reply.Validators = append(reply.Validators, platformapi.PermissionedValidator{ - Staker: apiStaker, - }) + validators = append(validators, apiStaker) default: - return fmt.Errorf("unexpected staker priority %d", currentStaker.Priority) + return nil, fmt.Errorf("unexpected staker priority %d", currentStaker.Priority) } } // handle delegators' information - for i, vdrIntf := range reply.Validators { + for i, vdrIntf := range validators { vdr, ok := vdrIntf.(platformapi.PermissionlessValidator) if !ok { continue @@ -979,19 +1056,19 @@ func (s *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidato // queried a specific validator, load all of its delegators vdr.Delegators = &delegators } - reply.Validators[i] = vdr + validators[i] = vdr } - return nil + return validators, nil } type GetL1ValidatorArgs struct { ValidationID ids.ID `json:"validationID"` } -type GetL1ValidatorReply struct { - SubnetID ids.ID `json:"subnetID"` - NodeID ids.NodeID `json:"nodeID"` +type APIL1Validator struct { + ValidationID ids.ID `json:"validationID"` + NodeID ids.NodeID `json:"nodeID"` // PublicKey is the compressed BLS public key of the validator PublicKey types.JSONByteSlice `json:"publicKey"` RemainingBalanceOwner platformapi.Owner `json:"remainingBalanceOwner"` @@ -1003,6 +1080,11 @@ type GetL1ValidatorReply struct { // the continuous fee, according to the last accepted state. If the // validator is inactive, the balance will be 0. Balance avajson.Uint64 `json:"balance"` +} + +type GetL1ValidatorReply struct { + SubnetID ids.ID `json:"subnetID"` + *APIL1Validator // Height is the height of the last accepted block Height avajson.Uint64 `json:"height"` } @@ -1023,53 +1105,66 @@ func (s *Service) GetL1Validator(r *http.Request, args *GetL1ValidatorArgs, repl return fmt.Errorf("fetching L1 validator %q failed: %w", args.ValidationID, err) } + ctx := r.Context() + height, err := s.vm.GetCurrentHeight(ctx) + if err != nil { + return fmt.Errorf("failed getting current height: %w", err) + } + apiVdr, err := s.convertL1ValidatorToAPI(l1Validator) + if err != nil { + return fmt.Errorf("failed converting L1 validator to API: %w", err) + } + + reply.SubnetID = l1Validator.SubnetID + reply.APIL1Validator = &apiVdr + reply.Height = avajson.Uint64(height) + + return nil +} + +func (s *Service) convertL1ValidatorToAPI(vdr state.L1Validator) (APIL1Validator, error) { var remainingBalanceOwner message.PChainOwner - if _, err := txs.Codec.Unmarshal(l1Validator.RemainingBalanceOwner, &remainingBalanceOwner); err != nil { - return fmt.Errorf("failed unmarshalling remaining balance owner: %w", err) + if _, err := txs.Codec.Unmarshal(vdr.RemainingBalanceOwner, &remainingBalanceOwner); err != nil { + return APIL1Validator{}, fmt.Errorf("failed unmarshalling remaining balance owner: %w", err) } remainingBalanceAPIOwner, err := s.getAPIOwner(&secp256k1fx.OutputOwners{ Threshold: remainingBalanceOwner.Threshold, Addrs: remainingBalanceOwner.Addresses, }) if err != nil { - return fmt.Errorf("failed formatting remaining balance owner: %w", err) + return APIL1Validator{}, fmt.Errorf("failed formatting remaining balance owner: %w", err) } var deactivationOwner message.PChainOwner - if _, err := txs.Codec.Unmarshal(l1Validator.DeactivationOwner, &deactivationOwner); err != nil { - return fmt.Errorf("failed unmarshalling deactivation owner: %w", err) + if _, err := txs.Codec.Unmarshal(vdr.DeactivationOwner, &deactivationOwner); err != nil { + return APIL1Validator{}, fmt.Errorf("failed unmarshalling deactivation owner: %w", err) } deactivationAPIOwner, err := s.getAPIOwner(&secp256k1fx.OutputOwners{ Threshold: deactivationOwner.Threshold, Addrs: deactivationOwner.Addresses, }) if err != nil { - return fmt.Errorf("failed formatting deactivation owner: %w", err) + return APIL1Validator{}, fmt.Errorf("failed formatting deactivation owner: %w", err) } - ctx := r.Context() - height, err := s.vm.GetCurrentHeight(ctx) - if err != nil { - return fmt.Errorf("failed getting current height: %w", err) - } + var apiVdr APIL1Validator - reply.SubnetID = l1Validator.SubnetID - reply.NodeID = l1Validator.NodeID - reply.PublicKey = bls.PublicKeyToCompressedBytes( - bls.PublicKeyFromValidUncompressedBytes(l1Validator.PublicKey), + apiVdr.ValidationID = vdr.ValidationID + apiVdr.NodeID = vdr.NodeID + apiVdr.PublicKey = bls.PublicKeyToCompressedBytes( + bls.PublicKeyFromValidUncompressedBytes(vdr.PublicKey), ) - reply.RemainingBalanceOwner = *remainingBalanceAPIOwner - reply.DeactivationOwner = *deactivationAPIOwner - reply.StartTime = avajson.Uint64(l1Validator.StartTime) - reply.Weight = avajson.Uint64(l1Validator.Weight) - reply.MinNonce = avajson.Uint64(l1Validator.MinNonce) - if l1Validator.EndAccumulatedFee != 0 { + apiVdr.RemainingBalanceOwner = *remainingBalanceAPIOwner + apiVdr.DeactivationOwner = *deactivationAPIOwner + apiVdr.StartTime = avajson.Uint64(vdr.StartTime) + apiVdr.Weight = avajson.Uint64(vdr.Weight) + apiVdr.MinNonce = avajson.Uint64(vdr.MinNonce) + if vdr.EndAccumulatedFee != 0 { accruedFees := s.vm.state.GetAccruedFees() - reply.Balance = avajson.Uint64(l1Validator.EndAccumulatedFee - accruedFees) + apiVdr.Balance = avajson.Uint64(vdr.EndAccumulatedFee - accruedFees) } - reply.Height = avajson.Uint64(height) - return nil + return apiVdr, nil } // GetCurrentSupplyArgs are the arguments for calling GetCurrentSupply diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 361623e27e78..d5185bce134d 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "golang.org/x/exp/maps" "github.com/ava-labs/avalanchego/api" "github.com/ava-labs/avalanchego/api/keystore" @@ -46,7 +47,9 @@ import ( "github.com/ava-labs/avalanchego/vms/platformvm/state" "github.com/ava-labs/avalanchego/vms/platformvm/status" "github.com/ava-labs/avalanchego/vms/platformvm/txs" + "github.com/ava-labs/avalanchego/vms/platformvm/warp/message" "github.com/ava-labs/avalanchego/vms/secp256k1fx" + "github.com/ava-labs/avalanchego/vms/types" "github.com/ava-labs/avalanchego/wallet/subnet/primary/common" avajson "github.com/ava-labs/avalanchego/utils/json" @@ -1398,3 +1401,263 @@ func FuzzGetFeeState(f *testing.F) { require.Equal(expectedReply, reply) }) } + +func TestGetCurrentValidatorsForL1(t *testing.T) { + subnetID := ids.GenerateTestID() + + sk, err := bls.NewSecretKey() + require.NoError(t, err) + pk := bls.PublicFromSecretKey(sk) + pkBytes := bls.PublicKeyToUncompressedBytes(pk) + + otherSK, err := bls.NewSecretKey() + require.NoError(t, err) + otherPK := bls.PublicFromSecretKey(otherSK) + otherPKBytes := bls.PublicKeyToUncompressedBytes(otherPK) + now := time.Now() + + tests := []struct { + name string + initial []*state.Staker + l1Validators []state.L1Validator + }{ + { + name: "empty noop", + }, + { + name: "initial stakers", + initial: []*state.Staker{ + { + TxID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: pk, + Weight: 1, + StartTime: now, + }, + { + TxID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: otherPK, + Weight: 1, + StartTime: now.Add(1 * time.Second), + }, + }, + }, + { + name: "L1 validators", + l1Validators: []state.L1Validator{ + { + ValidationID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + StartTime: uint64(now.Unix()), + PublicKey: pkBytes, + Weight: 1, + }, + { + ValidationID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: otherPKBytes, + StartTime: uint64(now.Unix()) + 1, + Weight: 1, + }, + }, + }, + { + name: "initial stakers and L1 validators mixed", + initial: []*state.Staker{ + { + TxID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: pk, + Weight: 123123, + StartTime: now, + }, + { + TxID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: otherPK, + Weight: 0, + StartTime: now.Add(2 * time.Second), + }, + }, + l1Validators: []state.L1Validator{ + { + ValidationID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + StartTime: uint64(now.Unix()), + PublicKey: pkBytes, + Weight: 1, + EndAccumulatedFee: 1, + MinNonce: 2, + }, + { + ValidationID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: pkBytes, + StartTime: uint64(now.Unix()) + 2, + Weight: 1, + EndAccumulatedFee: 0, + }, + { + ValidationID: ids.GenerateTestID(), + SubnetID: subnetID, + NodeID: ids.GenerateTestNodeID(), + PublicKey: otherPKBytes, + StartTime: uint64(now.Unix()) + 3, + Weight: 0, + EndAccumulatedFee: 1, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + service, _ := defaultService(t, upgradetest.Latest) + service.vm.ctx.Lock.Lock() + stakersByTxID := make(map[ids.ID]*state.Staker) + for _, staker := range test.initial { + primaryStaker := &state.Staker{ + TxID: ids.GenerateTestID(), + SubnetID: constants.PrimaryNetworkID, + NodeID: staker.NodeID, + PublicKey: staker.PublicKey, + Weight: 5, + // start primary network staker 1 second before the subnet staker + StartTime: staker.StartTime.Add(-1 * time.Second), + Priority: txs.PrimaryNetworkValidatorCurrentPriority, + } + require.NoError(service.vm.state.PutCurrentValidator(primaryStaker)) + staker.Priority = txs.SubnetPermissionedValidatorCurrentPriority + require.NoError(service.vm.state.PutCurrentValidator(staker)) + + stakersByTxID[staker.TxID] = staker + } + + l1ValidatorsByVID := make(map[ids.ID]state.L1Validator) + if len(test.l1Validators) != 0 { + service.vm.state.SetSubnetToL1Conversion(subnetID, + state.SubnetToL1Conversion{ + ConversionID: ids.GenerateTestID(), + ChainID: ids.GenerateTestID(), + Addr: []byte{'a', 'd', 'd', 'r'}, + }) + } + + for _, l1Validator := range test.l1Validators { + deactivationOwner := message.PChainOwner{ + Threshold: 1, + Addresses: []ids.ShortID{ + ids.GenerateTestShortID(), + }, + } + + remaningBalanceOwner := message.PChainOwner{ + Threshold: 1, + Addresses: []ids.ShortID{ + ids.GenerateTestShortID(), + }, + } + + remainingBalanceOwnerBytes, err := txs.Codec.Marshal(txs.CodecVersion, remaningBalanceOwner) + require.NoError(err) + deactivationOwnerBytes, err := txs.Codec.Marshal(txs.CodecVersion, deactivationOwner) + require.NoError(err) + l1Validator.RemainingBalanceOwner = remainingBalanceOwnerBytes + l1Validator.DeactivationOwner = deactivationOwnerBytes + + require.NoError(service.vm.state.PutL1Validator(l1Validator)) + + if l1Validator.Weight == 0 { + continue + } + l1ValidatorsByVID[l1Validator.ValidationID] = l1Validator + } + + service.vm.state.SetHeight(0) + require.NoError(service.vm.state.Commit()) + service.vm.ctx.Lock.Unlock() + + testValidator := func(vdr interface{}) ids.NodeID { + switch v := vdr.(type) { + case pchainapi.Staker: + staker, exists := stakersByTxID[v.TxID] + require.True(exists, "unexpected validator: %s", vdr) + require.Equal(staker.NodeID, v.NodeID) + require.Equal(avajson.Uint64(staker.Weight), v.Weight) + require.Equal(staker.StartTime.Unix(), int64(v.StartTime)) + return v.NodeID + case APIL1Validator: + validator, exists := l1ValidatorsByVID[v.ValidationID] + require.True(exists, "unexpected validator: %s", vdr) + require.Equal(validator.NodeID, v.NodeID) + require.Equal(avajson.Uint64(validator.Weight), v.Weight) + require.Equal(validator.StartTime, uint64(v.StartTime)) + accruedFees := service.vm.state.GetAccruedFees() + require.Equal(avajson.Uint64(validator.EndAccumulatedFee-accruedFees), v.Balance) + require.Equal(avajson.Uint64(validator.MinNonce), v.MinNonce) + require.Equal( + types.JSONByteSlice(bls.PublicKeyToCompressedBytes(bls.PublicKeyFromValidUncompressedBytes(validator.PublicKey))), + v.PublicKey) + var expectedRemainingBalanceOwner message.PChainOwner + _, err := txs.Codec.Unmarshal(validator.RemainingBalanceOwner, &expectedRemainingBalanceOwner) + require.NoError(err) + formattedRemainingBalanceOwner, err := service.addrManager.FormatLocalAddress(expectedRemainingBalanceOwner.Addresses[0]) + require.NoError(err) + require.Equal(formattedRemainingBalanceOwner, v.RemainingBalanceOwner.Addresses[0]) + require.Equal(avajson.Uint32(expectedRemainingBalanceOwner.Threshold), v.RemainingBalanceOwner.Threshold) + var expectedDeactivationOwner message.PChainOwner + _, err = txs.Codec.Unmarshal(validator.DeactivationOwner, &expectedDeactivationOwner) + require.NoError(err) + formattedDeactivationOwner, err := service.addrManager.FormatLocalAddress(expectedDeactivationOwner.Addresses[0]) + require.NoError(err) + require.Equal(formattedDeactivationOwner, v.DeactivationOwner.Addresses[0]) + require.Equal(avajson.Uint32(expectedDeactivationOwner.Threshold), v.DeactivationOwner.Threshold) + return v.NodeID + default: + require.Fail("unexpected validator type: %T", vdr) + return ids.NodeID{} + } + } + + args := GetCurrentValidatorsArgs{ + SubnetID: subnetID, + } + reply := GetCurrentValidatorsReply{} + require.NoError(service.GetCurrentValidators(nil, &args, &reply)) + require.Len(reply.Validators, len(stakersByTxID)+len(l1ValidatorsByVID)) + for _, vdr := range reply.Validators { + testValidator(vdr) + } + + // Test with a specific node ID + var nodeIDs []ids.NodeID + if len(stakersByTxID) > 0 { + // pick the first staker + nodeIDs = append(nodeIDs, maps.Values(stakersByTxID)[0].NodeID) + } + if len(l1ValidatorsByVID) > 0 { + nodeIDs = append(nodeIDs, maps.Values(l1ValidatorsByVID)[0].NodeID) + } + + args.NodeIDs = nodeIDs + reply = GetCurrentValidatorsReply{} + require.NoError(service.GetCurrentValidators(nil, &args, &reply)) + require.Len(reply.Validators, len(nodeIDs)) + for i, vdr := range reply.Validators { + nodeID := testValidator(vdr) + require.Equal(args.NodeIDs[i], nodeID) + } + }) + } +} diff --git a/vms/platformvm/state/l1_validator.go b/vms/platformvm/state/l1_validator.go index 3f728f993ce4..1871e7cf12ef 100644 --- a/vms/platformvm/state/l1_validator.go +++ b/vms/platformvm/state/l1_validator.go @@ -161,33 +161,33 @@ func (v L1Validator) isDeleted() bool { return v.Weight == 0 } -func (v L1Validator) isActive() bool { +func (v L1Validator) IsActive() bool { return v.Weight != 0 && v.EndAccumulatedFee != 0 } func (v L1Validator) effectiveValidationID() ids.ID { - if v.isActive() { + if v.IsActive() { return v.ValidationID } return ids.Empty } func (v L1Validator) effectiveNodeID() ids.NodeID { - if v.isActive() { + if v.IsActive() { return v.NodeID } return ids.EmptyNodeID } func (v L1Validator) effectivePublicKey() *bls.PublicKey { - if v.isActive() { + if v.IsActive() { return bls.PublicKeyFromValidUncompressedBytes(v.PublicKey) } return nil } func (v L1Validator) effectivePublicKeyBytes() []byte { - if v.isActive() { + if v.IsActive() { return v.PublicKey } return nil @@ -298,7 +298,7 @@ func (d *l1ValidatorsDiff) putL1Validator(state Chain, l1Validator L1Validator) var ( prevWeight uint64 prevActive bool - newActive = l1Validator.isActive() + newActive = l1Validator.IsActive() ) switch priorL1Validator, err := state.GetL1Validator(l1Validator.ValidationID); err { case nil: @@ -307,7 +307,7 @@ func (d *l1ValidatorsDiff) putL1Validator(state Chain, l1Validator L1Validator) } prevWeight = priorL1Validator.Weight - prevActive = priorL1Validator.isActive() + prevActive = priorL1Validator.IsActive() case database.ErrNotFound: // Verify that there is not a legacy subnet validator with the same // subnetID+nodeID as this L1 validator. @@ -365,7 +365,7 @@ func (d *l1ValidatorsDiff) putL1Validator(state Chain, l1Validator L1Validator) nodeID: l1Validator.NodeID, } d.modifiedHasNodeIDs[subnetIDNodeID] = !l1Validator.isDeleted() - if l1Validator.isActive() { + if l1Validator.IsActive() { d.active.ReplaceOrInsert(l1Validator) } return nil diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 184b2c9563f3..d3d9544fb755 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -424,20 +424,21 @@ func (mr *MockStateMockRecorder) GetCurrentValidator(subnetID, nodeID any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentValidator", reflect.TypeOf((*MockState)(nil).GetCurrentValidator), subnetID, nodeID) } -// GetCurrentValidatorSet mocks base method. -func (m *MockState) GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) { +// GetCurrentValidators mocks base method. +func (m *MockState) GetCurrentValidators(subnetID ids.ID) ([]*Staker, []L1Validator, uint64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCurrentValidatorSet", ctx, subnetID) - ret0, _ := ret[0].(map[ids.ID]*validators.GetCurrentValidatorOutput) - ret1, _ := ret[1].(uint64) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret := m.ctrl.Call(m, "GetCurrentValidators", subnetID) + ret0, _ := ret[0].([]*Staker) + ret1, _ := ret[1].([]L1Validator) + ret2, _ := ret[2].(uint64) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 } -// GetCurrentValidatorSet indicates an expected call of GetCurrentValidatorSet. -func (mr *MockStateMockRecorder) GetCurrentValidatorSet(ctx, subnetID any) *gomock.Call { +// GetCurrentValidators indicates an expected call of GetCurrentValidators. +func (mr *MockStateMockRecorder) GetCurrentValidators(subnetID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentValidatorSet", reflect.TypeOf((*MockState)(nil).GetCurrentValidatorSet), ctx, subnetID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentValidators", reflect.TypeOf((*MockState)(nil).GetCurrentValidators), subnetID) } // GetDelegateeReward mocks base method. diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 95d973b2e4ed..f9598b69812f 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -214,7 +214,10 @@ type State interface { SetHeight(height uint64) - GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) + // GetCurrentValidators returns legacy stakers, L1 validators for the given subnetID along with current P-chain height. + // Note: This is most convenient to fetch validators of an L1. Even though it can still return + // stakers of a subnet, it is not recommended to use this method to fetch stakers of a subnet. + GetCurrentValidators(subnetID ids.ID) ([]*Staker, []L1Validator, uint64, error) // Discard uncommitted changes to the database. Abort() @@ -834,62 +837,34 @@ func (s *state) DeleteExpiry(entry ExpiryEntry) { s.expiryDiff.DeleteExpiry(entry) } -func (s *state) GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) { - result := make(map[ids.ID]*validators.GetCurrentValidatorOutput) +func (s *state) GetCurrentValidators(subnetID ids.ID) ([]*Staker, []L1Validator, uint64, error) { // First add the current validators (non-L1) - for _, staker := range s.currentStakers.validators[subnetID] { - if err := ctx.Err(); err != nil { - return nil, 0, err - } - validator := staker.validator - result[validator.TxID] = &validators.GetCurrentValidatorOutput{ - ValidationID: validator.TxID, - NodeID: validator.NodeID, - PublicKey: validator.PublicKey, - Weight: validator.Weight, - StartTime: uint64(validator.StartTime.Unix()), - MinNonce: 0, - IsActive: true, - IsL1Validator: false, + var legacyStakers []*Staker + if legacyBaseStakers, ok := s.currentStakers.validators[subnetID]; ok { + for _, staker := range legacyBaseStakers { + legacyStakers = append(legacyStakers, staker.validator) } } - // Then iterate over subnetIDNodeID DB and add the L1 validators (if any) - // TODO: consider optimizing this to avoid hitting the subnetIDNodeIDDB and read from actives lookup - // if all validators are active (inactive weight is 0) + // Then iterate over subnetIDNodeID DB and add the L1 validators + var l1Validators []L1Validator validationIDIter := s.subnetIDNodeIDDB.NewIteratorWithPrefix( subnetID[:], ) defer validationIDIter.Release() - for validationIDIter.Next() { - if err := ctx.Err(); err != nil { - return nil, 0, err - } - validationID, err := ids.ToID(validationIDIter.Value()) if err != nil { - return nil, 0, fmt.Errorf("failed to parse validation ID: %w", err) + return nil, nil, 0, fmt.Errorf("failed to parse validation ID: %w", err) } - vdr, err := s.GetL1Validator(validationID) if err != nil { - return nil, 0, fmt.Errorf("failed to get validator: %w", err) - } - - result[validationID] = &validators.GetCurrentValidatorOutput{ - ValidationID: validationID, - NodeID: vdr.NodeID, - PublicKey: bls.PublicKeyFromValidUncompressedBytes(vdr.PublicKey), - Weight: vdr.Weight, - StartTime: vdr.StartTime, - IsActive: vdr.isActive(), - MinNonce: vdr.MinNonce, - IsL1Validator: true, + return nil, nil, 0, fmt.Errorf("failed to get validator: %w", err) } + l1Validators = append(l1Validators, vdr) } - return result, s.currentHeight, nil + return legacyStakers, l1Validators, s.currentHeight, nil } func (s *state) GetActiveL1ValidatorsIterator() (iterator.Iterator[L1Validator], error) { @@ -2476,7 +2451,7 @@ func (s *state) updateValidatorManager(updateValidators bool) error { switch err { case nil: // Modifying an existing validator - if priorL1Validator.isActive() == l1Validator.isActive() { + if priorL1Validator.IsActive() == l1Validator.IsActive() { // This validator's active status isn't changing. This means // the effectiveNodeIDs are equal. nodeID := l1Validator.effectiveNodeID() @@ -2872,7 +2847,7 @@ func (s *state) writeL1Validators() error { // Add the new validator var err error - if l1Validator.isActive() { + if l1Validator.IsActive() { s.activeL1Validators.put(l1Validator) err = putL1Validator(s.activeDB, emptyL1ValidatorCache, l1Validator) } else { diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 3636046cabdc..185191df1cd7 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -1866,7 +1866,7 @@ func TestL1Validators(t *testing.T) { require.True(has) weights[l1Validator.SubnetID] += l1Validator.Weight - if expectedL1Validator.isActive() { + if expectedL1Validator.IsActive() { expectedActive = append(expectedActive, expectedL1Validator) } } @@ -2130,7 +2130,7 @@ func TestGetCurrentValidators(t *testing.T) { NodeID: ids.GenerateTestNodeID(), PublicKey: otherPK, Weight: 1, - StartTime: now, + StartTime: now.Add(1 * time.Second), }, }, }, @@ -2151,7 +2151,7 @@ func TestGetCurrentValidators(t *testing.T) { NodeID: ids.GenerateTestNodeID(), PublicKey: otherPK, Weight: 1, - StartTime: now, + StartTime: now.Add(1 * time.Second), }, }, }, @@ -2171,7 +2171,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID1, NodeID: ids.GenerateTestNodeID(), PublicKey: otherPKBytes, - StartTime: uint64(now.Unix()), + StartTime: uint64(now.Unix()) + 1, Weight: 1, }, }, @@ -2192,7 +2192,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID2, NodeID: ids.GenerateTestNodeID(), PublicKey: otherPKBytes, - StartTime: uint64(now.Unix()), + StartTime: uint64(now.Unix()) + 1, Weight: 1, }, }, @@ -2205,7 +2205,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID1, NodeID: ids.GenerateTestNodeID(), PublicKey: pk, - Weight: uint64(now.Unix()), + Weight: 123123, StartTime: now, }, { @@ -2214,7 +2214,7 @@ func TestGetCurrentValidators(t *testing.T) { NodeID: ids.GenerateTestNodeID(), PublicKey: pk, Weight: 1, - StartTime: now, + StartTime: now.Add(1 * time.Second), }, { TxID: ids.GenerateTestID(), @@ -2222,7 +2222,7 @@ func TestGetCurrentValidators(t *testing.T) { NodeID: ids.GenerateTestNodeID(), PublicKey: otherPK, Weight: 0, - StartTime: now, + StartTime: now.Add(2 * time.Second), }, }, l1Validators: []L1Validator{ @@ -2241,7 +2241,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID2, NodeID: ids.GenerateTestNodeID(), PublicKey: otherPKBytes, - StartTime: uint64(now.Unix()), + StartTime: uint64(now.Unix()) + 1, Weight: 0, EndAccumulatedFee: 0, }, @@ -2250,7 +2250,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID1, NodeID: ids.GenerateTestNodeID(), PublicKey: pkBytes, - StartTime: uint64(now.Unix()), + StartTime: uint64(now.Unix()) + 2, Weight: 1, EndAccumulatedFee: 0, }, @@ -2259,7 +2259,7 @@ func TestGetCurrentValidators(t *testing.T) { SubnetID: subnetID1, NodeID: ids.GenerateTestNodeID(), PublicKey: otherPKBytes, - StartTime: uint64(now.Unix()), + StartTime: uint64(now.Unix()) + 3, Weight: 0, EndAccumulatedFee: 1, }, @@ -2274,6 +2274,8 @@ func TestGetCurrentValidators(t *testing.T) { db := memdb.New() state := newTestState(t, db) + stakersLenBySubnetID := make(map[ids.ID]int) + stakersByTxID := make(map[ids.ID]*Staker) for _, staker := range test.initial { primaryStaker := &Staker{ TxID: ids.GenerateTestID(), @@ -2286,8 +2288,13 @@ func TestGetCurrentValidators(t *testing.T) { } require.NoError(state.PutCurrentValidator(primaryStaker)) require.NoError(state.PutCurrentValidator(staker)) + + stakersByTxID[staker.TxID] = staker + stakersLenBySubnetID[staker.SubnetID]++ } + l1ValidatorsLenBySubnetID := make(map[ids.ID]int) + l1ValidatorsByVID := make(map[ids.ID]L1Validator) for _, l1Validator := range test.l1Validators { // The codec creates zero length slices rather than leaving them // as nil, so we need to populate the slices for later reflect @@ -2296,57 +2303,30 @@ func TestGetCurrentValidators(t *testing.T) { l1Validator.DeactivationOwner = []byte{} require.NoError(state.PutL1Validator(l1Validator)) - } - state.SetHeight(0) - require.NoError(state.Commit()) - - stakersBySubnetID := make(map[ids.ID][]*Staker) - for _, staker := range test.initial { - stakers := stakersBySubnetID[staker.SubnetID] - stakersBySubnetID[staker.SubnetID] = append(stakers, staker) - } - - l1ValidatorsBySubnetID := make(map[ids.ID][]L1Validator) - for _, l1Validator := range test.l1Validators { if l1Validator.Weight == 0 { continue } - l1Validators := l1ValidatorsBySubnetID[l1Validator.SubnetID] - l1ValidatorsBySubnetID[l1Validator.SubnetID] = append(l1Validators, l1Validator) + l1ValidatorsByVID[l1Validator.ValidationID] = l1Validator + l1ValidatorsLenBySubnetID[l1Validator.SubnetID]++ } + state.SetHeight(0) + require.NoError(state.Commit()) + for _, subnetID := range subnetIDs { - currentValidators, height, err := state.GetCurrentValidatorSet(context.Background(), subnetID) + baseStakers, currentValidators, height, err := state.GetCurrentValidators(subnetID) require.NoError(err) require.Equal(uint64(0), height) - totalLen := len(stakersBySubnetID[subnetID]) + len(l1ValidatorsBySubnetID[subnetID]) - require.Len(currentValidators, totalLen) - - for _, expectedStaker := range stakersBySubnetID[subnetID] { - currentValidator, ok := currentValidators[expectedStaker.TxID] - require.True(ok) - require.Equal(expectedStaker.TxID, currentValidator.ValidationID) - require.Equal(expectedStaker.NodeID, currentValidator.NodeID) - require.Equal(expectedStaker.PublicKey, currentValidator.PublicKey) - require.Equal(expectedStaker.Weight, currentValidator.Weight) - require.Equal(uint64(expectedStaker.StartTime.Unix()), currentValidator.StartTime) - require.Equal(uint64(0), currentValidator.MinNonce) - require.True(currentValidator.IsActive) - require.False(currentValidator.IsL1Validator) + require.Len(baseStakers, stakersLenBySubnetID[subnetID]) + require.Len(currentValidators, l1ValidatorsLenBySubnetID[subnetID]) + + for i, currentStaker := range baseStakers { + require.Equal(stakersByTxID[currentStaker.TxID], currentStaker, "index %d", i) } - for _, expectedL1Validator := range l1ValidatorsBySubnetID[subnetID] { - currentValidator, ok := currentValidators[expectedL1Validator.ValidationID] - require.True(ok) - require.Equal(expectedL1Validator.ValidationID, currentValidator.ValidationID) - require.Equal(expectedL1Validator.NodeID, currentValidator.NodeID) - require.Equal(expectedL1Validator.PublicKey, currentValidator.PublicKey.Serialize()) - require.Equal(expectedL1Validator.Weight, currentValidator.Weight) - require.Equal(expectedL1Validator.StartTime, currentValidator.StartTime) - require.Equal(expectedL1Validator.MinNonce, currentValidator.MinNonce) - require.Equal(expectedL1Validator.isActive(), currentValidator.IsActive) - require.True(currentValidator.IsL1Validator) + for i, currentValidator := range currentValidators { + require.Equal(l1ValidatorsByVID[currentValidator.ValidationID], currentValidator, "index %d", i) } } }) diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index cb1cb51c8280..9f02cccd5f93 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -15,12 +15,14 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/window" "github.com/ava-labs/avalanchego/vms/platformvm/block" "github.com/ava-labs/avalanchego/vms/platformvm/config" "github.com/ava-labs/avalanchego/vms/platformvm/metrics" + "github.com/ava-labs/avalanchego/vms/platformvm/state" "github.com/ava-labs/avalanchego/vms/platformvm/status" "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) @@ -101,7 +103,7 @@ type State interface { subnetID ids.ID, ) error - GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) + GetCurrentValidators(subnetID ids.ID) ([]*state.Staker, []state.L1Validator, uint64, error) } func NewManager( @@ -413,5 +415,42 @@ func (m *manager) OnAcceptedBlockID(blkID ids.ID) { } func (m *manager) GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*validators.GetCurrentValidatorOutput, uint64, error) { - return m.state.GetCurrentValidatorSet(ctx, subnetID) + result := make(map[ids.ID]*validators.GetCurrentValidatorOutput) + baseStakers, l1Validators, height, err := m.state.GetCurrentValidators(subnetID) + if err != nil { + return nil, 0, err + } + + for _, validator := range baseStakers { + if err := ctx.Err(); err != nil { + return nil, 0, err + } + result[validator.TxID] = &validators.GetCurrentValidatorOutput{ + ValidationID: validator.TxID, + NodeID: validator.NodeID, + PublicKey: validator.PublicKey, + Weight: validator.Weight, + StartTime: uint64(validator.StartTime.Unix()), + MinNonce: 0, + IsActive: true, + IsL1Validator: false, + } + } + + for _, validator := range l1Validators { + if err := ctx.Err(); err != nil { + return nil, 0, err + } + result[validator.ValidationID] = &validators.GetCurrentValidatorOutput{ + ValidationID: validator.ValidationID, + NodeID: validator.NodeID, + PublicKey: bls.PublicKeyFromValidUncompressedBytes(validator.PublicKey), + Weight: validator.Weight, + StartTime: validator.StartTime, + IsActive: validator.IsActive(), + MinNonce: validator.MinNonce, + IsL1Validator: true, + } + } + return result, height, nil }