diff --git a/core/rawdb/accessors_state_arbitrum.go b/core/rawdb/accessors_state_arbitrum.go index 32ccb1c94d..455ba6f4f6 100644 --- a/core/rawdb/accessors_state_arbitrum.go +++ b/core/rawdb/accessors_state_arbitrum.go @@ -17,20 +17,99 @@ package rawdb import ( + "fmt" + "runtime" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) -// Stores the activated asm and module for a given codeHash -func WriteActivation(db ethdb.KeyValueWriter, moduleHash common.Hash, asm, module []byte) { - key := ActivatedAsmKey(moduleHash) +type Target string + +const ( + TargetWavm Target = "wavm" + TargetArm64 Target = "arm64" + TargetAmd64 Target = "amd64" + TargetHost Target = "host" +) + +func LocalTarget() Target { + if runtime.GOOS == "linux" { + switch runtime.GOARCH { + case "arm64": + return TargetArm64 + case "amd64": + return TargetAmd64 + } + } + return TargetHost +} + +func (t Target) keyPrefix() (WasmPrefix, error) { + var prefix WasmPrefix + switch t { + case TargetWavm: + prefix = activatedAsmWavmPrefix + case TargetArm64: + prefix = activatedAsmArmPrefix + case TargetAmd64: + prefix = activatedAsmX86Prefix + case TargetHost: + prefix = activatedAsmHostPrefix + default: + return WasmPrefix{}, fmt.Errorf("invalid target: %v", t) + } + return prefix, nil +} + +func (t Target) IsValid() bool { + _, err := t.keyPrefix() + return err == nil +} + +var Targets = []Target{TargetWavm, TargetArm64, TargetAmd64, TargetHost} + +func WriteActivation(db ethdb.KeyValueWriter, moduleHash common.Hash, asmMap map[Target][]byte) { + for target, asm := range asmMap { + WriteActivatedAsm(db, target, moduleHash, asm) + } +} + +// Stores the activated asm for a given moduleHash and target +func WriteActivatedAsm(db ethdb.KeyValueWriter, target Target, moduleHash common.Hash, asm []byte) { + prefix, err := target.keyPrefix() + if err != nil { + log.Crit("Failed to store activated wasm asm", "err", err) + } + key := activatedKey(prefix, moduleHash) if err := db.Put(key[:], asm); err != nil { log.Crit("Failed to store activated wasm asm", "err", err) } +} + +// Retrieves the activated asm for a given moduleHash and target +func ReadActivatedAsm(db ethdb.KeyValueReader, target Target, moduleHash common.Hash) []byte { + prefix, err := target.keyPrefix() + if err != nil { + log.Crit("Failed to read activated wasm asm", "err", err) + } + key := activatedKey(prefix, moduleHash) + asm, err := db.Get(key[:]) + if err != nil { + return nil + } + return asm +} - key = ActivatedModuleKey(moduleHash) - if err := db.Put(key[:], module); err != nil { - log.Crit("Failed to store activated wasm module", "err", err) +// Stores wasm schema version +func WriteWasmSchemaVersion(db ethdb.KeyValueWriter) { + if err := db.Put(wasmSchemaVersionKey, []byte{WasmSchemaVersion}); err != nil { + log.Crit("Failed to store wasm schema version", "err", err) } } + +// Retrieves wasm schema version +func ReadWasmSchemaVersion(db ethdb.KeyValueReader) ([]byte, error) { + return db.Get(wasmSchemaVersionKey) +} diff --git a/core/rawdb/schema_arbitrum.go b/core/rawdb/schema_arbitrum.go index 72a3f755bb..5f364892c7 100644 --- a/core/rawdb/schema_arbitrum.go +++ b/core/rawdb/schema_arbitrum.go @@ -19,48 +19,41 @@ package rawdb import ( - "bytes" - "github.com/ethereum/go-ethereum/common" ) -var ( - activatedAsmPrefix = []byte{0x00, 'w', 'a'} // (prefix, moduleHash) -> stylus asm - activatedModulePrefix = []byte{0x00, 'w', 'm'} // (prefix, moduleHash) -> stylus module -) +const WasmSchemaVersion byte = 0x01 + +const WasmPrefixLen = 3 // WasmKeyLen = CompiledWasmCodePrefix + moduleHash -const WasmKeyLen = 3 + 32 +const WasmKeyLen = WasmPrefixLen + common.HashLength +type WasmPrefix = [WasmPrefixLen]byte type WasmKey = [WasmKeyLen]byte -func ActivatedAsmKey(moduleHash common.Hash) WasmKey { - return newWasmKey(activatedAsmPrefix, moduleHash) -} +var ( + wasmSchemaVersionKey = []byte("WasmSchemaVersion") + + // 0x00 prefix to avoid conflicts when wasmdb is not separate database + activatedAsmWavmPrefix = WasmPrefix{0x00, 'w', 'w'} // (prefix, moduleHash) -> stylus module (wavm) + activatedAsmArmPrefix = WasmPrefix{0x00, 'w', 'r'} // (prefix, moduleHash) -> stylus asm for ARM system + activatedAsmX86Prefix = WasmPrefix{0x00, 'w', 'x'} // (prefix, moduleHash) -> stylus asm for x86 system + activatedAsmHostPrefix = WasmPrefix{0x00, 'w', 'h'} // (prefix, moduleHash) -> stylus asm for system other then ARM and x86 +) -func ActivatedModuleKey(moduleHash common.Hash) WasmKey { - return newWasmKey(activatedModulePrefix, moduleHash) +func DeprecatedPrefixesV0() (keyPrefixes [][]byte, keyLength int) { + return [][]byte{ + // deprecated prefixes, used in version 0x00, purged in version 0x01 + []byte{0x00, 'w', 'a'}, // ActivatedAsmPrefix + []byte{0x00, 'w', 'm'}, // ActivatedModulePrefix + }, 3 + 32 } // key = prefix + moduleHash -func newWasmKey(prefix []byte, moduleHash common.Hash) WasmKey { +func activatedKey(prefix WasmPrefix, moduleHash common.Hash) WasmKey { var key WasmKey - copy(key[:3], prefix) - copy(key[3:], moduleHash[:]) + copy(key[:WasmPrefixLen], prefix[:]) + copy(key[WasmPrefixLen:], moduleHash[:]) return key } - -func IsActivatedAsmKey(key []byte) (bool, common.Hash) { - return extractWasmKey(activatedAsmPrefix, key) -} - -func IsActivatedModuleKey(key []byte) (bool, common.Hash) { - return extractWasmKey(activatedModulePrefix, key) -} - -func extractWasmKey(prefix, key []byte) (bool, common.Hash) { - if !bytes.HasPrefix(key, prefix) || len(key) != WasmKeyLen { - return false, common.Hash{} - } - return true, common.BytesToHash(key[len(prefix):]) -} diff --git a/core/state/database.go b/core/state/database.go index 860c2b2146..6c834ccc8a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -53,8 +53,7 @@ const ( // Database wraps access to tries and contract code. type Database interface { // Arbitrum: Read activated Stylus contracts - ActivatedAsm(moduleHash common.Hash) (asm []byte, err error) - ActivatedModule(moduleHash common.Hash) (module []byte, err error) + ActivatedAsm(target rawdb.Target, moduleHash common.Hash) (asm []byte, err error) WasmStore() ethdb.KeyValueStore WasmCacheTag() uint32 @@ -164,9 +163,8 @@ func NewDatabaseWithConfig(db ethdb.Database, config *triedb.Config) Database { wasmdb, wasmTag := db.WasmDataBase() cdb := &cachingDB{ // Arbitrum only - activatedAsmCache: lru.NewSizeConstrainedCache[common.Hash, []byte](activatedWasmCacheSize), - activatedModuleCache: lru.NewSizeConstrainedCache[common.Hash, []byte](activatedWasmCacheSize), - wasmTag: wasmTag, + activatedAsmCache: lru.NewSizeConstrainedCache[activatedAsmCacheKey, []byte](activatedWasmCacheSize), + wasmTag: wasmTag, disk: db, wasmdb: wasmdb, @@ -182,9 +180,8 @@ func NewDatabaseWithNodeDB(db ethdb.Database, triedb *triedb.Database) Database wasmdb, wasmTag := db.WasmDataBase() cdb := &cachingDB{ // Arbitrum only - activatedAsmCache: lru.NewSizeConstrainedCache[common.Hash, []byte](activatedWasmCacheSize), - activatedModuleCache: lru.NewSizeConstrainedCache[common.Hash, []byte](activatedWasmCacheSize), - wasmTag: wasmTag, + activatedAsmCache: lru.NewSizeConstrainedCache[activatedAsmCacheKey, []byte](activatedWasmCacheSize), + wasmTag: wasmTag, disk: db, wasmdb: wasmdb, @@ -195,11 +192,15 @@ func NewDatabaseWithNodeDB(db ethdb.Database, triedb *triedb.Database) Database return cdb } +type activatedAsmCacheKey struct { + moduleHash common.Hash + target rawdb.Target +} + type cachingDB struct { // Arbitrum - activatedAsmCache *lru.SizeConstrainedCache[common.Hash, []byte] - activatedModuleCache *lru.SizeConstrainedCache[common.Hash, []byte] - wasmTag uint32 + activatedAsmCache *lru.SizeConstrainedCache[activatedAsmCacheKey, []byte] + wasmTag uint32 disk ethdb.KeyValueStore wasmdb ethdb.KeyValueStore diff --git a/core/state/database_arbitrum.go b/core/state/database_arbitrum.go index 6171b4cd10..b05733520b 100644 --- a/core/state/database_arbitrum.go +++ b/core/state/database_arbitrum.go @@ -7,34 +7,14 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" ) -func (db *cachingDB) ActivatedAsm(moduleHash common.Hash) ([]byte, error) { - if asm, _ := db.activatedAsmCache.Get(moduleHash); len(asm) > 0 { +func (db *cachingDB) ActivatedAsm(target rawdb.Target, moduleHash common.Hash) ([]byte, error) { + cacheKey := activatedAsmCacheKey{moduleHash, target} + if asm, _ := db.activatedAsmCache.Get(cacheKey); len(asm) > 0 { return asm, nil } - wasmKey := rawdb.ActivatedAsmKey(moduleHash) - asm, err := db.wasmdb.Get(wasmKey[:]) - if err != nil { - return nil, err - } - if len(asm) > 0 { - db.activatedAsmCache.Add(moduleHash, asm) + if asm := rawdb.ReadActivatedAsm(db.wasmdb, target, moduleHash); len(asm) > 0 { + db.activatedAsmCache.Add(cacheKey, asm) return asm, nil } return nil, errors.New("not found") } - -func (db *cachingDB) ActivatedModule(moduleHash common.Hash) ([]byte, error) { - if module, _ := db.activatedModuleCache.Get(moduleHash); len(module) > 0 { - return module, nil - } - wasmKey := rawdb.ActivatedModuleKey(moduleHash) - module, err := db.wasmdb.Get(wasmKey[:]) - if err != nil { - return nil, err - } - if len(module) > 0 { - db.activatedModuleCache.Add(moduleHash, module) - return module, nil - } - return nil, errors.New("not found") -} diff --git a/core/state/journal_arbitrum.go b/core/state/journal_arbitrum.go index 69b2415b79..d074d6df7c 100644 --- a/core/state/journal_arbitrum.go +++ b/core/state/journal_arbitrum.go @@ -2,6 +2,7 @@ package state import ( "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" ) type wasmActivation struct { @@ -43,7 +44,7 @@ type EvictWasm struct { } func (ch EvictWasm) revert(s *StateDB) { - asm, err := s.TryGetActivatedAsm(ch.ModuleHash) // only happens in native mode + asm, err := s.TryGetActivatedAsm(rawdb.LocalTarget(), ch.ModuleHash) // only happens in native mode if err == nil && len(asm) != 0 { //if we failed to get it - it's not in the current rust cache CacheWasmRust(asm, ch.ModuleHash, ch.Version, ch.Tag, ch.Debug) diff --git a/core/state/statedb.go b/core/state/statedb.go index d1aaaf591a..cf4b2c5ba7 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -157,7 +157,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) unexpectedBalanceDelta: new(big.Int), openWasmPages: 0, everWasmPages: 0, - activatedWasms: make(map[common.Hash]*ActivatedWasm), + activatedWasms: make(map[common.Hash]ActivatedWasm), recentWasms: NewRecentWasms(), }, @@ -727,7 +727,7 @@ func (s *StateDB) Copy() *StateDB { state := &StateDB{ arbExtraData: &ArbitrumExtraData{ unexpectedBalanceDelta: new(big.Int).Set(s.arbExtraData.unexpectedBalanceDelta), - activatedWasms: make(map[common.Hash]*ActivatedWasm, len(s.arbExtraData.activatedWasms)), + activatedWasms: make(map[common.Hash]ActivatedWasm, len(s.arbExtraData.activatedWasms)), recentWasms: s.arbExtraData.recentWasms.Copy(), openWasmPages: s.arbExtraData.openWasmPages, everWasmPages: s.arbExtraData.everWasmPages, @@ -817,9 +817,9 @@ func (s *StateDB) Copy() *StateDB { state.arbExtraData.userWasms[call] = wasm } } - for moduleHash, info := range s.arbExtraData.activatedWasms { + for moduleHash, asmMap := range s.arbExtraData.activatedWasms { // It's fine to skip a deep copy since activations are immutable. - state.arbExtraData.activatedWasms[moduleHash] = info + state.arbExtraData.activatedWasms[moduleHash] = asmMap } // If there's a prefetcher running, make an inactive copy of it that can @@ -1277,11 +1277,11 @@ func (s *StateDB) Commit(block uint64, deleteEmptyObjects bool) (common.Hash, er } // Arbitrum: write Stylus programs to disk - for moduleHash, info := range s.arbExtraData.activatedWasms { - rawdb.WriteActivation(wasmCodeWriter, moduleHash, info.Asm, info.Module) + for moduleHash, asmMap := range s.arbExtraData.activatedWasms { + rawdb.WriteActivation(wasmCodeWriter, moduleHash, asmMap) } if len(s.arbExtraData.activatedWasms) > 0 { - s.arbExtraData.activatedWasms = make(map[common.Hash]*ActivatedWasm) + s.arbExtraData.activatedWasms = make(map[common.Hash]ActivatedWasm) } if codeWriter.ValueSize() > 0 { diff --git a/core/state/statedb_arbitrum.go b/core/state/statedb_arbitrum.go index e459ad4570..99e7835231 100644 --- a/core/state/statedb_arbitrum.go +++ b/core/state/statedb_arbitrum.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/lru" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" @@ -49,10 +50,7 @@ var ( StylusDiscriminant = []byte{stylusEOFMagic, stylusEOFMagicSuffix, stylusEOFVersion} ) -type ActivatedWasm struct { - Asm []byte - Module []byte -} +type ActivatedWasm map[rawdb.Target][]byte // checks if a valid Stylus prefix is present func IsStylusProgram(b []byte) bool { @@ -76,38 +74,48 @@ func NewStylusPrefix(dictionary byte) []byte { return append(prefix, dictionary) } -func (s *StateDB) ActivateWasm(moduleHash common.Hash, asm, module []byte) { +func (s *StateDB) ActivateWasm(moduleHash common.Hash, asmMap map[rawdb.Target][]byte) { _, exists := s.arbExtraData.activatedWasms[moduleHash] if exists { return } - s.arbExtraData.activatedWasms[moduleHash] = &ActivatedWasm{ - Asm: asm, - Module: module, - } + s.arbExtraData.activatedWasms[moduleHash] = asmMap s.journal.append(wasmActivation{ moduleHash: moduleHash, }) } -func (s *StateDB) TryGetActivatedAsm(moduleHash common.Hash) ([]byte, error) { - info, exists := s.arbExtraData.activatedWasms[moduleHash] +func (s *StateDB) TryGetActivatedAsm(target rawdb.Target, moduleHash common.Hash) ([]byte, error) { + asmMap, exists := s.arbExtraData.activatedWasms[moduleHash] if exists { - return info.Asm, nil + if asm, exists := asmMap[target]; exists { + return asm, nil + } } - return s.db.ActivatedAsm(moduleHash) + return s.db.ActivatedAsm(target, moduleHash) } -func (s *StateDB) GetActivatedModule(moduleHash common.Hash) []byte { - info, exists := s.arbExtraData.activatedWasms[moduleHash] - if exists { - return info.Module +func (s *StateDB) TryGetActivatedAsmMap(targets []rawdb.Target, moduleHash common.Hash) (map[rawdb.Target][]byte, error) { + asmMap := s.arbExtraData.activatedWasms[moduleHash] + if asmMap != nil { + for _, target := range targets { + if _, exists := asmMap[target]; !exists { + return nil, fmt.Errorf("newly activated wasms for module %v exist, but they don't contain asm for target %v", moduleHash, target) + } + } + return asmMap, nil } - code, err := s.db.ActivatedModule(moduleHash) - if err != nil { - s.setError(fmt.Errorf("failed to load module for %x: %v", moduleHash, err)) + var err error + asmMap = make(map[rawdb.Target][]byte, len(targets)) + for _, target := range targets { + asm, dbErr := s.db.ActivatedAsm(target, moduleHash) + if dbErr == nil { + asmMap[target] = asm + } else { + err = errors.Join(fmt.Errorf("failed to read activated asm from database for target %v and module %v: %w", target, moduleHash, dbErr), err) + } } - return code + return asmMap, err } func (s *StateDB) GetStylusPages() (uint16, uint16) { @@ -148,11 +156,11 @@ func (s *StateDB) Deterministic() bool { } type ArbitrumExtraData struct { - unexpectedBalanceDelta *big.Int // total balance change across all accounts - userWasms UserWasms // user wasms encountered during execution - openWasmPages uint16 // number of pages currently open - everWasmPages uint16 // largest number of pages ever allocated during this tx's execution - activatedWasms map[common.Hash]*ActivatedWasm // newly activated WASMs + unexpectedBalanceDelta *big.Int // total balance change across all accounts + userWasms UserWasms // user wasms encountered during execution + openWasmPages uint16 // number of pages currently open + everWasmPages uint16 // largest number of pages ever allocated during this tx's execution + activatedWasms map[common.Hash]ActivatedWasm // newly activated WASMs recentWasms RecentWasms } @@ -233,16 +241,17 @@ func (s *StateDB) StartRecording() { s.arbExtraData.userWasms = make(UserWasms) } -func (s *StateDB) RecordProgram(moduleHash common.Hash) { - asm, err := s.TryGetActivatedAsm(moduleHash) +func (s *StateDB) RecordProgram(targets []rawdb.Target, moduleHash common.Hash) { + if len(targets) == 0 { + // nothing to record + return + } + asmMap, err := s.TryGetActivatedAsmMap(targets, moduleHash) if err != nil { - log.Crit("can't find activated wasm while recording", "modulehash", moduleHash) + log.Crit("can't find activated wasm while recording", "modulehash", moduleHash, "err", err) } if s.arbExtraData.userWasms != nil { - s.arbExtraData.userWasms[moduleHash] = ActivatedWasm{ - Asm: asm, - Module: s.GetActivatedModule(moduleHash), - } + s.arbExtraData.userWasms[moduleHash] = asmMap } } diff --git a/core/vm/interface.go b/core/vm/interface.go index 132d8d4ec1..2f8822d913 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -20,6 +20,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" @@ -30,9 +31,9 @@ import ( // StateDB is an EVM database for full state querying. type StateDB interface { // Arbitrum: manage Stylus wasms - ActivateWasm(moduleHash common.Hash, asm, module []byte) - TryGetActivatedAsm(moduleHash common.Hash) (asm []byte, err error) - GetActivatedModule(moduleHash common.Hash) (module []byte) + ActivateWasm(moduleHash common.Hash, asmMap map[rawdb.Target][]byte) + TryGetActivatedAsm(target rawdb.Target, moduleHash common.Hash) (asm []byte, err error) + TryGetActivatedAsmMap(targets []rawdb.Target, moduleHash common.Hash) (asmMap map[rawdb.Target][]byte, err error) RecordCacheWasm(wasm state.CacheWasm) RecordEvictWasm(wasm state.EvictWasm) GetRecentWasms() state.RecentWasms diff --git a/node/config.go b/node/config.go index 0aef1f18b2..8ac4c955af 100644 --- a/node/config.go +++ b/node/config.go @@ -221,6 +221,12 @@ type Config struct { EnablePersonal bool `toml:"-"` DBEngine string `toml:",omitempty"` + + // HTTPBodyLimit is the maximum number of bytes allowed in the HTTP request body. + HTTPBodyLimit int `toml:",omitempty"` + + // WSReadLimit is the maximum number of bytes allowed in the websocket request body. + WSReadLimit int64 `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into diff --git a/node/node.go b/node/node.go index 9a3a50e1c1..a320a016ab 100644 --- a/node/node.go +++ b/node/node.go @@ -420,6 +420,12 @@ func (n *Node) startRPC() error { batchResponseSizeLimit: n.config.BatchResponseMaxSize, apiFilter: n.apiFilter, } + if n.config.HTTPBodyLimit != 0 { + rpcConfig.httpBodyLimit = n.config.HTTPBodyLimit + } + if n.config.WSReadLimit != 0 { + rpcConfig.wsReadLimit = n.config.WSReadLimit + } initHttp := func(server *httpServer, port int) error { if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { @@ -467,6 +473,12 @@ func (n *Node) startRPC() error { batchResponseSizeLimit: engineAPIBatchResponseSizeLimit, httpBodyLimit: engineAPIBodyLimit, } + if n.config.HTTPBodyLimit != 0 { + sharedConfig.httpBodyLimit = n.config.HTTPBodyLimit + } + if n.config.WSReadLimit != 0 { + sharedConfig.wsReadLimit = n.config.WSReadLimit + } err := server.enableRPC(allAPIs, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, diff --git a/node/rpcstack.go b/node/rpcstack.go index 33085869b1..4f01668af5 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -59,6 +59,7 @@ type rpcEndpointConfig struct { batchResponseSizeLimit int apiFilter map[string]bool httpBodyLimit int + wsReadLimit int64 } type rpcHandler struct { @@ -363,7 +364,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret), + Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins, config.wsReadLimit), config.jwtSecret), server: srv, }) return nil diff --git a/rpc/client_test.go b/rpc/client_test.go index 01c326afb0..1a6b826415 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -586,7 +586,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) { var ( srv = NewServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -747,7 +747,7 @@ func TestClientReconnect(t *testing.T) { if err != nil { t.Fatal("can't listen:", err) } - go http.Serve(l, srv.WebsocketHandler([]string{"*"})) + go http.Serve(l, srv.WebsocketHandler([]string{"*"}, 0)) return srv, l } @@ -813,7 +813,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, var hs *httptest.Server switch transport { case "ws": - hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) + hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, 0)) case "http": hs = httptest.NewUnstartedServer(srv) default: diff --git a/rpc/websocket.go b/rpc/websocket.go index 9f67caf859..1f555a2f5f 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -47,7 +47,10 @@ var wsBufferPool = new(sync.Pool) // // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". -func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { +func (s *Server) WebsocketHandler(allowedOrigins []string, wsReadLimit int64) http.Handler { + if wsReadLimit == 0 { + wsReadLimit = wsDefaultReadLimit + } var upgrader = websocket.Upgrader{ ReadBufferSize: wsReadBuffer, WriteBufferSize: wsWriteBuffer, @@ -60,7 +63,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { log.Debug("WebSocket upgrade failed", "err", err) return } - codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) + codec := newWebsocketCodec(conn, r.Host, r.Header, wsReadLimit) s.ServeCodec(codec, 0) }) } diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index c6ea325d29..1288315ac9 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -53,7 +53,7 @@ func TestWebsocketOriginCheck(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -83,7 +83,7 @@ func TestWebsocketLargeCall(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -119,7 +119,7 @@ func TestWebsocketLargeRead(t *testing.T) { var ( srv = newTestServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() @@ -176,7 +176,7 @@ func TestWebsocketLargeRead(t *testing.T) { func TestWebsocketPeerInfo(t *testing.T) { var ( s = newTestServer() - ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) + ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}, 0)) tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") ) defer s.Stop() @@ -261,7 +261,7 @@ func TestClientWebsocketPing(t *testing.T) { func TestClientWebsocketLargeMessage(t *testing.T) { var ( srv = NewServer() - httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) + httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop()