From e7008773c58dfef3bb18fa2633823a8fa8008eaf Mon Sep 17 00:00:00 2001 From: Maksym Ostroverkhov Date: Fri, 28 Jun 2024 16:22:21 +0300 Subject: [PATCH] Outbound text frames support (#6) * WebSocketFrameFactory: add text frames support. * WebSocketFrameFactory.Encoder: add text frames support. * WebSocketFrameFactory.BulkEncoder: add text frames support. --- .../http/websocketx/WebSocketCodecTest.java | 530 +++++++++++++++++- .../websocketx/MaskingWebSocketEncoder.java | 67 ++- .../NonMaskingWebSocketEncoder.java | 62 +- .../websocketx/WebSocketFrameFactory.java | 31 + 4 files changed, 668 insertions(+), 22 deletions(-) diff --git a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java index e4a4dd7..d725aac 100644 --- a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java +++ b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java @@ -84,7 +84,7 @@ void tearDown() { @ParameterizedTest void binaryFramesEncoder(boolean mask) throws Exception { int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE; - Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false); + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); BinaryFramesEncoderClientHandler clientHandler = new BinaryFramesEncoderClientHandler(maxFrameSize); Channel client = @@ -98,12 +98,31 @@ void binaryFramesEncoder(boolean mask) throws Exception { client.close(); } + @Timeout(300) + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void textFramesEncoder(boolean mask) throws Exception { + int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE; + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); + TextFramesEncoderClientHandler clientHandler = + new TextFramesEncoderClientHandler(maxFrameSize, 'a'); + Channel client = + webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler); + + WebSocketFrameFactory.Encoder encoder = clientHandler.onHandshakeCompleted().join(); + Assertions.assertThat(encoder).isNotNull(); + + CompletableFuture onComplete = clientHandler.startFramesExchange(); + onComplete.join(); + client.close(); + } + @Timeout(300) @ValueSource(booleans = {true, false}) @ParameterizedTest void binaryFramesBulkEncoder(boolean mask) throws Exception { int maxFrameSize = 1000; - Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false); + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); BinaryFramesEncoderClientBulkHandler clientHandler = new BinaryFramesEncoderClientBulkHandler(maxFrameSize); Channel client = @@ -117,6 +136,44 @@ void binaryFramesBulkEncoder(boolean mask) throws Exception { client.close(); } + @Timeout(300) + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void textFramesBulkEncoder(boolean mask) throws Exception { + int maxFrameSize = 1000; + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); + TextFramesEncoderClientBulkHandler clientHandler = + new TextFramesEncoderClientBulkHandler(maxFrameSize, 'a'); + Channel client = + webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler); + + WebSocketFrameFactory.BulkEncoder encoder = clientHandler.onHandshakeCompleted().join(); + Assertions.assertThat(encoder).isNotNull(); + + CompletableFuture onComplete = clientHandler.startFramesExchange(); + onComplete.join(); + client.close(); + } + + @Timeout(300) + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void textFramesFactory(boolean mask) throws Exception { + int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE; + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); + TextFramesFactoryClientHandler clientHandler = + new TextFramesFactoryClientHandler(maxFrameSize, 'a'); + Channel client = + webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler); + + WebSocketFrameFactory frameFactory = clientHandler.onHandshakeCompleted().join(); + Assertions.assertThat(frameFactory).isNotNull(); + + CompletableFuture onComplete = clientHandler.startFramesExchange(); + onComplete.join(); + client.close(); + } + @Timeout(300) @MethodSource("maskingArgs") @ParameterizedTest @@ -124,7 +181,7 @@ void allSizeBinaryFramesDefaultDecoder( boolean mask, Class webSocketFrameFactoryType, Class webSocketDecoderType) throws Exception { int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE; - Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false); + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); BinaryFramesTestClientHandler clientHandler = new BinaryFramesTestClientHandler(maxFrameSize); Channel client = webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler); @@ -142,7 +199,7 @@ void allSizeBinaryFramesDefaultDecoder( @Test void binaryFramesSmallDecoder() throws Exception { int maxFrameSize = SMALL_CODEC_MAX_FRAME_SIZE; - Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), false, false); + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), false, false); BinaryFramesTestClientHandler clientHandler = new BinaryFramesTestClientHandler(maxFrameSize); Channel client = webSocketCallbacksClient(s.localAddress(), false, false, maxFrameSize, clientHandler); @@ -450,7 +507,7 @@ protected void initChannel(SocketChannel ch) { WebSocketDecoderConfig.newBuilder() .expectMaskedFrames(expectMaskedFrames) .allowMaskMismatch(allowMaskMismatch) - .withUTF8Validator(false) + .withUTF8Validator(true) .allowExtensions(false) .maxFramePayloadLength(65535) .build(); @@ -610,6 +667,159 @@ private void sendFrames(ChannelHandlerContext c, int toSend) { } } + static class TextFramesEncoderClientBulkHandler + implements WebSocketCallbacksHandler, WebSocketFrameListener { + private final CompletableFuture onHandshakeComplete = + new CompletableFuture<>(); + private final CompletableFuture onFrameExchangeComplete = new CompletableFuture<>(); + private final int framesCount; + private final char expectedAsciiChar; + private WebSocketFrameFactory.BulkEncoder textFrameEncoder; + private int receivedFrames; + private int sentFrames; + private ByteBuf outBuffer; + private volatile ChannelHandlerContext ctx; + + TextFramesEncoderClientBulkHandler(int maxFrameSize, char expectedAsciiChar) { + this.framesCount = maxFrameSize; + this.expectedAsciiChar = expectedAsciiChar; + } + + @Override + public WebSocketFrameListener exchange( + ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) { + this.textFrameEncoder = webSocketFrameFactory.bulkEncoder(); + return this; + } + + @Override + public void onChannelRead( + ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) { + if (!finalFragment) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-final frame: " + finalFragment)); + payload.release(); + return; + } + if (rsv != 0) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received frame with non-zero rsv: " + rsv)); + payload.release(); + return; + } + if (opcode != WebSocketProtocol.OPCODE_TEXT) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-text frame: " + Long.toHexString(opcode))); + payload.release(); + return; + } + + int readableBytes = payload.readableBytes(); + + int expectedSize = receivedFrames; + if (expectedSize != readableBytes) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame of unexpected size: " + + expectedSize + + ", actual: " + + readableBytes)); + payload.release(); + return; + } + + for (int i = 0; i < readableBytes; i++) { + char ch = (char) payload.readByte(); + if (ch != expectedAsciiChar) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame with unexpected content: " + + ch + + ", expected: " + + expectedAsciiChar)); + payload.release(); + return; + } + } + payload.release(); + if (++receivedFrames == framesCount) { + onFrameExchangeComplete.complete(null); + } + } + + @Override + public void onOpen(ChannelHandlerContext ctx) { + this.ctx = ctx; + int bufferSize = 4 * framesCount; + this.outBuffer = ctx.alloc().buffer(bufferSize, bufferSize); + onHandshakeComplete.complete(textFrameEncoder); + } + + @Override + public void onClose(ChannelHandlerContext ctx) { + ByteBuf out = outBuffer; + if (out != null) { + outBuffer = null; + out.release(); + } + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(new ClosedChannelException()); + } + } + + @Override + public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(cause); + } + } + + CompletableFuture onHandshakeCompleted() { + return onHandshakeComplete; + } + + CompletableFuture startFramesExchange() { + ChannelHandlerContext c = ctx; + c.executor().execute(() -> sendFrames(c, framesCount - sentFrames)); + return onFrameExchangeComplete; + } + + private void sendFrames(ChannelHandlerContext c, int toSend) { + WebSocketFrameFactory.BulkEncoder frameEncoder = textFrameEncoder; + for (int frameIdx = 0; frameIdx < toSend; frameIdx++) { + if (!c.channel().isOpen()) { + return; + } + int payloadSize = sentFrames; + int frameSize = frameEncoder.sizeofTextFrame(payloadSize); + ByteBuf out = outBuffer; + if (frameSize > out.capacity() - out.writerIndex()) { + int readableBytes = out.readableBytes(); + int bufferSize = 4 * framesCount; + outBuffer = c.alloc().buffer(bufferSize, bufferSize); + if (c.channel().bytesBeforeUnwritable() < readableBytes) { + c.writeAndFlush(out, c.voidPromise()); + } else { + c.write(out, c.voidPromise()); + } + out = outBuffer; + } + int mask = frameEncoder.encodeTextFramePrefix(out, payloadSize); + for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) { + out.writeByte(expectedAsciiChar); + } + frameEncoder.maskTextFrame(out, mask, payloadSize); + sentFrames++; + } + ByteBuf out = outBuffer; + if (out.readableBytes() > 0) { + c.writeAndFlush(out, c.voidPromise()); + } else { + c.flush(); + } + } + } + static class BinaryFramesEncoderClientHandler implements WebSocketCallbacksHandler, WebSocketFrameListener { private final CompletableFuture onHandshakeComplete = @@ -759,6 +969,314 @@ private void sendFrames(ChannelHandlerContext c, int toSend) { } } + static class TextFramesEncoderClientHandler + implements WebSocketCallbacksHandler, WebSocketFrameListener { + private final CompletableFuture onHandshakeComplete = + new CompletableFuture<>(); + private final CompletableFuture onFrameExchangeComplete = new CompletableFuture<>(); + private WebSocketFrameFactory.Encoder textFrameEncoder; + private final int framesCount; + private final char expectedAsciiChar; + private int receivedFrames; + private int sentFrames; + private volatile ChannelHandlerContext ctx; + + TextFramesEncoderClientHandler(int maxFrameSize, char expectedAsciiChar) { + this.framesCount = maxFrameSize; + this.expectedAsciiChar = expectedAsciiChar; + } + + @Override + public WebSocketFrameListener exchange( + ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) { + this.textFrameEncoder = webSocketFrameFactory.encoder(); + return this; + } + + @Override + public void onChannelRead( + ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) { + if (!finalFragment) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-final frame: " + finalFragment)); + payload.release(); + return; + } + if (rsv != 0) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received frame with non-zero rsv: " + rsv)); + payload.release(); + return; + } + if (opcode != WebSocketProtocol.OPCODE_TEXT) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-text frame: " + Long.toHexString(opcode))); + payload.release(); + return; + } + + int readableBytes = payload.readableBytes(); + + int expectedSize = receivedFrames; + if (expectedSize != readableBytes) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame of unexpected size: " + + expectedSize + + ", actual: " + + readableBytes)); + payload.release(); + return; + } + + for (int i = 0; i < readableBytes; i++) { + char ch = (char) payload.readByte(); + if (ch != expectedAsciiChar) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame with unexpected content: " + + ch + + ", expected: " + + expectedAsciiChar)); + payload.release(); + return; + } + } + payload.release(); + if (++receivedFrames == framesCount) { + onFrameExchangeComplete.complete(null); + } + } + + @Override + public void onChannelWritabilityChanged(ChannelHandlerContext ctx) { + boolean writable = ctx.channel().isWritable(); + if (sentFrames > 0 && writable) { + int toSend = framesCount - sentFrames; + if (toSend > 0) { + sendFrames(ctx, toSend); + } + } + } + + @Override + public void onOpen(ChannelHandlerContext ctx) { + this.ctx = ctx; + onHandshakeComplete.complete(textFrameEncoder); + } + + @Override + public void onClose(ChannelHandlerContext ctx) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(new ClosedChannelException()); + } + } + + @Override + public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(cause); + } + } + + CompletableFuture onHandshakeCompleted() { + return onHandshakeComplete; + } + + CompletableFuture startFramesExchange() { + ChannelHandlerContext c = ctx; + c.executor().execute(() -> sendFrames(c, framesCount - sentFrames)); + return onFrameExchangeComplete; + } + + private void sendFrames(ChannelHandlerContext c, int toSend) { + Channel ch = c.channel(); + WebSocketFrameFactory.Encoder frameEncoder = textFrameEncoder; + boolean pendingFlush = false; + ByteBufAllocator allocator = c.alloc(); + for (int frameIdx = 0; frameIdx < toSend; frameIdx++) { + if (!c.channel().isOpen()) { + return; + } + int payloadSize = sentFrames; + int frameSize = frameEncoder.sizeofTextFrame(payloadSize); + ByteBuf textFrame = allocator.buffer(frameSize); + textFrame.writerIndex(frameSize - payloadSize); + for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) { + textFrame.writeByte(expectedAsciiChar); + } + ByteBuf maskedTextFrame = frameEncoder.encodeTextFrame(textFrame); + sentFrames++; + if (ch.bytesBeforeUnwritable() < textFrame.capacity()) { + c.writeAndFlush(maskedTextFrame, c.voidPromise()); + pendingFlush = false; + if (!ch.isWritable()) { + return; + } + } else { + c.write(maskedTextFrame, c.voidPromise()); + pendingFlush = true; + } + } + if (pendingFlush) { + c.flush(); + } + } + } + + static class TextFramesFactoryClientHandler + implements WebSocketCallbacksHandler, WebSocketFrameListener { + private final CompletableFuture onHandshakeComplete = + new CompletableFuture<>(); + private final CompletableFuture onFrameExchangeComplete = new CompletableFuture<>(); + private WebSocketFrameFactory frameFactory; + private final int framesCount; + private final char expectedAsciiChar; + private int receivedFrames; + private int sentFrames; + private volatile ChannelHandlerContext ctx; + + TextFramesFactoryClientHandler(int maxFrameSize, char expectedAsciiChar) { + this.framesCount = maxFrameSize; + this.expectedAsciiChar = expectedAsciiChar; + } + + @Override + public WebSocketFrameListener exchange( + ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) { + this.frameFactory = webSocketFrameFactory; + return this; + } + + @Override + public void onChannelRead( + ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) { + if (!finalFragment) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-final frame: " + finalFragment)); + payload.release(); + return; + } + if (rsv != 0) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received frame with non-zero rsv: " + rsv)); + payload.release(); + return; + } + if (opcode != WebSocketProtocol.OPCODE_TEXT) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-text frame: " + Long.toHexString(opcode))); + payload.release(); + return; + } + + int readableBytes = payload.readableBytes(); + + int expectedSize = receivedFrames; + if (expectedSize != readableBytes) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame of unexpected size: " + + expectedSize + + ", actual: " + + readableBytes)); + payload.release(); + return; + } + + for (int i = 0; i < readableBytes; i++) { + char ch = (char) payload.readByte(); + if (ch != expectedAsciiChar) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame with unexpected content: " + + ch + + ", expected: " + + expectedAsciiChar)); + payload.release(); + return; + } + } + payload.release(); + if (++receivedFrames == framesCount) { + onFrameExchangeComplete.complete(null); + } + } + + @Override + public void onChannelWritabilityChanged(ChannelHandlerContext ctx) { + boolean writable = ctx.channel().isWritable(); + if (sentFrames > 0 && writable) { + int toSend = framesCount - sentFrames; + if (toSend > 0) { + sendFrames(ctx, toSend); + } + } + } + + @Override + public void onOpen(ChannelHandlerContext ctx) { + this.ctx = ctx; + onHandshakeComplete.complete(frameFactory); + } + + @Override + public void onClose(ChannelHandlerContext ctx) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(new ClosedChannelException()); + } + } + + @Override + public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(cause); + } + } + + CompletableFuture onHandshakeCompleted() { + return onHandshakeComplete; + } + + CompletableFuture startFramesExchange() { + ChannelHandlerContext c = ctx; + c.executor().execute(() -> sendFrames(c, framesCount - sentFrames)); + return onFrameExchangeComplete; + } + + private void sendFrames(ChannelHandlerContext c, int toSend) { + Channel ch = c.channel(); + WebSocketFrameFactory factory = frameFactory; + boolean pendingFlush = false; + ByteBufAllocator allocator = c.alloc(); + for (int frameIdx = 0; frameIdx < toSend; frameIdx++) { + if (!c.channel().isOpen()) { + return; + } + int payloadSize = sentFrames; + ByteBuf textFrame = factory.createTextFrame(allocator, payloadSize); + for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) { + textFrame.writeByte(expectedAsciiChar); + } + ByteBuf maskedTextFrame = factory.mask(textFrame); + sentFrames++; + if (ch.bytesBeforeUnwritable() < textFrame.capacity()) { + c.writeAndFlush(maskedTextFrame, c.voidPromise()); + pendingFlush = false; + if (!ch.isWritable()) { + return; + } + } else { + c.write(maskedTextFrame, c.voidPromise()); + pendingFlush = true; + } + } + if (pendingFlush) { + c.flush(); + } + } + } + static class BinaryFramesTestClientHandler implements WebSocketCallbacksHandler, WebSocketFrameListener { private final CompletableFuture onHandshakeComplete = @@ -1186,7 +1704,7 @@ private void sendFrames(ChannelHandlerContext c, int toSend) { } } - static class BinaryFramesTestServerHandler extends ChannelInboundHandlerAdapter { + static class WebSocketFramesTestServerHandler extends ChannelInboundHandlerAdapter { boolean ready = true; boolean pendingFlush; diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java index a587daa..dde1868 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java @@ -20,6 +20,7 @@ import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_CLOSE; import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PING; import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PONG; +import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_TEXT; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -55,6 +56,8 @@ static class FrameFactory static final int PREFIX_SIZE_SMALL = 6; static final int BINARY_FRAME_SMALL = OPCODE_BINARY << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7; + static final int TEXT_FRAME_SMALL = + OPCODE_TEXT << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7; static final int CLOSE_FRAME = OPCODE_CLOSE << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7; @@ -65,27 +68,38 @@ static class FrameFactory static final int PREFIX_SIZE_MEDIUM = 8; static final int BINARY_FRAME_MEDIUM = (BINARY_FRAME_SMALL | /*LEN*/ (byte) 126) << 16; + static final int TEXT_FRAME_MEDIUM = (TEXT_FRAME_SMALL | /*LEN*/ (byte) 126) << 16; static final WebSocketFrameFactory INSTANCE = new FrameFactory(); - @Override - public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) { + static ByteBuf createDataFrame( + ByteBufAllocator allocator, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { return allocator .buffer(PREFIX_SIZE_SMALL + payloadSize) - .writeShort(BINARY_FRAME_SMALL | payloadSize) + .writeShort(prefixSmall | payloadSize) .readerIndex(2) .writeInt(mask()); } else if (payloadSize <= 65_535) { return allocator .buffer(PREFIX_SIZE_MEDIUM + payloadSize) - .writeLong((long) (BINARY_FRAME_MEDIUM | payloadSize) << 32 | mask()) + .writeLong((long) (prefixMedium | payloadSize) << 32 | mask()) .readerIndex(4); } else { throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); } } + @Override + public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) { + return createDataFrame(allocator, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public ByteBuf createTextFrame(ByteBufAllocator allocator, int payloadSize) { + return createDataFrame(allocator, payloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + @Override public ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason) { if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) { @@ -155,11 +169,20 @@ public BulkEncoder bulkEncoder() { @Override public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) { + return encodeDataFrame(binaryFrame, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public ByteBuf encodeTextFrame(ByteBuf textFrame) { + return encodeDataFrame(textFrame, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixMedium) { int frameSize = binaryFrame.readableBytes(); int smallPrefixSize = 6; if (frameSize <= 125 + smallPrefixSize) { int payloadSize = frameSize - smallPrefixSize; - binaryFrame.setShort(0, BINARY_FRAME_SMALL | payloadSize); + binaryFrame.setShort(0, prefixSmall | payloadSize); int mask = mask(); binaryFrame.setInt(2, mask); return mask(mask, binaryFrame, smallPrefixSize, binaryFrame.writerIndex()); @@ -169,7 +192,7 @@ public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) { if (frameSize <= 65_535 + mediumPrefixSize) { int payloadSize = frameSize - mediumPrefixSize; int mask = mask(); - binaryFrame.setLong(0, ((BINARY_FRAME_MEDIUM | (long) payloadSize) << 32) | mask); + binaryFrame.setLong(0, ((prefixMedium | (long) payloadSize) << 32) | mask); return mask(mask, binaryFrame, mediumPrefixSize, binaryFrame.writerIndex()); } int payloadSize = frameSize - 12; @@ -178,8 +201,18 @@ public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) { @Override public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { + return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static int encodeDataFramePrefix( + ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { - byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize); + byteBuf.writeShort(prefixSmall | payloadSize); int mask = mask(); byteBuf.writeInt(mask); return mask; @@ -187,7 +220,7 @@ public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { if (payloadSize <= 65_535) { int mask = mask(); - byteBuf.writeLong(((BINARY_FRAME_MEDIUM | (long) payloadSize) << 32) | mask); + byteBuf.writeLong(((prefixMedium | (long) payloadSize) << 32) | mask); return mask; } throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); @@ -195,6 +228,15 @@ public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { @Override public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) { + return maskDataFrame(byteBuf, mask, payloadSize); + } + + @Override + public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + return maskDataFrame(byteBuf, mask, textPayloadSize); + } + + static ByteBuf maskDataFrame(ByteBuf byteBuf, int mask, int payloadSize) { int end = byteBuf.writerIndex(); int start = end - payloadSize; return mask(mask, byteBuf, start, end); @@ -202,6 +244,15 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) { @Override public int sizeofBinaryFrame(int payloadSize) { + return sizeOfDataFrame(payloadSize); + } + + @Override + public int sizeofTextFrame(int textPayloadSize) { + return sizeOfDataFrame(textPayloadSize); + } + + static int sizeOfDataFrame(int payloadSize) { if (payloadSize <= 125) { return payloadSize + 6; } diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java index a120f7e..1734904 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java @@ -20,6 +20,7 @@ import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_CLOSE; import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PING; import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PONG; +import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_TEXT; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -53,6 +54,7 @@ static class FrameFactory WebSocketFrameFactory.BulkEncoder { static final int PREFIX_SIZE_SMALL = 2; static final int BINARY_FRAME_SMALL = OPCODE_BINARY << 8 | /*FIN*/ (byte) 1 << 15; + static final int TEXT_FRAME_SMALL = OPCODE_TEXT << 8 | /*FIN*/ (byte) 1 << 15; static final int CLOSE_FRAME = OPCODE_CLOSE << 8 | /*FIN*/ (byte) 1 << 15; static final int PING_FRAME = OPCODE_PING << 8 | /*FIN*/ (byte) 1 << 15; @@ -60,25 +62,36 @@ static class FrameFactory static final int PREFIX_SIZE_MEDIUM = 4; static final int BINARY_FRAME_MEDIUM = (BINARY_FRAME_SMALL | /*LEN*/ (byte) 126) << 16; + static final int TEXT_FRAME_MEDIUM = (TEXT_FRAME_SMALL | /*LEN*/ (byte) 126) << 16; static final WebSocketFrameFactory INSTANCE = new FrameFactory(); - @Override - public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) { + static ByteBuf createDataFrame( + ByteBufAllocator allocator, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { return allocator .buffer(PREFIX_SIZE_SMALL + payloadSize) - .writeShort(BINARY_FRAME_SMALL | payloadSize); + .writeShort(prefixSmall | payloadSize); } if (payloadSize <= 65_535) { return allocator .buffer(PREFIX_SIZE_MEDIUM + payloadSize) - .writeInt(BINARY_FRAME_MEDIUM | payloadSize); + .writeInt(prefixMedium | payloadSize); } throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); } + @Override + public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) { + return createDataFrame(allocator, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public ByteBuf createTextFrame(ByteBufAllocator allocator, int textDataSize) { + return createDataFrame(allocator, textDataSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + @Override public ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason) { if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) { @@ -136,17 +149,26 @@ public BulkEncoder bulkEncoder() { @Override public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) { + return encodeDataFrame(binaryFrame, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public ByteBuf encodeTextFrame(ByteBuf textFrame) { + return encodeDataFrame(textFrame, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixMedium) { int frameSize = binaryFrame.readableBytes(); int smallPrefixSize = 2; if (frameSize <= 125 + smallPrefixSize) { int payloadSize = frameSize - smallPrefixSize; - return binaryFrame.setShort(0, BINARY_FRAME_SMALL | payloadSize); + return binaryFrame.setShort(0, prefixSmall | payloadSize); } int mediumPrefixSize = 4; if (frameSize <= 65_535 + mediumPrefixSize) { int payloadSize = frameSize - mediumPrefixSize; - return binaryFrame.setInt(0, BINARY_FRAME_MEDIUM | payloadSize); + return binaryFrame.setInt(0, prefixMedium | payloadSize); } int payloadSize = frameSize - 8; throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); @@ -154,10 +176,20 @@ public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) { @Override public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { + return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static int encodeDataFramePrefix( + ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { - byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize); + byteBuf.writeShort(prefixSmall | payloadSize); } else if (payloadSize <= 65_535) { - byteBuf.writeInt(BINARY_FRAME_MEDIUM | payloadSize); + byteBuf.writeInt(prefixMedium | payloadSize); } else { throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); } @@ -169,8 +201,22 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) { return byteBuf; } + @Override + public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + return byteBuf; + } + @Override public int sizeofBinaryFrame(int payloadSize) { + return sizeOfDataFrame(payloadSize); + } + + @Override + public int sizeofTextFrame(int textPayloadSize) { + return sizeOfDataFrame(textPayloadSize); + } + + static int sizeOfDataFrame(int payloadSize) { if (payloadSize <= 125) { return payloadSize + 2; } diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java index bbc8ee8..63d19a7 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java @@ -27,6 +27,11 @@ public interface WebSocketFrameFactory { ByteBuf createBinaryFrame(ByteBufAllocator allocator, int binaryDataSize); + default ByteBuf createTextFrame(ByteBufAllocator allocator, int textDataSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.createTextFrame() not implemented"); + } + ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason); ByteBuf createPingFrame(ByteBufAllocator allocator, int binaryDataSize); @@ -47,6 +52,16 @@ interface Encoder { ByteBuf encodeBinaryFrame(ByteBuf binaryFrame); int sizeofBinaryFrame(int payloadSize); + + default ByteBuf encodeTextFrame(ByteBuf textFrame) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.Encoder.encodeTextFrame() not implemented"); + } + + default int sizeofTextFrame(int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.Encoder.sizeofTextFrame() not implemented"); + } } /** Encodes prefixes of multiple binary websocket frames into provided bytebuffer. */ @@ -58,5 +73,21 @@ interface BulkEncoder { ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize); int sizeofBinaryFrame(int payloadSize); + + /** @return frame mask, or -1 if masking not applicable */ + default int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.encodeTextFramePrefix() not implemented"); + } + + default ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.maskTextFrame() not implemented"); + } + + default int sizeofTextFrame(int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.sizeofTextFrame() not implemented"); + } } }