Skip to content

Commit

Permalink
Fix io close (#12)
Browse files Browse the repository at this point in the history
* Fix io close for Framed
* Fix connection shutdown for h1 dispatcher
* Enable client disconnect for http server by default
* Add connection disconnect timeout to framed service
  • Loading branch information
fafhrd91 authored Apr 7, 2020
1 parent 8a753a7 commit 3b12a77
Show file tree
Hide file tree
Showing 21 changed files with 528 additions and 168 deletions.
6 changes: 6 additions & 0 deletions ntex-codec/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changes

## [0.1.1] - 2020-04-07

* Optimize io operations

* Fix framed close method

## [0.1.0] - 2020-03-31

* Fork crate to ntex namespace
Expand Down
12 changes: 8 additions & 4 deletions ntex-codec/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[package]
name = "ntex-codec"
version = "0.1.0"
version = "0.1.1"
authors = ["Nikolay Kim <[email protected]>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
homepage = "https://ntex.rs"
repository = "https://github.com/ntex-rs/ntex.git"
documentation = "https://docs.rs/ntex-codec/"
categories = ["network-programming", "asynchronous"]
license = "MIT/Apache-2.0"
license = "MIT"
edition = "2018"

[lib]
Expand All @@ -20,6 +20,10 @@ bitflags = "1.2.1"
bytes = "0.5.4"
futures-core = "0.3.4"
futures-sink = "0.3.4"
tokio = { version = "0.2.4", default-features=false }
tokio = { version = "0.2.6", default-features=false }
tokio-util = { version = "0.2.0", default-features=false, features=["codec"] }
log = "0.4"
log = "0.4"

[dev-dependencies]
ntex = "0.1.4"
futures = "0.3.4"
259 changes: 211 additions & 48 deletions ntex-codec/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ const HW: usize = 8 * 1024;

bitflags::bitflags! {
struct Flags: u8 {
const EOF = 0b0001;
const READABLE = 0b0010;
const EOF = 0b0001;
const READABLE = 0b0010;
const DISCONNECTED = 0b0100;
const SHUTDOWN = 0b1000;
}
}

/// A unified `Stream` and `Sink` interface to an underlying I/O object, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
/// `Framed` is heavily optimized for streaming io.
pub struct Framed<T, U> {
io: T,
codec: U,
Expand All @@ -28,8 +31,6 @@ pub struct Framed<T, U> {
write_buf: BytesMut,
}

impl<T, U> Unpin for Framed<T, U> {}

impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
Expand Down Expand Up @@ -123,6 +124,18 @@ impl<T, U> Framed<T, U> {
&mut self.io
}

#[inline]
/// Get read buffer.
pub fn read_buf_mut(&mut self) -> &mut BytesMut {
&mut self.read_buf
}

#[inline]
/// Get write buffer.
pub fn write_buf_mut(&mut self) -> &mut BytesMut {
&mut self.write_buf
}

#[inline]
/// Check if write buffer is empty.
pub fn is_write_buf_empty(&self) -> bool {
Expand All @@ -135,6 +148,12 @@ impl<T, U> Framed<T, U> {
self.write_buf.len() >= HW
}

#[inline]
/// Check if framed object is closed
pub fn is_closed(&self) -> bool {
self.flags.contains(Flags::DISCONNECTED)
}

#[inline]
/// Consume the `Frame`, returning `Frame` with different codec.
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
Expand Down Expand Up @@ -227,34 +246,87 @@ where
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
log::trace!("flushing framed transport");

while !self.write_buf.is_empty() {
log::trace!("writing; remaining={}", self.write_buf.len());
let len = self.write_buf.len();
if len == 0 {
return Poll::Ready(Ok(()));
}

let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.write_buf))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)
.into()));
let mut written = 0;
while written < len {
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("Disconnected during flush, written {}", written);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)
.into()));
} else {
written += n
}
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(e.into()));
}
}

// remove written data
self.write_buf.advance(n);
}

