diff --git a/snow/snowtest/context.go b/snow/snowtest/context.go index 734c120e556d..edeefe89c8ec 100644 --- a/snow/snowtest/context.go +++ b/snow/snowtest/context.go @@ -15,7 +15,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators/validatorstest" - "github.com/ava-labs/avalanchego/upgrade" + "github.com/ava-labs/avalanchego/upgrade/upgradetest" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" @@ -84,7 +84,7 @@ func Context(tb testing.TB, chainID ids.ID) *snow.Context { ChainID: chainID, NodeID: ids.EmptyNodeID, PublicKey: publicKey, - NetworkUpgrades: upgrade.Default, + NetworkUpgrades: upgradetest.GetConfig(upgradetest.Latest), XChainID: XChainID, CChainID: CChainID, diff --git a/tests/upgrade/upgrade_test.go b/tests/upgrade/upgrade_test.go index 48bcc8b57c9d..4c0d8fbe79fd 100644 --- a/tests/upgrade/upgrade_test.go +++ b/tests/upgrade/upgrade_test.go @@ -4,8 +4,6 @@ package upgrade import ( - "encoding/base64" - "encoding/json" "flag" "fmt" "testing" @@ -13,10 +11,8 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/stretchr/testify/require" - "github.com/ava-labs/avalanchego/config" "github.com/ava-labs/avalanchego/tests/fixture/e2e" "github.com/ava-labs/avalanchego/tests/fixture/tmpnet" - "github.com/ava-labs/avalanchego/upgrade/upgradetest" ) func TestUpgrade(t *testing.T) { @@ -55,16 +51,6 @@ var _ = ginkgo.Describe("[Upgrade]", func() { require.NoError(err) network.Genesis = genesis - // Configure network upgrade flag - latestUnscheduled := upgradetest.GetConfig(upgradetest.Latest - 1) - upgradeJSON, err := json.Marshal(latestUnscheduled) - require.NoError(err) - upgradeBase64 := base64.StdEncoding.EncodeToString(upgradeJSON) - if network.DefaultFlags == nil { - network.DefaultFlags = tmpnet.FlagsMap{} - } - network.DefaultFlags[config.UpgradeFileContentKey] = upgradeBase64 - e2e.StartNetwork(tc, network, avalancheGoExecPath, "" /* pluginDir */, 0 /* shutdownDelay */, false /* reuseNetwork */) tc.By(fmt.Sprintf("restarting all nodes with %q binary", avalancheGoExecPathToUpgradeTo)) diff --git a/upgrade/upgrade.go b/upgrade/upgrade.go index fc1f55359a48..21e404ffb3f4 100644 --- a/upgrade/upgrade.go +++ b/upgrade/upgrade.go @@ -70,7 +70,9 @@ var ( CortinaTime: InitiallyActiveTime, CortinaXChainStopVertexID: ids.Empty, DurangoTime: InitiallyActiveTime, - EtnaTime: InitiallyActiveTime, + // Etna is left unactivated by default on local networks. It can be configured to + // activate by overriding the activation time in the upgrade file. + EtnaTime: UnscheduledActivationTime, } ErrInvalidUpgradeTimes = errors.New("invalid upgrade configuration") diff --git a/utils/iterator/filter.go b/utils/iterator/filter.go index e8a11464457d..f26b082aeab2 100644 --- a/utils/iterator/filter.go +++ b/utils/iterator/filter.go @@ -3,6 +3,8 @@ package iterator +import "github.com/ava-labs/avalanchego/utils/set" + var _ Iterator[any] = (*filtered[any])(nil) type filtered[T any] struct { @@ -19,6 +21,19 @@ func Filter[T any](it Iterator[T], filter func(T) bool) Iterator[T] { } } +// Deduplicate returns an iterator that skips the elements that have already +// been returned from [it]. +func Deduplicate[T comparable](it Iterator[T]) Iterator[T] { + var seen set.Set[T] + return Filter(it, func(e T) bool { + if seen.Contains(e) { + return true + } + seen.Add(e) + return false + }) +} + func (i *filtered[_]) Next() bool { for i.it.Next() { element := i.it.Value() diff --git a/utils/iterator/filter_test.go b/utils/iterator/filter_test.go index bf523017fe92..56c47892095e 100644 --- a/utils/iterator/filter_test.go +++ b/utils/iterator/filter_test.go @@ -10,8 +10,9 @@ import ( "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/vms/platformvm/state" + + . "github.com/ava-labs/avalanchego/utils/iterator" ) func TestFilter(t *testing.T) { @@ -40,8 +41,8 @@ func TestFilter(t *testing.T) { stakers[3].TxID: stakers[3], } - it := iterator.Filter( - iterator.FromSlice(stakers[:3]...), + it := Filter( + FromSlice(stakers[:3]...), func(staker *state.Staker) bool { _, ok := maskedStakers[staker.TxID] return ok @@ -55,3 +56,11 @@ func TestFilter(t *testing.T) { it.Release() require.False(it.Next()) } + +func TestDeduplicate(t *testing.T) { + require.Equal( + t, + []int{0, 1, 2, 3}, + ToSlice(Deduplicate(FromSlice(0, 1, 2, 1, 2, 0, 3))), + ) +} diff --git a/utils/math/continuous_averager.go b/utils/math/continuous_averager.go index 7bc892576b9f..ff9f0949f993 100644 --- a/utils/math/continuous_averager.go +++ b/utils/math/continuous_averager.go @@ -8,8 +8,6 @@ import ( "time" ) -var convertEToBase2 = math.Log(2) - type continuousAverager struct { halflife float64 weightedSum float64 @@ -34,7 +32,7 @@ func NewAverager( currentTime time.Time, ) Averager { return &continuousAverager{ - halflife: float64(halflife) / convertEToBase2, + halflife: float64(halflife) / math.Ln2, weightedSum: initialPrediction, normalizer: 1, lastUpdated: currentTime, diff --git a/utils/math/meter/continuous_meter.go b/utils/math/meter/continuous_meter.go index 378248a15027..ea3b8680f7ae 100644 --- a/utils/math/meter/continuous_meter.go +++ b/utils/math/meter/continuous_meter.go @@ -9,8 +9,6 @@ import ( ) var ( - convertEToBase2 = math.Log(2) - _ Factory = (*ContinuousFactory)(nil) _ Meter = (*continuousMeter)(nil) ) @@ -34,7 +32,7 @@ type continuousMeter struct { // NewMeter returns a new Meter with the provided halflife func NewMeter(halflife time.Duration) Meter { return &continuousMeter{ - halflife: float64(halflife) / convertEToBase2, + halflife: float64(halflife) / math.Ln2, } } diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index aceecc47ad56..24bdabfa96da 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -42,6 +42,8 @@ type diff struct { // Subnet ID --> supply of native asset of the subnet currentSupply map[ids.ID]uint64 + expiryDiff *expiryDiff + currentStakerDiffs diffStakers // map of subnetID -> nodeID -> total accrued delegatee rewards modifiedDelegateeRewards map[ids.ID]map[ids.NodeID]uint64 @@ -79,6 +81,7 @@ func NewDiff( timestamp: parentState.GetTimestamp(), feeState: parentState.GetFeeState(), accruedFees: parentState.GetAccruedFees(), + expiryDiff: newExpiryDiff(), subnetOwners: make(map[ids.ID]fx.Owner), subnetManagers: make(map[ids.ID]chainIDAndAddr), }, nil @@ -146,6 +149,41 @@ func (d *diff) SetCurrentSupply(subnetID ids.ID, currentSupply uint64) { } } +func (d *diff) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + parentIterator, err := parentState.GetExpiryIterator() + if err != nil { + return nil, err + } + + return d.expiryDiff.getExpiryIterator(parentIterator), nil +} + +func (d *diff) HasExpiry(entry ExpiryEntry) (bool, error) { + if has, modified := d.expiryDiff.modified[entry]; modified { + return has, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return false, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + return parentState.HasExpiry(entry) +} + +func (d *diff) PutExpiry(entry ExpiryEntry) { + d.expiryDiff.PutExpiry(entry) +} + +func (d *diff) DeleteExpiry(entry ExpiryEntry) { + d.expiryDiff.DeleteExpiry(entry) +} + func (d *diff) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { // If the validator was modified in this diff, return the modified // validator. @@ -451,6 +489,13 @@ func (d *diff) Apply(baseState Chain) error { for subnetID, supply := range d.currentSupply { baseState.SetCurrentSupply(subnetID, supply) } + for entry, isAdded := range d.expiryDiff.modified { + if isAdded { + baseState.PutExpiry(entry) + } else { + baseState.DeleteExpiry(entry) + } + } for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs { for _, validatorDiff := range subnetValidatorDiffs { switch validatorDiff.validatorStatus { diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index c4246b8aaeec..b71599928e23 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -17,6 +17,7 @@ import ( "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/iterator/iteratormock" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/gas" "github.com/ava-labs/avalanchego/vms/platformvm/fx/fxmock" @@ -112,6 +113,156 @@ func TestDiffCurrentSupply(t *testing.T) { assertChainsEqual(t, state, d) } +func TestDiffExpiry(t *testing.T) { + type op struct { + put bool + entry ExpiryEntry + } + tests := []struct { + name string + initialExpiries []ExpiryEntry + ops []op + }{ + { + name: "empty noop", + }, + { + name: "insert", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "remove", + initialExpiries: []ExpiryEntry{ + {Timestamp: 1}, + }, + ops: []op{ + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "add and immediately remove", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "add + remove + add", + ops: []op{ + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + { + name: "everything", + initialExpiries: []ExpiryEntry{ + {Timestamp: 1}, + {Timestamp: 2}, + {Timestamp: 3}, + }, + ops: []op{ + { + put: false, + entry: ExpiryEntry{Timestamp: 1}, + }, + { + put: false, + entry: ExpiryEntry{Timestamp: 2}, + }, + { + put: true, + entry: ExpiryEntry{Timestamp: 1}, + }, + }, + }, + } + + for _, test := range tests { + require := require.New(t) + + state := newTestState(t, memdb.New()) + for _, expiry := range test.initialExpiries { + state.PutExpiry(expiry) + } + + d, err := NewDiffOn(state) + require.NoError(err) + + var ( + expectedExpiries = set.Of(test.initialExpiries...) + unexpectedExpiries set.Set[ExpiryEntry] + ) + for _, op := range test.ops { + if op.put { + d.PutExpiry(op.entry) + expectedExpiries.Add(op.entry) + unexpectedExpiries.Remove(op.entry) + } else { + d.DeleteExpiry(op.entry) + expectedExpiries.Remove(op.entry) + unexpectedExpiries.Add(op.entry) + } + } + + // If expectedExpiries is empty, we want expectedExpiriesSlice to be + // nil. + var expectedExpiriesSlice []ExpiryEntry + if expectedExpiries.Len() > 0 { + expectedExpiriesSlice = expectedExpiries.List() + utils.Sort(expectedExpiriesSlice) + } + + verifyChain := func(chain Chain) { + expiryIterator, err := chain.GetExpiryIterator() + require.NoError(err) + require.Equal( + expectedExpiriesSlice, + iterator.ToSlice(expiryIterator), + ) + + for expiry := range expectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.True(has) + } + for expiry := range unexpectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.False(has) + } + } + + verifyChain(d) + require.NoError(d.Apply(state)) + verifyChain(state) + assertChainsEqual(t, d, state) + } +} + func TestDiffCurrentValidator(t *testing.T) { require := require.New(t) ctrl := gomock.NewController(t) @@ -527,6 +678,16 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { t.Helper() + expectedExpiryIterator, expectedErr := expected.GetExpiryIterator() + actualExpiryIterator, actualErr := actual.GetExpiryIterator() + require.Equal(expectedErr, actualErr) + if expectedErr == nil { + require.Equal( + iterator.ToSlice(expectedExpiryIterator), + iterator.ToSlice(actualExpiryIterator), + ) + } + expectedCurrentStakerIterator, expectedErr := expected.GetCurrentStakerIterator() actualCurrentStakerIterator, actualErr := actual.GetCurrentStakerIterator() require.Equal(expectedErr, actualErr) diff --git a/vms/platformvm/state/expiry.go b/vms/platformvm/state/expiry.go new file mode 100644 index 000000000000..b50439ddf20c --- /dev/null +++ b/vms/platformvm/state/expiry.go @@ -0,0 +1,114 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "encoding/binary" + "fmt" + + "github.com/google/btree" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/iterator" +) + +// expiryEntry = [timestamp] + [validationID] +const expiryEntryLength = database.Uint64Size + ids.IDLen + +var ( + errUnexpectedExpiryEntryLength = fmt.Errorf("expected expiry entry length %d", expiryEntryLength) + + _ btree.LessFunc[ExpiryEntry] = ExpiryEntry.Less + _ utils.Sortable[ExpiryEntry] = ExpiryEntry{} +) + +type Expiry interface { + // GetExpiryIterator returns an iterator of all the expiry entries in order + // of lowest to highest timestamp. + GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) + + // HasExpiry returns true if the database has the specified entry. + HasExpiry(ExpiryEntry) (bool, error) + + // PutExpiry adds the entry to the database. If the entry already exists, it + // is a noop. + PutExpiry(ExpiryEntry) + + // DeleteExpiry removes the entry from the database. If the entry doesn't + // exist, it is a noop. + DeleteExpiry(ExpiryEntry) +} + +type ExpiryEntry struct { + Timestamp uint64 + ValidationID ids.ID +} + +func (e *ExpiryEntry) Marshal() []byte { + data := make([]byte, expiryEntryLength) + binary.BigEndian.PutUint64(data, e.Timestamp) + copy(data[database.Uint64Size:], e.ValidationID[:]) + return data +} + +func (e *ExpiryEntry) Unmarshal(data []byte) error { + if len(data) != expiryEntryLength { + return errUnexpectedExpiryEntryLength + } + + e.Timestamp = binary.BigEndian.Uint64(data) + copy(e.ValidationID[:], data[database.Uint64Size:]) + return nil +} + +func (e ExpiryEntry) Less(o ExpiryEntry) bool { + return e.Compare(o) == -1 +} + +// Invariant: Compare produces the same ordering as the marshalled bytes. +func (e ExpiryEntry) Compare(o ExpiryEntry) int { + switch { + case e.Timestamp < o.Timestamp: + return -1 + case e.Timestamp > o.Timestamp: + return 1 + default: + return e.ValidationID.Compare(o.ValidationID) + } +} + +type expiryDiff struct { + modified map[ExpiryEntry]bool // bool represents isAdded + added *btree.BTreeG[ExpiryEntry] +} + +func newExpiryDiff() *expiryDiff { + return &expiryDiff{ + modified: make(map[ExpiryEntry]bool), + added: btree.NewG(defaultTreeDegree, ExpiryEntry.Less), + } +} + +func (e *expiryDiff) PutExpiry(entry ExpiryEntry) { + e.modified[entry] = true + e.added.ReplaceOrInsert(entry) +} + +func (e *expiryDiff) DeleteExpiry(entry ExpiryEntry) { + e.modified[entry] = false + e.added.Delete(entry) +} + +func (e *expiryDiff) getExpiryIterator(parentIterator iterator.Iterator[ExpiryEntry]) iterator.Iterator[ExpiryEntry] { + return iterator.Merge( + ExpiryEntry.Less, + iterator.Filter(parentIterator, func(entry ExpiryEntry) bool { + _, ok := e.modified[entry] + return ok + }), + iterator.FromTree(e.added), + ) +} diff --git a/vms/platformvm/state/expiry_test.go b/vms/platformvm/state/expiry_test.go new file mode 100644 index 000000000000..38a0d07f9ab0 --- /dev/null +++ b/vms/platformvm/state/expiry_test.go @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + "github.com/thepudds/fzgen/fuzzer" +) + +func FuzzExpiryEntryMarshal(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + require := require.New(t) + + var entry ExpiryEntry + fz := fuzzer.NewFuzzer(data) + fz.Fill(&entry) + + marshalledData := entry.Marshal() + + var parsedEntry ExpiryEntry + require.NoError(parsedEntry.Unmarshal(marshalledData)) + require.Equal(entry, parsedEntry) + }) +} + +func FuzzExpiryEntryLessAndMarshalOrdering(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + var ( + entry0 ExpiryEntry + entry1 ExpiryEntry + ) + fz := fuzzer.NewFuzzer(data) + fz.Fill(&entry0, &entry1) + + key0 := entry0.Marshal() + key1 := entry1.Marshal() + require.Equal( + t, + entry0.Less(entry1), + bytes.Compare(key0, key1) == -1, + ) + }) +} + +func FuzzExpiryEntryUnmarshal(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + require := require.New(t) + + var entry ExpiryEntry + if err := entry.Unmarshal(data); err != nil { + require.ErrorIs(err, errUnexpectedExpiryEntryLength) + return + } + + marshalledData := entry.Marshal() + require.Equal(data, marshalledData) + }) +} diff --git a/vms/platformvm/state/mock_chain.go b/vms/platformvm/state/mock_chain.go index 1891b74099d8..3b380a87a8b8 100644 --- a/vms/platformvm/state/mock_chain.go +++ b/vms/platformvm/state/mock_chain.go @@ -142,6 +142,18 @@ func (mr *MockChainMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockChain)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockChain) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockChainMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockChain)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockChain) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -267,6 +279,21 @@ func (mr *MockChainMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockChain)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockChain) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockChainMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockChain)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockChain) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -417,6 +444,21 @@ func (mr *MockChainMockRecorder) GetUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUTXO", reflect.TypeOf((*MockChain)(nil).GetUTXO), utxoID) } +// HasExpiry mocks base method. +func (m *MockChain) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockChainMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockChain)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockChain) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -443,6 +485,18 @@ func (mr *MockChainMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockChain)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockChain) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockChainMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockChain)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockChain) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_diff.go b/vms/platformvm/state/mock_diff.go index 557c59faf58c..77edfde92aaf 100644 --- a/vms/platformvm/state/mock_diff.go +++ b/vms/platformvm/state/mock_diff.go @@ -156,6 +156,18 @@ func (mr *MockDiffMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockDiff)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockDiff) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockDiffMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockDiff)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockDiff) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -281,6 +293,21 @@ func (mr *MockDiffMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockDiff)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockDiff) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockDiffMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockDiff)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockDiff) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -431,6 +458,21 @@ func (mr *MockDiffMockRecorder) GetUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUTXO", reflect.TypeOf((*MockDiff)(nil).GetUTXO), utxoID) } +// HasExpiry mocks base method. +func (m *MockDiff) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockDiffMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockDiff)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockDiff) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -457,6 +499,18 @@ func (mr *MockDiffMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockDiff)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockDiff) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockDiffMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockDiff)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockDiff) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index b6f0c7bd6b8a..2f8ddaa4bc85 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -257,6 +257,18 @@ func (mr *MockStateMockRecorder) DeleteCurrentValidator(staker any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCurrentValidator", reflect.TypeOf((*MockState)(nil).DeleteCurrentValidator), staker) } +// DeleteExpiry mocks base method. +func (m *MockState) DeleteExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteExpiry", arg0) +} + +// DeleteExpiry indicates an expected call of DeleteExpiry. +func (mr *MockStateMockRecorder) DeleteExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiry", reflect.TypeOf((*MockState)(nil).DeleteExpiry), arg0) +} + // DeletePendingDelegator mocks base method. func (m *MockState) DeletePendingDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -412,6 +424,21 @@ func (mr *MockStateMockRecorder) GetDelegateeReward(subnetID, nodeID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegateeReward", reflect.TypeOf((*MockState)(nil).GetDelegateeReward), subnetID, nodeID) } +// GetExpiryIterator mocks base method. +func (m *MockState) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiryIterator") + ret0, _ := ret[0].(iterator.Iterator[ExpiryEntry]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiryIterator indicates an expected call of GetExpiryIterator. +func (mr *MockStateMockRecorder) GetExpiryIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiryIterator", reflect.TypeOf((*MockState)(nil).GetExpiryIterator)) +} + // GetFeeState mocks base method. func (m *MockState) GetFeeState() gas.State { m.ctrl.T.Helper() @@ -652,6 +679,21 @@ func (mr *MockStateMockRecorder) GetUptime(nodeID, subnetID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUptime", reflect.TypeOf((*MockState)(nil).GetUptime), nodeID, subnetID) } +// HasExpiry mocks base method. +func (m *MockState) HasExpiry(arg0 ExpiryEntry) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasExpiry", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasExpiry indicates an expected call of HasExpiry. +func (mr *MockStateMockRecorder) HasExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockState)(nil).HasExpiry), arg0) +} + // PutCurrentDelegator mocks base method. func (m *MockState) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -678,6 +720,18 @@ func (mr *MockStateMockRecorder) PutCurrentValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutCurrentValidator", reflect.TypeOf((*MockState)(nil).PutCurrentValidator), staker) } +// PutExpiry mocks base method. +func (m *MockState) PutExpiry(arg0 ExpiryEntry) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "PutExpiry", arg0) +} + +// PutExpiry indicates an expected call of PutExpiry. +func (mr *MockStateMockRecorder) PutExpiry(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutExpiry", reflect.TypeOf((*MockState)(nil).PutExpiry), arg0) +} + // PutPendingDelegator mocks base method. func (m *MockState) PutPendingDelegator(staker *Staker) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 849a1bb1f110..50f00e1fe8f7 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -82,6 +82,7 @@ var ( TransformedSubnetPrefix = []byte("transformedSubnet") SupplyPrefix = []byte("supply") ChainPrefix = []byte("chain") + ExpiryReplayProtectionPrefix = []byte("expiryReplayProtection") SingletonPrefix = []byte("singleton") TimestampKey = []byte("timestamp") @@ -97,6 +98,7 @@ var ( // Chain collects all methods to manage the state of the chain for block // execution. type Chain interface { + Expiry Stakers avax.UTXOAdder avax.UTXOGetter @@ -278,6 +280,8 @@ type stateBlk struct { * | '-. subnetID * | '-. list * | '-- txID -> nil + * |-. expiryReplayProtection + * | '-- timestamp + validationID -> nil * '-. singletons * |-- initializedKey -> nil * |-- blocksReindexedKey -> nil @@ -299,6 +303,10 @@ type state struct { baseDB *versiondb.Database + expiry *btree.BTreeG[ExpiryEntry] + expiryDiff *expiryDiff + expiryDB database.Database + currentStakers *baseStakers pendingStakers *baseStakers @@ -610,6 +618,10 @@ func New( blockCache: blockCache, blockDB: prefixdb.New(BlockPrefix, baseDB), + expiry: btree.NewG(defaultTreeDegree, ExpiryEntry.Less), + expiryDiff: newExpiryDiff(), + expiryDB: prefixdb.New(ExpiryReplayProtectionPrefix, baseDB), + currentStakers: newBaseStakers(), pendingStakers: newBaseStakers(), @@ -684,6 +696,27 @@ func New( return s, nil } +func (s *state) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) { + return s.expiryDiff.getExpiryIterator( + iterator.FromTree(s.expiry), + ), nil +} + +func (s *state) HasExpiry(entry ExpiryEntry) (bool, error) { + if has, modified := s.expiryDiff.modified[entry]; modified { + return has, nil + } + return s.expiry.Has(entry), nil +} + +func (s *state) PutExpiry(entry ExpiryEntry) { + s.expiryDiff.PutExpiry(entry) +} + +func (s *state) DeleteExpiry(entry ExpiryEntry) { + s.expiryDiff.DeleteExpiry(entry) +} + func (s *state) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { return s.currentStakers.GetValidator(subnetID, nodeID) } @@ -1336,6 +1369,7 @@ func (s *state) syncGenesis(genesisBlk block.Block, genesis *genesis.Genesis) er func (s *state) load() error { return errors.Join( s.loadMetadata(), + s.loadExpiry(), s.loadCurrentValidators(), s.loadPendingValidators(), s.initValidatorSets(), @@ -1407,6 +1441,23 @@ func (s *state) loadMetadata() error { return nil } +func (s *state) loadExpiry() error { + it := s.expiryDB.NewIterator() + defer it.Release() + + for it.Next() { + key := it.Key() + + var entry ExpiryEntry + if err := entry.Unmarshal(key); err != nil { + return fmt.Errorf("failed to unmarshal ExpiryEntry during load: %w", err) + } + s.expiry.ReplaceOrInsert(entry) + } + + return nil +} + func (s *state) loadCurrentValidators() error { s.currentStakers = newBaseStakers() @@ -1705,6 +1756,7 @@ func (s *state) write(updateValidators bool, height uint64) error { return errors.Join( s.writeBlocks(), + s.writeExpiry(), s.writeCurrentStakers(updateValidators, height, codecVersion), s.writePendingStakers(), s.WriteValidatorMetadata(s.currentValidatorList, s.currentSubnetValidatorList, codecVersion), // Must be called after writeCurrentStakers @@ -1723,6 +1775,7 @@ func (s *state) write(updateValidators bool, height uint64) error { func (s *state) Close() error { return errors.Join( + s.expiryDB.Close(), s.pendingSubnetValidatorBaseDB.Close(), s.pendingSubnetDelegatorBaseDB.Close(), s.pendingDelegatorBaseDB.Close(), @@ -1929,6 +1982,28 @@ func (s *state) GetBlockIDAtHeight(height uint64) (ids.ID, error) { return blkID, nil } +func (s *state) writeExpiry() error { + for entry, isAdded := range s.expiryDiff.modified { + var ( + key = entry.Marshal() + err error + ) + if isAdded { + s.expiry.ReplaceOrInsert(entry) + err = s.expiryDB.Put(key, nil) + } else { + s.expiry.Delete(entry) + err = s.expiryDB.Delete(key) + } + if err != nil { + return err + } + } + + s.expiryDiff = newExpiryDiff() + return nil +} + func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecVersion uint16) error { for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { delete(s.currentStakers.validatorDiffs, subnetID) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 515794364d71..a0dcd0cb7dcc 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -1577,3 +1577,35 @@ func TestGetFeeStateErrors(t *testing.T) { }) } } + +// Verify that committing the state writes the expiry changes to the database +// and that loading the state fetches the expiry from the database. +func TestStateExpiryCommitAndLoad(t *testing.T) { + require := require.New(t) + + db := memdb.New() + s := newTestState(t, db) + + // Populate an entry. + expiry := ExpiryEntry{ + Timestamp: 1, + } + s.PutExpiry(expiry) + require.NoError(s.Commit()) + + // Verify that the entry was written and loaded correctly. + s = newTestState(t, db) + has, err := s.HasExpiry(expiry) + require.NoError(err) + require.True(has) + + // Delete an entry. + s.DeleteExpiry(expiry) + require.NoError(s.Commit()) + + // Verify that the entry was deleted correctly. + s = newTestState(t, db) + has, err = s.HasExpiry(expiry) + require.NoError(err) + require.False(has) +} diff --git a/vms/platformvm/validators/fee/fee.go b/vms/platformvm/validators/fee/fee.go new file mode 100644 index 000000000000..be0eae544ea2 --- /dev/null +++ b/vms/platformvm/validators/fee/fee.go @@ -0,0 +1,167 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package fee + +import ( + "math" + + "github.com/ava-labs/avalanchego/vms/components/gas" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +// Config contains all the static parameters of the dynamic fee mechanism. +type Config struct { + Target gas.Gas `json:"target"` + MinPrice gas.Price `json:"minPrice"` + ExcessConversionConstant gas.Gas `json:"excessConversionConstant"` +} + +// State represents the current on-chain values used in the dynamic fee +// mechanism. +type State struct { + Current gas.Gas `json:"current"` + Excess gas.Gas `json:"excess"` +} + +// AdvanceTime adds (s.Current - target) * seconds to Excess. +// +// If Excess would underflow, it is set to 0. +// If Excess would overflow, it is set to MaxUint64. +func (s State) AdvanceTime(target gas.Gas, seconds uint64) State { + excess := s.Excess + if s.Current < target { + excess = excess.SubPerSecond(target-s.Current, seconds) + } else if s.Current > target { + excess = excess.AddPerSecond(s.Current-target, seconds) + } + return State{ + Current: s.Current, + Excess: excess, + } +} + +// CostOf calculates how much to charge based on the dynamic fee mechanism for +// [seconds]. +// +// This implements the ACP-77 cost over time formula: +func (s State) CostOf(c Config, seconds uint64) uint64 { + // If the current and target are the same, the price is constant. + if s.Current == c.Target { + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + cost, err := safemath.Mul(seconds, uint64(price)) + if err != nil { + return math.MaxUint64 + } + return cost + } + + var ( + cost uint64 + err error + ) + for i := uint64(0); i < seconds; i++ { + s = s.AdvanceTime(c.Target, 1) + + // Advancing the time is going to either hold excess constant, + // monotonically increase it, or monotonically decrease it. If it is + // equal to 0 after performing one of these operations, it is guaranteed + // to always remain 0. + if s.Excess == 0 { + secondsWithZeroExcess := seconds - i + zeroExcessCost, err := safemath.Mul(uint64(c.MinPrice), secondsWithZeroExcess) + if err != nil { + return math.MaxUint64 + } + + cost, err = safemath.Add(cost, zeroExcessCost) + if err != nil { + return math.MaxUint64 + } + return cost + } + + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + cost, err = safemath.Add(cost, uint64(price)) + if err != nil { + return math.MaxUint64 + } + } + return cost +} + +// SecondsUntil calculates the number of seconds that it would take to charge at +// least [targetCost] based on the dynamic fee mechanism. The result is capped +// at [maxSeconds]. +func (s State) SecondsUntil(c Config, maxSeconds uint64, targetCost uint64) uint64 { + // Because this function can divide by prices, we need to sanity check the + // parameters to avoid division by 0. + if c.MinPrice == 0 { + if targetCost == 0 { + return 0 + } + return maxSeconds + } + + // If the current and target are the same, the price is constant. + if s.Current == c.Target { + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + return secondsUntil( + uint64(price), + maxSeconds, + targetCost, + ) + } + + var ( + cost uint64 + seconds uint64 + err error + ) + for cost < targetCost && seconds < maxSeconds { + s = s.AdvanceTime(c.Target, 1) + + // Advancing the time is going to either hold excess constant, + // monotonically increase it, or monotonically decrease it. If it is + // equal to 0 after performing one of these operations, it is guaranteed + // to always remain 0. + if s.Excess == 0 { + zeroExcessCost := targetCost - cost + secondsWithZeroExcess := secondsUntil( + uint64(c.MinPrice), + maxSeconds, + zeroExcessCost, + ) + + totalSeconds, err := safemath.Add(seconds, secondsWithZeroExcess) + if err != nil || totalSeconds >= maxSeconds { + return maxSeconds + } + return totalSeconds + } + + seconds++ + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + cost, err = safemath.Add(cost, uint64(price)) + if err != nil { + return seconds + } + } + return seconds +} + +// Calculate the number of seconds that it would take to charge at least [cost] +// at [price] every second. The result is capped at [maxSeconds]. +func secondsUntil(price uint64, maxSeconds uint64, cost uint64) uint64 { + // Directly rounding up could cause an overflow. Instead we round down and + // then check if we should have rounded up. + secondsRoundedDown := cost / price + if secondsRoundedDown >= maxSeconds { + return maxSeconds + } + if cost%price == 0 { + return secondsRoundedDown + } + return secondsRoundedDown + 1 +} diff --git a/vms/platformvm/validators/fee/fee_test.go b/vms/platformvm/validators/fee/fee_test.go new file mode 100644 index 000000000000..059fbad61c51 --- /dev/null +++ b/vms/platformvm/validators/fee/fee_test.go @@ -0,0 +1,460 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package fee + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/vms/components/gas" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +const ( + second = 1 + minute = 60 * second + hour = 60 * minute + day = 24 * hour + week = 7 * day + year = 365 * day + + minPrice = 2_048 + + capacity = 20_000 + target = 10_000 + maxExcessIncreasePerSecond = capacity - target + doubleEvery = day + excessIncreasePerDoubling = maxExcessIncreasePerSecond * doubleEvery + excessConversionConstantFloat = excessIncreasePerDoubling / math.Ln2 +) + +var ( + excessConversionConstant = floatToGas(excessConversionConstantFloat) + + tests = []struct { + name string + state State + config Config + + // expectedSeconds and expectedCost are used as inputs for some tests + // and outputs for other tests. + expectedSeconds uint64 + expectedCost uint64 + expectedExcess gas.Gas + }{ + { + name: "excess=0, currenttarget, minute", + state: State{ + Current: 15_000, + Excess: 0, + }, + config: Config{ + Target: 10_000, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: minute, + expectedCost: 122_880, + expectedExcess: 5_000 * minute, + }, + { + name: "excess hits 0 during, currenttarget, day", + state: State{ + Current: 15_000, + Excess: 0, + }, + config: Config{ + Target: 10_000, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: day, + expectedCost: 211_438_809, + expectedExcess: 5_000 * day, + }, + { + name: "excess=0, current=target, week", + state: State{ + Current: 10_000, + Excess: 0, + }, + config: Config{ + Target: 10_000, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: week, + expectedCost: 1_238_630_400, + expectedExcess: 0, + }, + { + name: "excess=0, current>target, week", + state: State{ + Current: 15_000, + Excess: 0, + }, + config: Config{ + Target: 10_000, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: week, + expectedCost: 5_265_492_669, + expectedExcess: 5_000 * week, + }, + { + name: "excess=1, current>>target, second", + state: State{ + Current: math.MaxUint64, + Excess: 1, + }, + config: Config{ + Target: 0, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: 1, + expectedCost: math.MaxUint64, // Should not overflow + expectedExcess: math.MaxUint64, // Should not overflow + }, + { + name: "excess=0, current>>target, 11 seconds", + state: State{ + Current: math.MaxUint32, + Excess: 0, + }, + config: Config{ + Target: 0, + MinPrice: minPrice, + ExcessConversionConstant: excessConversionConstant, + }, + expectedSeconds: 11, + expectedCost: math.MaxUint64, // Should not overflow + expectedExcess: math.MaxUint32 * 11, + }, + } +) + +func TestStateAdvanceTime(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal( + t, + State{ + Current: test.state.Current, + Excess: test.expectedExcess, + }, + test.state.AdvanceTime(test.config.Target, test.expectedSeconds), + ) + }) + } +} + +func TestStateCostOf(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal( + t, + test.expectedCost, + test.state.CostOf(test.config, test.expectedSeconds), + ) + }) + } +} + +func TestStateSecondsUntil(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal( + t, + test.expectedSeconds, + test.state.SecondsUntil(test.config, year, test.expectedCost), + ) + }) + } +} + +func BenchmarkStateCostOf(b *testing.B) { + benchmarks := []struct { + name string + costOf func( + s State, + c Config, + seconds uint64, + ) uint64 + }{ + { + name: "unoptimized", + costOf: State.unoptimizedCostOf, + }, + { + name: "optimized", + costOf: State.CostOf, + }, + } + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + for _, benchmark := range benchmarks { + b.Run(benchmark.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + benchmark.costOf(test.state, test.config, test.expectedSeconds) + } + }) + } + }) + } +} + +func BenchmarkStateSecondsUntil(b *testing.B) { + benchmarks := []struct { + name string + secondsUntil func( + s State, + c Config, + maxSeconds uint64, + targetCost uint64, + ) uint64 + }{ + { + name: "unoptimized", + secondsUntil: State.unoptimizedSecondsUntil, + }, + { + name: "optimized", + secondsUntil: State.SecondsUntil, + }, + } + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + for _, benchmark := range benchmarks { + b.Run(benchmark.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + benchmark.secondsUntil(test.state, test.config, year, test.expectedCost) + } + }) + } + }) + } +} + +func FuzzStateCostOf(f *testing.F) { + for _, test := range tests { + f.Add( + uint64(test.state.Current), + uint64(test.state.Excess), + uint64(test.config.Target), + uint64(test.config.MinPrice), + uint64(test.config.ExcessConversionConstant), + test.expectedSeconds, + ) + } + f.Fuzz( + func( + t *testing.T, + current uint64, + excess uint64, + target uint64, + minPrice uint64, + excessConversionConstant uint64, + seconds uint64, + ) { + s := State{ + Current: gas.Gas(current), + Excess: gas.Gas(excess), + } + c := Config{ + Target: gas.Gas(target), + MinPrice: gas.Price(minPrice), + ExcessConversionConstant: gas.Gas(max(excessConversionConstant, 1)), + } + seconds = min(seconds, year) + require.Equal( + t, + s.unoptimizedCostOf(c, seconds), + s.CostOf(c, seconds), + ) + }, + ) +} + +func FuzzStateSecondsUntil(f *testing.F) { + for _, test := range tests { + f.Add( + uint64(test.state.Current), + uint64(test.state.Excess), + uint64(test.config.Target), + uint64(test.config.MinPrice), + uint64(test.config.ExcessConversionConstant), + uint64(year), + test.expectedCost, + ) + } + f.Fuzz( + func( + t *testing.T, + current uint64, + excess uint64, + target uint64, + minPrice uint64, + excessConversionConstant uint64, + maxSeconds uint64, + targetCost uint64, + ) { + s := State{ + Current: gas.Gas(current), + Excess: gas.Gas(excess), + } + c := Config{ + Target: gas.Gas(target), + MinPrice: gas.Price(minPrice), + ExcessConversionConstant: gas.Gas(max(excessConversionConstant, 1)), + } + maxSeconds = min(maxSeconds, year) + require.Equal( + t, + s.unoptimizedSecondsUntil(c, maxSeconds, targetCost), + s.SecondsUntil(c, maxSeconds, targetCost), + ) + }, + ) +} + +// unoptimizedCalculateCost is a naive implementation of CostOf that is used for +// differential fuzzing. +func (s State) unoptimizedCostOf(c Config, seconds uint64) uint64 { + var ( + cost uint64 + err error + ) + for i := uint64(0); i < seconds; i++ { + s = s.AdvanceTime(c.Target, 1) + + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + cost, err = safemath.Add(cost, uint64(price)) + if err != nil { + return math.MaxUint64 + } + } + return cost +} + +// unoptimizedSecondsUntil is a naive implementation of SecondsUntil that is +// used for differential fuzzing. +func (s State) unoptimizedSecondsUntil(c Config, maxSeconds uint64, targetCost uint64) uint64 { + var ( + cost uint64 + seconds uint64 + err error + ) + for cost < targetCost && seconds < maxSeconds { + s = s.AdvanceTime(c.Target, 1) + seconds++ + + price := gas.CalculatePrice(c.MinPrice, s.Excess, c.ExcessConversionConstant) + cost, err = safemath.Add(cost, uint64(price)) + if err != nil { + return seconds + } + } + return seconds +} + +// floatToGas converts f to gas.Gas by truncation. `gas.Gas(f)` is preferred and +// floatToGas should only be used if its argument is a `const`. +func floatToGas(f float64) gas.Gas { + return gas.Gas(f) +}