Skip to content

Commit

Permalink
Merge pull request #350 from OffchainLabs/request_size
Browse files Browse the repository at this point in the history
Add config option to change HTTPBodyLimit and WSReadLimit
  • Loading branch information
PlasmaPower authored Aug 12, 2024
2 parents 48de203 + a1fc200 commit 5b905ae
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 11 deletions.
6 changes: 6 additions & 0 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ type Config struct {
EnablePersonal bool `toml:"-"`

DBEngine string `toml:",omitempty"`

// HTTPBodyLimit is the maximum number of bytes allowed in the HTTP request body.
HTTPBodyLimit int `toml:",omitempty"`

// WSReadLimit is the maximum number of bytes allowed in the websocket request body.
WSReadLimit int64 `toml:",omitempty"`
}

// IPCEndpoint resolves an IPC endpoint based on a configured value, taking into
Expand Down
12 changes: 12 additions & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ func (n *Node) startRPC() error {
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
apiFilter: n.apiFilter,
}
if n.config.HTTPBodyLimit != 0 {
rpcConfig.httpBodyLimit = n.config.HTTPBodyLimit
}
if n.config.WSReadLimit != 0 {
rpcConfig.wsReadLimit = n.config.WSReadLimit
}

initHttp := func(server *httpServer, port int) error {
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
Expand Down Expand Up @@ -470,6 +476,12 @@ func (n *Node) startRPC() error {
batchResponseSizeLimit: engineAPIBatchResponseSizeLimit,
httpBodyLimit: engineAPIBodyLimit,
}
if n.config.HTTPBodyLimit != 0 {
sharedConfig.httpBodyLimit = n.config.HTTPBodyLimit
}
if n.config.WSReadLimit != 0 {
sharedConfig.wsReadLimit = n.config.WSReadLimit
}
err := server.enableRPC(allAPIs, httpConfig{
CorsAllowedOrigins: DefaultAuthCors,
Vhosts: n.config.AuthVirtualHosts,
Expand Down
3 changes: 2 additions & 1 deletion node/rpcstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type rpcEndpointConfig struct {
batchResponseSizeLimit int
apiFilter map[string]bool
httpBodyLimit int
wsReadLimit int64
}

type rpcHandler struct {
Expand Down Expand Up @@ -362,7 +363,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret),
Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins, config.wsReadLimit), config.jwtSecret),
server: srv,
})
return nil
Expand Down
6 changes: 3 additions & 3 deletions rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ func TestClientSubscriptionChannelClose(t *testing.T) {

var (
srv = NewServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil))
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -745,7 +745,7 @@ func TestClientReconnect(t *testing.T) {
if err != nil {
t.Fatal("can't listen:", err)
}
go http.Serve(l, srv.WebsocketHandler([]string{"*"}))
go http.Serve(l, srv.WebsocketHandler([]string{"*"}, 0))
return srv, l
}

Expand Down Expand Up @@ -811,7 +811,7 @@ func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client,
var hs *httptest.Server
switch transport {
case "ws":
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}))
hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}, 0))
case "http":
hs = httptest.NewUnstartedServer(srv)
default:
Expand Down
7 changes: 5 additions & 2 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ var wsBufferPool = new(sync.Pool)
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
func (s *Server) WebsocketHandler(allowedOrigins []string, wsReadLimit int64) http.Handler {
if wsReadLimit == 0 {
wsReadLimit = wsDefaultReadLimit
}
var upgrader = websocket.Upgrader{
ReadBufferSize: wsReadBuffer,
WriteBufferSize: wsWriteBuffer,
Expand All @@ -60,7 +63,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
codec := newWebsocketCodec(conn, r.Host, r.Header, wsReadLimit)
s.ServeCodec(codec, 0)
})
}
Expand Down
10 changes: 5 additions & 5 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestWebsocketOriginCheck(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}, 0))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestWebsocketLargeCall(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestWebsocketLargeRead(t *testing.T) {

var (
srv = newTestServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}, 0))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down Expand Up @@ -176,7 +176,7 @@ func TestWebsocketLargeRead(t *testing.T) {
func TestWebsocketPeerInfo(t *testing.T) {
var (
s = newTestServer()
ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}))
ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}, 0))
tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:")
)
defer s.Stop()
Expand Down Expand Up @@ -260,7 +260,7 @@ func TestClientWebsocketPing(t *testing.T) {
func TestClientWebsocketLargeMessage(t *testing.T) {
var (
srv = NewServer()
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil))
httpsrv = httptest.NewServer(srv.WebsocketHandler(nil, 0))
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
)
defer srv.Stop()
Expand Down

0 comments on commit 5b905ae

Please sign in to comment.