diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index 96b6544b..9a7b8232 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -140,14 +140,18 @@ func (c *Conn) CompressionEnabled() bool { return c.compress } -func (c *Conn) handleDataFrame(opcode MessageType, fin bool, data []byte) { +func (c *Conn) handleDataFrame(opcode MessageType, fin bool, body []byte) { h := c.dataFrameHandler if c.isBlockingMod { - h(c, opcode, fin, data) + h(c, opcode, fin, body) } else { - c.Execute(func() { - h(c, opcode, fin, data) - }) + if !c.Execute(func() { + h(c, opcode, fin, body) + }) { + if len(body) > 0 { + c.Engine.BodyAllocator.Free(body) + } + } } } @@ -343,6 +347,7 @@ func (c *Conn) Parse(data []byte) error { var err error var body []byte var frame []byte + var message []byte var protocolMessage []byte var opcode MessageType var ok, fin, compress bool @@ -367,7 +372,7 @@ func (c *Conn) Parse(data []byte) error { bl := len(body) if c.dataFrameHandler != nil { if bl > 0 { - frame = c.Engine.BodyAllocator.Malloc(bl) + frame = allocator.Malloc(bl) copy(frame, body) } if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) { @@ -379,16 +384,20 @@ func (c *Conn) Parse(data []byte) error { if c.messageHandler != nil { if bl > 0 { if c.message == nil { - c.message = c.Engine.BodyAllocator.Malloc(len(body)) + c.message = allocator.Malloc(len(body)) copy(c.message, body) } else { - c.message = c.Engine.BodyAllocator.Append(c.message, body...) + c.message = allocator.Append(c.message, body...) } } + if fin { + message = c.message + c.message = nil + } } case PingMessage, PongMessage, CloseMessage: if len(body) > 0 { - protocolMessage = c.Engine.BodyAllocator.Malloc(len(body)) + protocolMessage = allocator.Malloc(len(body)) copy(protocolMessage, body) } default: @@ -413,23 +422,23 @@ func (c *Conn) Parse(data []byte) error { var b []byte var rc io.ReadCloser if c.Engine.WebsocketDecompressor != nil { - rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail))) + rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail))) } else { - rc = decompressReader(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail))) + rc = decompressReader(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail))) } - b, err = c.readAll(rc, len(c.message)*2) - c.Engine.BodyAllocator.Free(c.message) - c.message = b + b, err = c.readAll(rc, len(message)*2) + allocator.Free(message) + message = b rc.Close() if err != nil { return err } } - c.handleMessage(c.msgType, c.message) + c.handleMessage(c.msgType, message) } c.compress = false c.expectingFragments = false - c.message = nil + message = nil c.msgType = 0 } else { c.expectingFragments = true