Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wasmstore #2292

Merged
merged 17 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions arbos/programs/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/log"
Expand Down Expand Up @@ -53,6 +54,24 @@ func activateProgram(
debug bool,
burner burn.Burner,
) (*activationInfo, error) {
info, asm, module, err := activateProgramInternal(db, program, codehash, wasm, page_limit, version, debug, burner.GasLeft())
if err != nil {
return nil, err
}
db.ActivateWasm(info.moduleHash, asm, module)
return info, nil
}

func activateProgramInternal(
db vm.StateDB,
program common.Address,
codehash common.Hash,
wasm []byte,
page_limit uint16,
version uint16,
debug bool,
gasLeft *uint64,
) (*activationInfo, []byte, []byte, error) {
output := &rustBytes{}
asmLen := usize(0)
moduleHash := &bytes32{}
Expand All @@ -69,7 +88,7 @@ func activateProgram(
&codeHash,
moduleHash,
stylusData,
(*u64)(burner.GasLeft()),
(*u64)(gasLeft),
))

data, msg, err := status.toResult(output.intoBytes(), debug)
Expand All @@ -78,9 +97,9 @@ func activateProgram(
log.Warn("activation failed", "err", err, "msg", msg, "program", program)
}
if errors.Is(err, vm.ErrExecutionReverted) {
return nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg)
}
return nil, err
return nil, nil, nil, err
}

hash := moduleHash.toHash()
Expand All @@ -95,13 +114,57 @@ func activateProgram(
asmEstimate: uint32(stylusData.asm_estimate),
footprint: uint16(stylusData.footprint),
}
db.ActivateWasm(hash, asm, module)
return info, err
return info, asm, module, err
}

func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) {
localAsm, err := statedb.TryGetActivatedAsm(moduleHash)
if err == nil && len(localAsm) > 0 {
return localAsm, nil
}

codeHash := statedb.GetCodeHash(address)

wasm, err := getWasm(statedb, address)
if err != nil {
log.Error("Failed to reactivate program: getWasm", "address", address, "expected moduleHash", moduleHash, "err", err)
return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err)
}

unlimitedGas := uint64(0xffffffffffff)
// we know program is activated, so it must be in correct version and not use too much memory
info, asm, module, err := activateProgramInternal(statedb, address, codeHash, wasm, pagelimit, program.version, debugMode, &unlimitedGas)
if err != nil {
log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "err", err)
return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err)
}

if info.moduleHash != moduleHash {
log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "got", info.moduleHash)
return nil, fmt.Errorf("failed to reactivate program. address: %v, expected ModuleHash: %v", address, moduleHash)
}

currentHoursSince := hoursSinceArbitrum(time)
if currentHoursSince > program.activatedAt {
// stylus program is active on-chain, and was activated in the past
PlasmaPower marked this conversation as resolved.
Show resolved Hide resolved
// so we store it directly to database
batch := statedb.Database().WasmStore().NewBatch()
rawdb.WriteActivation(batch, moduleHash, asm, module)
if err := batch.Write(); err != nil {
log.Error("failed writing re-activation to state", "address", address, "err", err)
}
} else {
// program activated recently, possibly in this eth_call
// store it to statedb. It will be stored to database if statedb is commited
statedb.ActivateWasm(info.moduleHash, asm, module)
}
return asm, nil
}

