Skip to content

Commit

Permalink
Merge pull request #135 from lesismal/websocket_engine_should_attache…
Browse files Browse the repository at this point in the history
…d_to_wsstate

move engine to websocket  state to avoid race condition
  • Loading branch information
lesismal authored Oct 22, 2021
2 parents b3e24f2 + 938cfef commit f94b40f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nbhttp/websocket/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
wsConn.OnClose(upgrader.onClose)

state.conn = wsConn
upgrader.Engine = parser.Engine
state.Engine = parser.Engine

if upgrader.openHandler != nil {
upgrader.openHandler(wsConn)
Expand Down
23 changes: 11 additions & 12 deletions nbhttp/websocket/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ type Upgrader struct {
messageHandler func(c *Conn, messageType MessageType, data []byte)
dataFrameHandler func(c *Conn, messageType MessageType, fin bool, data []byte)
onClose func(c *Conn, err error)

Engine *nbhttp.Engine
}
type connState struct {
common *Upgrader
Expand All @@ -60,6 +58,7 @@ type connState struct {
opcode MessageType
buffer []byte
message []byte
Engine *nbhttp.Engine
}

// CompressionEnabled .
Expand Down Expand Up @@ -282,7 +281,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}

state.conn = newConn(u, conn, subprotocol, compress)
u.Engine = parser.Engine
state.Engine = parser.Engine
state.conn.Engine = parser.Engine

if u.openHandler != nil {
Expand Down Expand Up @@ -364,18 +363,18 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error {
err = ErrMessageTooLarge
break
}
frame = u.common.Engine.BodyAllocator.Malloc(bl)
frame = u.Engine.BodyAllocator.Malloc(bl)
copy(frame, body)
}
if u.opcode == TextMessage && len(frame) > 0 && !u.common.Engine.CheckUtf8(frame) {
if u.opcode == TextMessage && len(frame) > 0 && !u.Engine.CheckUtf8(frame) {
u.conn.Close()
} else {
u.common.handleDataFrame(p, u.conn, u.opcode, fin, frame)
}
}
if bl > 0 && u.common.messageHandler != nil {
if u.message == nil {
u.message = u.common.Engine.BodyAllocator.Malloc(len(body))
u.message = u.Engine.BodyAllocator.Malloc(len(body))
if u.isMessageTooLarge(len(body)) {
err = ErrMessageTooLarge
break
Expand All @@ -395,7 +394,7 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error {
var b []byte
rc := decompressReader(io.MultiReader(bytes.NewBuffer(u.message), strings.NewReader(flateReaderTail)))
b, err = u.readAll(rc, len(u.message)*2)
u.common.Engine.BodyAllocator.Free(u.message)
u.Engine.BodyAllocator.Free(u.message)
u.message = b
rc.Close()
if err != nil {
Expand All @@ -418,7 +417,7 @@ func (u *connState) Read(p *nbhttp.Parser, data []byte) error {
err = ErrMessageTooLarge
break
}
frame = u.common.Engine.BodyAllocator.Malloc(len(body))
frame = u.Engine.BodyAllocator.Malloc(len(body))
copy(frame, body)
}
u.handleProtocolMessage(p, opcode, frame)
Expand Down Expand Up @@ -468,7 +467,7 @@ func (u *Upgrader) handleDataFrame(p *nbhttp.Parser, c *Conn, opcode MessageType
}

func (u *connState) handleMessage(p *nbhttp.Parser, opcode MessageType, body []byte) {
if u.opcode == TextMessage && !u.common.Engine.CheckUtf8(u.message) {
if u.opcode == TextMessage && !u.Engine.CheckUtf8(u.message) {
u.conn.Close()
return
}
Expand All @@ -482,8 +481,8 @@ func (u *connState) handleMessage(p *nbhttp.Parser, opcode MessageType, body []b
func (u *connState) handleProtocolMessage(p *nbhttp.Parser, opcode MessageType, body []byte) {
p.Execute(func() {
u.common.handleWsMessage(u.conn, opcode, body)
if len(body) > 0 && u.common.Engine.ReleaseWebsocketPayload {
u.common.Engine.BodyAllocator.Free(body)
if len(body) > 0 && u.Engine.ReleaseWebsocketPayload {
u.Engine.BodyAllocator.Free(body)
}
})
}
Expand Down Expand Up @@ -871,7 +870,7 @@ func nextTokenOrQuoted(s string) (value string, rest string) {

func (u *connState) readAll(r io.Reader, size int) ([]byte, error) {
const maxAppendSize = 1024 * 1024 * 4
buf := u.common.Engine.BodyAllocator.Malloc(size)[0:0]
buf := u.Engine.BodyAllocator.Malloc(size)[0:0]
for {
n, err := r.Read(buf[len(buf):cap(buf)])
if n > 0 {
Expand Down

0 comments on commit f94b40f

Please sign in to comment.