// Try flushing the underlying IO
ready!(Pin::new(&mut self.io).poll_flush(cx))?;

log::trace!("framed transport flushed");
Poll::Ready(Ok(()))
// remove written data
if written == len {
// flushed same amount as in buffer, we dont need to reallocate
unsafe { self.write_buf.set_len(0) }
} else {
self.write_buf.advance(written);
}
if self.write_buf.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}

impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
/// Flush write buffer and shutdown underlying I/O stream.
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
ready!(Pin::new(&mut self.io).poll_shutdown(cx))?;
///
/// Close method shutdown write side of a io object and
/// then reads until disconnect or error, high level code must use
/// timeout for close operation.
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if !self.flags.contains(Flags::DISCONNECTED) {
// flush write buffer
ready!(Pin::new(&mut self.io).poll_flush(cx))?;

if !self.flags.contains(Flags::SHUTDOWN) {
// shutdown WRITE side
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
self.flags.insert(Flags::DISCONNECTED);
e
})?;
self.flags.insert(Flags::SHUTDOWN);
}

// read until 0 or err
let mut buf = [0u8; 512];
loop {
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) {
Err(_) | Ok(0) => {
break;
}
_ => (),
}
}
self.flags.insert(Flags::DISCONNECTED);
}
log::trace!("framed transport flushed and closed");
Poll::Ready(Ok(()))
}
Expand All @@ -269,11 +341,9 @@ where
pub fn next_item(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<U::Item, U::Error>>>
where
T: AsyncRead,
U: Decoder,
{
) -> Poll<Option<Result<U::Item, U::Error>>> {
let mut done_read = false;

loop {
// Repeatedly call `decode` or `decode_eof` as long as it is
// "readable". Readable is defined as not having returned `None`. If
Expand Down Expand Up @@ -302,34 +372,53 @@ where
}

self.flags.remove(Flags::READABLE);
if done_read {
return Poll::Pending;
}
}

debug_assert!(!self.flags.contains(Flags::EOF));

// Otherwise, try to read more data and try again. Make sure we've got room
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
let cnt = match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf)
{
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Ok(cnt)) => cnt,
};

if cnt == 0 {
self.flags.insert(Flags::EOF);
// read all data from socket
let mut updated = false;
loop {
// Otherwise, try to read more data and try again. Make sure we've got room
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
Poll::Pending => {
if updated {
done_read = true;
self.flags.insert(Flags::READABLE);
break;
} else {
return Poll::Pending;
}
}
Poll::Ready(Ok(n)) => {
if n == 0 {
self.flags.insert(Flags::EOF | Flags::READABLE);
if updated {
done_read = true;
}
break;
} else {
updated = true;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
}
}
self.flags.insert(Flags::READABLE);
}
}
}

impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
U: Decoder + Unpin,
{
type Item = Result<U::Item, U::Error>;

Expand All @@ -344,8 +433,8 @@ where

impl<T, U> Sink<U::Item> for Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Unpin,
U::Error: From<io::Error>,
{
type Error = U::Error;
Expand Down Expand Up @@ -383,7 +472,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close(cx)
self.close(cx).map_err(|e| e.into())
}
}

Expand Down Expand Up @@ -443,3 +532,77 @@ impl<T, U> FramedParts<T, U> {
}
}
}

#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::future::lazy;
use futures::Sink;
use ntex::testing::Io;

use super::*;
use crate::BytesCodec;

#[ntex::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
let mut server = Framed::new(server, BytesCodec);

assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());

let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
assert_eq!(client.read_any(), b"".as_ref());

assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());

assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
}

#[ntex::test]
async fn test_write_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);

assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();

client.remote_buffer_cap(3);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_pending());
assert_eq!(client.read_any(), b"GET".as_ref());

client.remote_buffer_cap(1024);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());

assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
assert!(server.is_closed());
}
}
Loading

0 comments on commit 3b12a77

Please sign in to comment.