func callProgram(
address common.Address,
moduleHash common.Hash,
localAsm []byte,
scope *vm.ScopeContext,
interpreter *vm.EVMInterpreter,
tracingInfo *util.TracingInfo,
Expand All @@ -111,10 +174,9 @@ func callProgram(
memoryModel *MemoryModel,
) ([]byte, error) {
db := interpreter.Evm().StateDB
asm := db.GetActivatedAsm(moduleHash)
debug := stylusParams.DebugMode

if len(asm) == 0 {
if len(localAsm) == 0 {
log.Error("missing asm", "program", address, "module", moduleHash)
panic("missing asm")
}
Expand All @@ -128,7 +190,7 @@ func callProgram(

output := &rustBytes{}
status := userStatus(C.stylus_call(
goSlice(asm),
goSlice(localAsm),
goSlice(calldata),
stylusParams.encode(),
evmApi.cNative,
Expand Down Expand Up @@ -159,11 +221,15 @@ func handleReqImpl(apiId usize, req_type u32, data *rustSlice, costPtr *u64, out

// Caches a program in Rust. We write a record so that we can undo on revert.
// For gas estimation and eth_call, we ignore permanent updates and rely on Rust's LRU.
func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, runMode core.MessageRunMode) {
func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) {
if runMode == core.MessageCommitMode {
asm := db.GetActivatedAsm(module)
state.CacheWasmRust(asm, module, version, debug)
db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: version, Debug: debug})
// address is only used for logging
asm, err := getLocalAsm(db, module, common.Address{}, params.PageLimit, time, debug, program)
if err != nil {
panic("unable to recreate wasm")
}
state.CacheWasmRust(asm, module, program.version, debug)
db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: program.version, Debug: debug})
}
}

Expand Down
18 changes: 14 additions & 4 deletions arbos/programs/programs.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c
return 0, codeHash, common.Hash{}, nil, true, err
}

// replace the cached asm
// remove prev asm
if cached {
oldModuleHash, err := p.moduleHashes.Get(codeHash)
if err != nil {
return 0, codeHash, common.Hash{}, nil, true, err
}
evictProgram(statedb, oldModuleHash, currentVersion, debugMode, runMode, expired)
cacheProgram(statedb, info.moduleHash, stylusVersion, debugMode, runMode)
}
if err := p.moduleHashes.Set(codeHash, info.moduleHash); err != nil {
return 0, codeHash, common.Hash{}, nil, true, err
Expand All @@ -152,6 +151,11 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c
activatedAt: hoursSinceArbitrum(time),
cached: cached,
}
// replace the cached asm
if cached {
cacheProgram(statedb, info.moduleHash, programData, params, debugMode, time, runMode)
}

return stylusVersion, codeHash, info.moduleHash, dataFee, false, p.setProgram(codeHash, programData)
}

Expand Down Expand Up @@ -205,6 +209,12 @@ func (p Programs) CallProgram(
statedb.AddStylusPages(program.footprint)
defer statedb.SetStylusPagesOpen(open)

localAsm, err := getLocalAsm(statedb, moduleHash, contract.Address(), params.PageLimit, evm.Context.Time, debugMode, program)
if err != nil {
log.Crit("failed to get local wasm for activated program", "program", contract.Address())
return nil, err
}

evmData := &EvmData{
blockBasefee: common.BigToHash(evm.Context.BaseFee),
chainId: evm.ChainConfig().ChainID.Uint64(),
Expand All @@ -227,7 +237,7 @@ func (p Programs) CallProgram(
if contract.CodeAddr != nil {
address = *contract.CodeAddr
}
return callProgram(address, moduleHash, scope, interpreter, tracingInfo, calldata, evmData, goParams, model)
return callProgram(address, moduleHash, localAsm, scope, interpreter, tracingInfo, calldata, evmData, goParams, model)
}

func getWasm(statedb vm.StateDB, program common.Address) ([]byte, error) {
Expand Down Expand Up @@ -380,7 +390,7 @@ func (p Programs) SetProgramCached(
return err
}
if cache {
cacheProgram(db, moduleHash, program.version, debug, runMode)
cacheProgram(db, moduleHash, program, params, debug, time, runMode)
} else {
evictProgram(db, moduleHash, program.version, debug, runMode, expired)
}
Expand Down
7 changes: 6 additions & 1 deletion arbos/programs/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func activateProgram(
}

// stub any non-consensus, Rust-side caching updates
func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode) {
func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) {
}
func evictProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode, forever bool) {
}
Expand Down Expand Up @@ -128,9 +128,14 @@ func startProgram(module uint32) uint32
//go:wasmimport programs send_response
func sendResponse(req_id uint32) uint32

func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) {
return nil, nil
}

