diff --git a/nbhttp/websocket/dialer.go b/nbhttp/websocket/dialer.go index 48506d01..abcf5b2f 100644 --- a/nbhttp/websocket/dialer.go +++ b/nbhttp/websocket/dialer.go @@ -245,7 +245,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h wsConn.OnClose(upgrader.onClose) state.conn = wsConn - upgrader.Engine = parser.Engine + state.Engine = parser.Engine if upgrader.openHandler != nil { upgrader.openHandler(wsConn) diff --git a/nbhttp/websocket/upgrader.go b/nbhttp/websocket/upgrader.go index 429ce25c..666eb46d 100644 --- a/nbhttp/websocket/upgrader.go +++ b/nbhttp/websocket/upgrader.go @@ -49,8 +49,6 @@ type Upgrader struct { messageHandler func(c *Conn, messageType MessageType, data []byte) dataFrameHandler func(c *Conn, messageType MessageType, fin bool, data []byte) onClose func(c *Conn, err error) - - Engine *nbhttp.Engine } type connState struct { common *Upgrader @@ -60,6 +58,7 @@ type connState struct { opcode MessageType buffer []byte message []byte + Engine *nbhttp.Engine } // CompressionEnabled . @@ -282,7 +281,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } state.conn = newConn(u, conn, subprotocol, compress) - u.Engine = parser.Engine + state.Engine = parser.Engine state.conn.Engine = parser.Engine if u.openHandler != nil { @@ -364,10 +363,10 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error { err = ErrMessageTooLarge break } - frame = u.common.Engine.BodyAllocator.Malloc(bl) + frame = u.Engine.BodyAllocator.Malloc(bl) copy(frame, body) } - if u.opcode == TextMessage && len(frame) > 0 && !u.common.Engine.CheckUtf8(frame) { + if u.opcode == TextMessage && len(frame) > 0 && !u.Engine.CheckUtf8(frame) { u.conn.Close() } else { u.common.handleDataFrame(p, u.conn, u.opcode, fin, frame) @@ -375,7 +374,7 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error { } if bl > 0 && u.common.messageHandler != nil { if u.message == nil { - u.message = u.common.Engine.BodyAllocator.Malloc(len(body)) + u.message = u.Engine.BodyAllocator.Malloc(len(body)) if u.isMessageTooLarge(len(body)) { err = ErrMessageTooLarge break @@ -395,7 +394,7 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error { var b []byte rc := decompressReader(io.MultiReader(bytes.NewBuffer(u.message), strings.NewReader(flateReaderTail))) b, err = u.readAll(rc, len(u.message)*2) - u.common.Engine.BodyAllocator.Free(u.message) + u.Engine.BodyAllocator.Free(u.message) u.message = b rc.Close() if err != nil { @@ -418,7 +417,7 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error { err = ErrMessageTooLarge break } - frame = u.common.Engine.BodyAllocator.Malloc(len(body)) + frame = u.Engine.BodyAllocator.Malloc(len(body)) copy(frame, body) } u.handleProtocolMessage(p, opcode, frame) @@ -468,7 +467,7 @@ func (u *Upgrader) handleDataFrame(p *nbhttp.Parser, c *Conn, opcode MessageType } func (u *connState) handleMessage(p *nbhttp.Parser, opcode MessageType, body []byte) { - if u.opcode == TextMessage && !u.common.Engine.CheckUtf8(u.message) { + if u.opcode == TextMessage && !u.Engine.CheckUtf8(u.message) { u.conn.Close() return } @@ -482,8 +481,8 @@ func (u *connState) handleMessage(p *nbhttp.Parser, opcode MessageType, body []b func (u *connState) handleProtocolMessage(p *nbhttp.Parser, opcode MessageType, body []byte) { p.Execute(func() { u.common.handleWsMessage(u.conn, opcode, body) - if len(body) > 0 && u.common.Engine.ReleaseWebsocketPayload { - u.common.Engine.BodyAllocator.Free(body) + if len(body) > 0 && u.Engine.ReleaseWebsocketPayload { + u.Engine.BodyAllocator.Free(body) } }) } @@ -871,7 +870,7 @@ func nextTokenOrQuoted(s string) (value string, rest string) { func (u *connState) readAll(r io.Reader, size int) ([]byte, error) { const maxAppendSize = 1024 * 1024 * 4 - buf := u.common.Engine.BodyAllocator.Malloc(size)[0:0] + buf := u.Engine.BodyAllocator.Malloc(size)[0:0] for { n, err := r.Read(buf[len(buf):cap(buf)]) if n > 0 {