Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
bsamuels453 committed Nov 30, 2024
1 parent 056d720 commit c9cac8c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 29 deletions.
31 changes: 17 additions & 14 deletions chain/fork/remote_state_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,12 @@ package fork

import (
"fmt"
"github.com/crytic/medusa/chain/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/holiman/uint256"
)

/*
type MedusaStateDB interface {
vm.StateDB
// add the extra methods that Medusa uses.
IntermediateRoot(bool) common.Hash
Finalise(bool)
GetLogs(common.Hash, uint64, common.Hash) []*types.Log
TxIndex() int
SetBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason)
SetTxContext(common.Hash, int)
Commit(uint64, bool) common.Hash
}*/

var _ state.RemoteStateProvider = (*RemoteStateProvider)(nil)
var _ state.RemoteStateProviderFactory = (*RemoteStateProviderFactory)(nil)

Expand Down Expand Up @@ -160,6 +147,22 @@ type RemoteStateProviderFactory struct {
RemoteStateCache
}

func NewRemoteStateProviderFactory(cache RemoteStateCache) *RemoteStateProviderFactory {
return &RemoteStateProviderFactory{cache}
}

func (r RemoteStateProviderFactory) New() state.RemoteStateProvider {
return newRemoteStateProvider(r.RemoteStateCache)
}

type MedusaStateFactory struct {
*RemoteStateProviderFactory
}

func NewMedusaStateFactory(remoteStateFactory *RemoteStateProviderFactory) *MedusaStateFactory {
return &MedusaStateFactory{remoteStateFactory}
}

