diff --git a/arbos/programs/native.go b/arbos/programs/native.go index c44f8f56cb..7a6c16d866 100644 --- a/arbos/programs/native.go +++ b/arbos/programs/native.go @@ -24,6 +24,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/log" @@ -53,6 +54,24 @@ func activateProgram( debug bool, burner burn.Burner, ) (*activationInfo, error) { + info, asm, module, err := activateProgramInternal(db, program, codehash, wasm, page_limit, version, debug, burner.GasLeft()) + if err != nil { + return nil, err + } + db.ActivateWasm(info.moduleHash, asm, module) + return info, nil +} + +func activateProgramInternal( + db vm.StateDB, + program common.Address, + codehash common.Hash, + wasm []byte, + page_limit uint16, + version uint16, + debug bool, + gasLeft *uint64, +) (*activationInfo, []byte, []byte, error) { output := &rustBytes{} asmLen := usize(0) moduleHash := &bytes32{} @@ -69,7 +88,7 @@ func activateProgram( &codeHash, moduleHash, stylusData, - (*u64)(burner.GasLeft()), + (*u64)(gasLeft), )) data, msg, err := status.toResult(output.intoBytes(), debug) @@ -78,9 +97,9 @@ func activateProgram( log.Warn("activation failed", "err", err, "msg", msg, "program", program) } if errors.Is(err, vm.ErrExecutionReverted) { - return nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg) + return nil, nil, nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg) } - return nil, err + return nil, nil, nil, err } hash := moduleHash.toHash() @@ -95,13 +114,57 @@ func activateProgram( asmEstimate: uint32(stylusData.asm_estimate), footprint: uint16(stylusData.footprint), } - db.ActivateWasm(hash, asm, module) - return info, err + return info, asm, module, err +} + +func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) { + localAsm, err := statedb.TryGetActivatedAsm(moduleHash) + if err == nil && len(localAsm) > 0 { + return localAsm, nil + } + + codeHash := statedb.GetCodeHash(address) + + wasm, err := getWasm(statedb, address) + if err != nil { + log.Error("Failed to reactivate program: getWasm", "address", address, "expected moduleHash", moduleHash, "err", err) + return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err) + } + + unlimitedGas := uint64(0xffffffffffff) + // we know program is activated, so it must be in correct version and not use too much memory + info, asm, module, err := activateProgramInternal(statedb, address, codeHash, wasm, pagelimit, program.version, debugMode, &unlimitedGas) + if err != nil { + log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "err", err) + return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err) + } + + if info.moduleHash != moduleHash { + log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "got", info.moduleHash) + return nil, fmt.Errorf("failed to reactivate program. address: %v, expected ModuleHash: %v", address, moduleHash) + } + + currentHoursSince := hoursSinceArbitrum(time) + if currentHoursSince > program.activatedAt { + // stylus program is active on-chain, and was activated in the past + // so we store it directly to database + batch := statedb.Database().WasmStore().NewBatch() + rawdb.WriteActivation(batch, moduleHash, asm, module) + if err := batch.Write(); err != nil { + log.Error("failed writing re-activation to state", "address", address, "err", err) + } + } else { + // program activated recently, possibly in this eth_call + // store it to statedb. It will be stored to database if statedb is commited + statedb.ActivateWasm(info.moduleHash, asm, module) + } + return asm, nil } func callProgram( address common.Address, moduleHash common.Hash, + localAsm []byte, scope *vm.ScopeContext, interpreter *vm.EVMInterpreter, tracingInfo *util.TracingInfo, @@ -111,10 +174,9 @@ func callProgram( memoryModel *MemoryModel, ) ([]byte, error) { db := interpreter.Evm().StateDB - asm := db.GetActivatedAsm(moduleHash) debug := stylusParams.DebugMode - if len(asm) == 0 { + if len(localAsm) == 0 { log.Error("missing asm", "program", address, "module", moduleHash) panic("missing asm") } @@ -128,7 +190,7 @@ func callProgram( output := &rustBytes{} status := userStatus(C.stylus_call( - goSlice(asm), + goSlice(localAsm), goSlice(calldata), stylusParams.encode(), evmApi.cNative, @@ -159,11 +221,15 @@ func handleReqImpl(apiId usize, req_type u32, data *rustSlice, costPtr *u64, out // Caches a program in Rust. We write a record so that we can undo on revert. // For gas estimation and eth_call, we ignore permanent updates and rely on Rust's LRU. -func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, runMode core.MessageRunMode) { +func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) { if runMode == core.MessageCommitMode { - asm := db.GetActivatedAsm(module) - state.CacheWasmRust(asm, module, version, debug) - db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: version, Debug: debug}) + // address is only used for logging + asm, err := getLocalAsm(db, module, common.Address{}, params.PageLimit, time, debug, program) + if err != nil { + panic("unable to recreate wasm") + } + state.CacheWasmRust(asm, module, program.version, debug) + db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: program.version, Debug: debug}) } } diff --git a/arbos/programs/programs.go b/arbos/programs/programs.go index 3f7bdc39ca..d3113ae98d 100644 --- a/arbos/programs/programs.go +++ b/arbos/programs/programs.go @@ -120,14 +120,13 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c return 0, codeHash, common.Hash{}, nil, true, err } - // replace the cached asm + // remove prev asm if cached { oldModuleHash, err := p.moduleHashes.Get(codeHash) if err != nil { return 0, codeHash, common.Hash{}, nil, true, err } evictProgram(statedb, oldModuleHash, currentVersion, debugMode, runMode, expired) - cacheProgram(statedb, info.moduleHash, stylusVersion, debugMode, runMode) } if err := p.moduleHashes.Set(codeHash, info.moduleHash); err != nil { return 0, codeHash, common.Hash{}, nil, true, err @@ -152,6 +151,11 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c activatedAt: hoursSinceArbitrum(time), cached: cached, } + // replace the cached asm + if cached { + cacheProgram(statedb, info.moduleHash, programData, params, debugMode, time, runMode) + } + return stylusVersion, codeHash, info.moduleHash, dataFee, false, p.setProgram(codeHash, programData) } @@ -205,6 +209,12 @@ func (p Programs) CallProgram( statedb.AddStylusPages(program.footprint) defer statedb.SetStylusPagesOpen(open) + localAsm, err := getLocalAsm(statedb, moduleHash, contract.Address(), params.PageLimit, evm.Context.Time, debugMode, program) + if err != nil { + log.Crit("failed to get local wasm for activated program", "program", contract.Address()) + return nil, err + } + evmData := &EvmData{ blockBasefee: common.BigToHash(evm.Context.BaseFee), chainId: evm.ChainConfig().ChainID.Uint64(), @@ -227,7 +237,7 @@ func (p Programs) CallProgram( if contract.CodeAddr != nil { address = *contract.CodeAddr } - return callProgram(address, moduleHash, scope, interpreter, tracingInfo, calldata, evmData, goParams, model) + return callProgram(address, moduleHash, localAsm, scope, interpreter, tracingInfo, calldata, evmData, goParams, model) } func getWasm(statedb vm.StateDB, program common.Address) ([]byte, error) { @@ -380,7 +390,7 @@ func (p Programs) SetProgramCached( return err } if cache { - cacheProgram(db, moduleHash, program.version, debug, runMode) + cacheProgram(db, moduleHash, program, params, debug, time, runMode) } else { evictProgram(db, moduleHash, program.version, debug, runMode, expired) } diff --git a/arbos/programs/wasm.go b/arbos/programs/wasm.go index 1e9b5e680b..95f30e83b6 100644 --- a/arbos/programs/wasm.go +++ b/arbos/programs/wasm.go @@ -95,7 +95,7 @@ func activateProgram( } // stub any non-consensus, Rust-side caching updates -func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode) { +func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) { } func evictProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode, forever bool) { } @@ -128,9 +128,14 @@ func startProgram(module uint32) uint32 //go:wasmimport programs send_response func sendResponse(req_id uint32) uint32 +func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) { + return nil, nil +} + func callProgram( address common.Address, moduleHash common.Hash, + _localAsm []byte, scope *vm.ScopeContext, interpreter *vm.EVMInterpreter, tracingInfo *util.TracingInfo, diff --git a/cmd/nitro/init.go b/cmd/nitro/init.go index a45ec054a1..c52c87732c 100644 --- a/cmd/nitro/init.go +++ b/cmd/nitro/init.go @@ -178,10 +178,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo if !arbmath.BigEquals(chainConfig.ChainID, chainId) { return nil, nil, fmt.Errorf("database has chain ID %v but config has chain ID %v (are you sure this database is for the right chain?)", chainConfig.ChainID, chainId) } - chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false) + chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false) if err != nil { - return chainDb, nil, err + return nil, nil, err + } + wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false) + if err != nil { + return nil, nil, err } + chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb) err = pruning.PruneChainDb(ctx, chainDb, stack, &config.Init, cacheConfig, l1Client, rollupAddrs, config.Node.ValidatorRequired()) if err != nil { return chainDb, nil, fmt.Errorf("error pruning: %w", err) @@ -230,10 +235,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo var initDataReader statetransfer.InitDataReader = nil - chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false) + chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false) if err != nil { - return chainDb, nil, err + return nil, nil, err + } + wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false) + if err != nil { + return nil, nil, err } + chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb) if config.Init.ImportFile != "" { initDataReader, err = statetransfer.NewJsonInitDataReader(config.Init.ImportFile) diff --git a/go-ethereum b/go-ethereum index 9874ec397a..8048ac4bed 160000 --- a/go-ethereum +++ b/go-ethereum @@ -1 +1 @@ -Subproject commit 9874ec397a5b499eefc98f7f9ae9632c3fc1e17f +Subproject commit 8048ac4bed2eda18284e3c022ea5ee4cce771134 diff --git a/system_tests/common_test.go b/system_tests/common_test.go index 1b2b7ca6d6..f6bfde2108 100644 --- a/system_tests/common_test.go +++ b/system_tests/common_test.go @@ -44,6 +44,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" @@ -772,8 +773,11 @@ func createL2BlockChainWithStackConfig( stack, err = node.New(stackConfig) Require(t, err) - chainDb, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false) + chainData, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false) Require(t, err) + wasmData, err := stack.OpenDatabase("wasm", 0, 0, "wasm/", false) + Require(t, err) + chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmData) arbDb, err := stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false) Require(t, err) @@ -976,8 +980,12 @@ func Create2ndNodeWithConfig( l2stack, err := node.New(stackConfig) Require(t, err) - l2chainDb, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false) + l2chainData, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false) + Require(t, err) + wasmData, err := l2stack.OpenDatabase("wasm", 0, 0, "wasm/", false) Require(t, err) + l2chainDb := rawdb.WrapDatabaseWithWasm(l2chainData, wasmData) + l2arbDb, err := l2stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false) Require(t, err) initReader := statetransfer.NewMemoryInitDataReader(l2InitData) diff --git a/system_tests/program_test.go b/system_tests/program_test.go index b20efe0740..079b6c0818 100644 --- a/system_tests/program_test.go +++ b/system_tests/program_test.go @@ -1458,3 +1458,81 @@ func formatTime(duration time.Duration) string { } return fmt.Sprintf("%.2f%s", span, units[unit]) } + +func TestWasmRecreate(t *testing.T) { + builder, auth, cleanup := setupProgramTest(t, true) + ctx := builder.ctx + l2info := builder.L2Info + l2client := builder.L2.Client + defer cleanup() + + storage := deployWasm(t, ctx, auth, l2client, rustFile("storage")) + + zero := common.Hash{} + val := common.HexToHash("0x121233445566") + + // do an onchain call - store value + storeTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageWrite(zero, val)) + Require(t, l2client.SendTransaction(ctx, storeTx)) + _, err := EnsureTxSucceeded(ctx, l2client, storeTx) + Require(t, err) + + testDir := t.TempDir() + nodeBStack := createStackConfigForTest(testDir) + nodeB, cleanupB := builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack}) + + _, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx) + Require(t, err) + + // make sure reading 2nd value succeeds from 2nd node + loadTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageRead(zero)) + result, err := arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true) + Require(t, err) + if common.BytesToHash(result) != val { + Fatal(t, "got wrong value") + } + // close nodeB + cleanupB() + + // delete wasm dir of nodeB + + wasmPath := filepath.Join(testDir, "system_tests.test", "wasm") + dirContents, err := os.ReadDir(wasmPath) + Require(t, err) + if len(dirContents) == 0 { + Fatal(t, "not contents found before delete") + } + os.RemoveAll(wasmPath) + + // recreate nodeB - using same source dir (wasm deleted) + nodeB, cleanupB = builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack}) + + // test nodeB - sees existing transaction + _, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx) + Require(t, err) + + // test nodeB - answers eth_call (requires reloading wasm) + result, err = arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true) + Require(t, err) + if common.BytesToHash(result) != val { + Fatal(t, "got wrong value") + } + + // send new tx (requires wasm) and check nodeB sees it as well + Require(t, l2client.SendTransaction(ctx, loadTx)) + + _, err = EnsureTxSucceeded(ctx, l2client, loadTx) + Require(t, err) + + _, err = EnsureTxSucceeded(ctx, nodeB.Client, loadTx) + Require(t, err) + + cleanupB() + dirContents, err = os.ReadDir(wasmPath) + Require(t, err) + if len(dirContents) == 0 { + Fatal(t, "not contents found before delete") + } + os.RemoveAll(wasmPath) + +}