func callProgram(
address common.Address,
moduleHash common.Hash,
_localAsm []byte,
scope *vm.ScopeContext,
interpreter *vm.EVMInterpreter,
tracingInfo *util.TracingInfo,
Expand Down
18 changes: 14 additions & 4 deletions cmd/nitro/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo
if !arbmath.BigEquals(chainConfig.ChainID, chainId) {
return nil, nil, fmt.Errorf("database has chain ID %v but config has chain ID %v (are you sure this database is for the right chain?)", chainConfig.ChainID, chainId)
}
chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
if err != nil {
return chainDb, nil, err
return nil, nil, err
}
wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false)
if err != nil {
return nil, nil, err
}
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb)
err = pruning.PruneChainDb(ctx, chainDb, stack, &config.Init, cacheConfig, l1Client, rollupAddrs, config.Node.ValidatorRequired())
if err != nil {
return chainDb, nil, fmt.Errorf("error pruning: %w", err)
Expand Down Expand Up @@ -230,10 +235,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo

var initDataReader statetransfer.InitDataReader = nil

chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
if err != nil {
return chainDb, nil, err
return nil, nil, err
}
wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false)
if err != nil {
return nil, nil, err
}
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb)

if config.Init.ImportFile != "" {
initDataReader, err = statetransfer.NewJsonInitDataReader(config.Init.ImportFile)
Expand Down
12 changes: 10 additions & 2 deletions system_tests/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -772,8 +773,11 @@ func createL2BlockChainWithStackConfig(
stack, err = node.New(stackConfig)
Require(t, err)

chainDb, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
chainData, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
Require(t, err)
wasmData, err := stack.OpenDatabase("wasm", 0, 0, "wasm/", false)
Require(t, err)
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmData)
arbDb, err := stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false)
Require(t, err)

Expand Down Expand Up @@ -976,8 +980,12 @@ func Create2ndNodeWithConfig(
l2stack, err := node.New(stackConfig)
Require(t, err)

l2chainDb, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
l2chainData, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
Require(t, err)
wasmData, err := l2stack.OpenDatabase("wasm", 0, 0, "wasm/", false)
Require(t, err)
l2chainDb := rawdb.WrapDatabaseWithWasm(l2chainData, wasmData)

l2arbDb, err := l2stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false)
Require(t, err)
initReader := statetransfer.NewMemoryInitDataReader(l2InitData)
Expand Down
78 changes: 78 additions & 0 deletions system_tests/program_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1458,3 +1458,81 @@ func formatTime(duration time.Duration) string {
}
return fmt.Sprintf("%.2f%s", span, units[unit])
}

func TestWasmRecreate(t *testing.T) {
builder, auth, cleanup := setupProgramTest(t, true)
ctx := builder.ctx
l2info := builder.L2Info
l2client := builder.L2.Client
defer cleanup()

storage := deployWasm(t, ctx, auth, l2client, rustFile("storage"))

zero := common.Hash{}
val := common.HexToHash("0x121233445566")

// do an onchain call - store value
storeTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageWrite(zero, val))
Require(t, l2client.SendTransaction(ctx, storeTx))
_, err := EnsureTxSucceeded(ctx, l2client, storeTx)
Require(t, err)

testDir := t.TempDir()
nodeBStack := createStackConfigForTest(testDir)
nodeB, cleanupB := builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack})

_, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx)
Require(t, err)

// make sure reading 2nd value succeeds from 2nd node
loadTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageRead(zero))
result, err := arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true)
Require(t, err)
if common.BytesToHash(result) != val {
Fatal(t, "got wrong value")
}
// close nodeB
cleanupB()

// delete wasm dir of nodeB

wasmPath := filepath.Join(testDir, "system_tests.test", "wasm")
dirContents, err := os.ReadDir(wasmPath)
Require(t, err)
if len(dirContents) == 0 {
Fatal(t, "not contents found before delete")
}
os.RemoveAll(wasmPath)

// recreate nodeB - using same source dir (wasm deleted)
nodeB, cleanupB = builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack})

// test nodeB - sees existing transaction
_, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx)
Require(t, err)

// test nodeB - answers eth_call (requires reloading wasm)
result, err = arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true)
Require(t, err)
if common.BytesToHash(result) != val {
Fatal(t, "got wrong value")
}

// send new tx (requires wasm) and check nodeB sees it as well
Require(t, l2client.SendTransaction(ctx, loadTx))

_, err = EnsureTxSucceeded(ctx, l2client, loadTx)
Require(t, err)

_, err = EnsureTxSucceeded(ctx, nodeB.Client, loadTx)
Require(t, err)

cleanupB()
dirContents, err = os.ReadDir(wasmPath)
Require(t, err)
if len(dirContents) == 0 {
Fatal(t, "not contents found before delete")
}
os.RemoveAll(wasmPath)

}
Loading