func (f *MedusaStateFactory) New(root common.Hash, db state.Database) (types.MedusaStateDB, error) {
return state.NewProxyDB(root, db, f.RemoteStateProviderFactory)
}
29 changes: 20 additions & 9 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chain
import (
"errors"
"fmt"
"github.com/crytic/medusa/chain/fork"
"math/big"
"sort"

Expand Down Expand Up @@ -63,10 +64,10 @@ type TestChain struct {
// genesisDefinition represents the Genesis information used to generate the chain's initial state.
genesisDefinition *core.Genesis

// state represents the current Ethereum world state.StateDB. It tracks all state across the chain and dummyChain
// and is the subject of state changes when executing new transactions. This does not track the current block
// head or anything of that nature and simply tracks accounts, balances, code, storage, etc.
state *state.StateDB
// state represents the current Ethereum world (interface implementing state.StateDB). It tracks all state across
// the chain and dummyChain and is the subject of state changes when executing new transactions. This does not
// track the current block head or anything of that nature and simply tracks accounts, balances, code, storage, etc.
state chainTypes.MedusaStateDB

// stateDatabase refers to the database object which state uses to store data. It is constructed over db.
stateDatabase state.Database
Expand All @@ -85,6 +86,10 @@ type TestChain struct {

// Events defines the event system for the TestChain.
Events TestChainEvents

// stateDbFactory used to construct state databases from db/root. Abstracts away the backing RPC when running in
// fork mode.
stateDbFactory *fork.MedusaStateFactory
}

// NewTestChain creates a simulated Ethereum backend used for testing, or returns an error if one occurred.
Expand Down Expand Up @@ -179,6 +184,11 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
transactionTracerRouter := NewTestChainTracerRouter()
callTracerRouter := NewTestChainTracerRouter()

// Set up the state factory
remoteCache := fork.EmptyRemoteStateCache{}
rspf := fork.NewRemoteStateProviderFactory(remoteCache)
sf := fork.NewMedusaStateFactory(rspf)

// Create our instance
chain := &TestChain{
genesisDefinition: genesisDefinition,
Expand All @@ -193,6 +203,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
testChainConfig: testChainConfig,
chainConfig: genesisDefinition.Config,
vmConfigExtensions: vmConfigExtensions,
stateDbFactory: sf,
}

// Add our internal tracers to this chain.
Expand Down Expand Up @@ -297,7 +308,7 @@ func (t *TestChain) GenesisDefinition() *core.Genesis {
}

// State returns the current state.StateDB of the chain.
func (t *TestChain) State() *state.StateDB {
func (t *TestChain) State() chainTypes.MedusaStateDB {
return t.state
}

Expand Down Expand Up @@ -460,9 +471,9 @@ func (t *TestChain) BlockHashFromNumber(blockNumber uint64) (common.Hash, error)

// StateFromRoot obtains a state from a given state root hash.
// Returns the state, or an error if one occurred.
func (t *TestChain) StateFromRoot(root common.Hash) (*state.StateDB, error) {
func (t *TestChain) StateFromRoot(root common.Hash) (chainTypes.MedusaStateDB, error) {
// Load our state from the database
stateDB, err := state.New(root, t.stateDatabase, nil)
stateDB, err := t.stateDbFactory.New(root, t.stateDatabase)
if err != nil {
return nil, err
}
Expand All @@ -486,7 +497,7 @@ func (t *TestChain) StateRootAfterBlockNumber(blockNumber uint64) (common.Hash,

// StateAfterBlockNumber obtains the Ethereum world state after processing all transactions in the provided block
// number. Returns the state, or an error if one occurs.
func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (*state.StateDB, error) {
func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (chainTypes.MedusaStateDB, error) {
// Obtain our block's post-execution state root hash
root, err := t.StateRootAfterBlockNumber(blockNumber)
if err != nil {
Expand Down Expand Up @@ -558,7 +569,7 @@ func (t *TestChain) RevertToBlockNumber(blockNumber uint64) error {
// It takes an optional state argument, which is the state to execute the message over. If not provided, the
// current pending state (or committed state if none is pending) will be used instead.
// The state executed over may be a pending block state.
func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) {
func (t *TestChain) CallContract(msg *core.Message, state chainTypes.MedusaStateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) {
// If our provided state is nil, use our current chain state.
if state == nil {
state = t.state
Expand Down
8 changes: 4 additions & 4 deletions chain/test_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func TestChainBlockNumberJumping(t *testing.T) {
// verifies have no registered contract deployments.
func TestChainDynamicDeployments(t *testing.T) {
// Copy our testdata over to our testing directory
contractPath := testutils.CopyToTestDirectory(t, "testdata/contracts/deployment_with_inner.sol")
contractPath := testutils.CopyToTesDirectory(t, "testdata/contracts/deployment_with_inner.sol")

// Execute our tests in the given test path
testutils.ExecuteInDirectory(t, contractPath, func() {
Expand Down Expand Up @@ -318,7 +318,7 @@ func TestChainDynamicDeployments(t *testing.T) {
// have no registered contract deployments.
func TestChainDeploymentWithArgs(t *testing.T) {
// Copy our testdata over to our testing directory
contractPath := testutils.CopyToTestDirectory(t, "testdata/contracts/deployment_with_args.sol")
contractPath := testutils.CopyToTesDirectory(t, "testdata/contracts/deployment_with_args.sol")

// Execute our tests in the given test path
testutils.ExecuteInDirectory(t, contractPath, func() {
Expand Down Expand Up @@ -450,7 +450,7 @@ func TestChainDeploymentWithArgs(t *testing.T) {
// that the ending state is the same.
func TestChainCloning(t *testing.T) {
// Copy our testdata over to our testing directory
contractPath := testutils.CopyToTestDirectory(t, "testdata/contracts/deployment_single.sol")
contractPath := testutils.CopyToTesDirectory(t, "testdata/contracts/deployment_single.sol")

// Execute our tests in the given test path
testutils.ExecuteInDirectory(t, contractPath, func() {
Expand Down Expand Up @@ -546,7 +546,7 @@ func TestChainCloning(t *testing.T) {
// semantics to be the same whenever run with the same messages being sent for all the same blocks.
func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
// Copy our testdata over to our testing directory
contractPath := testutils.CopyToTestDirectory(t, "testdata/contracts/deployment_single.sol")
contractPath := testutils.CopyToTesDirectory(t, "testdata/contracts/deployment_single.sol")

// Execute our tests in the given test path
testutils.ExecuteInDirectory(t, contractPath, func() {
Expand Down
27 changes: 27 additions & 0 deletions chain/types/medusa_statedb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package types

import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/holiman/uint256"
)

var _ MedusaStateDB = (*state.StateDB)(nil)
var _ MedusaStateDB = (*state.ForkStateDb)(nil)

type MedusaStateDB interface {
vm.StateDB

// add the extra methods that Medusa uses.
IntermediateRoot(bool) common.Hash
Finalise(bool)
GetLogs(common.Hash, uint64, common.Hash) []*types.Log
TxIndex() int
SetBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason)
SetTxContext(common.Hash, int)
Commit(uint64, bool) (common.Hash, error)
SetLogger(*tracing.Hooks)
}
4 changes: 2 additions & 2 deletions chain/vendored/apply_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package vendored

import (
"github.com/crytic/medusa/chain/config"
types2 "github.com/crytic/medusa/chain/types"
"github.com/ethereum/go-ethereum/common"
. "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand All @@ -36,7 +36,7 @@ import (
// This executes on an underlying EVM and returns a transaction receipt, or an error if one occurs.
// Additional changes:
// - Exposed core.ExecutionResult as a return value.
func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb *state.StateDB, blockNumber *big.Int, blockHash common.Hash, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *types.Receipt, result *ExecutionResult, err error) {
func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb types2.MedusaStateDB, blockNumber *big.Int, blockHash common.Hash, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *types.Receipt, result *ExecutionResult, err error) {
// Apply the OnTxStart and OnTxEnd hooks
if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, msg.From)
Expand Down

0 comments on commit c9cac8c

Please sign in to comment.