Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Encoder state machine tolerates being wrapped by an AsyncWrite #309

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ proptest-derive = "0.5"
rand = "0.8.5"
tokio = { version = "1.24.2", default-features = false, features = ["io-util", "macros", "rt-multi-thread", "io-std"] }
tokio-util = { version = "0.7", default-features = false, features = ["io"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

[[test]]
name = "brotli"
Expand Down
145 changes: 36 additions & 109 deletions src/tokio/write/generic/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,13 @@ use futures_core::ready;
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

#[derive(Debug)]
enum State {
Encoding,
Finishing,
Done,
}

pin_project! {
#[derive(Debug)]
pub struct Encoder<W, E> {
#[pin]
writer: BufWriter<W>,
encoder: E,
state: State,
finished: bool
NobodyXu marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand All @@ -35,7 +28,7 @@ impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
Self {
writer: BufWriter::new(writer),
encoder,
state: State::Encoding,
finished: false,
}
}
}
Expand All @@ -62,97 +55,6 @@ impl<W, E> Encoder<W, E> {
}
}

impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
fn do_poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
input: &mut PartialBuffer<&[u8]>,
) -> Poll<io::Result<()>> {
let mut this = self.project();

loop {
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
let mut output = PartialBuffer::new(output);

*this.state = match this.state {
State::Encoding => {
this.encoder.encode(input, &mut output)?;
State::Encoding
}

State::Finishing | State::Done => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"Write after shutdown",
)))
}
};

let produced = output.written().len();
this.writer.as_mut().produce(produced);

if input.unwritten().is_empty() {
return Poll::Ready(Ok(()));
}
}
}

fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();

loop {
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
let mut output = PartialBuffer::new(output);

let done = match this.state {
State::Encoding => this.encoder.flush(&mut output)?,

State::Finishing | State::Done => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"Flush after shutdown",
)))
}
};

let produced = output.written().len();
this.writer.as_mut().produce(produced);

if done {
return Poll::Ready(Ok(()));
}
}
}

fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();

loop {
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
let mut output = PartialBuffer::new(output);

*this.state = match this.state {
State::Encoding | State::Finishing => {
if this.encoder.finish(&mut output)? {
State::Done
} else {
State::Finishing
}
}

State::Done => State::Done,
};

let produced = output.written().len();
this.writer.as_mut().produce(produced);

if let State::Done = this.state {
return Poll::Ready(Ok(()));
}
}
}
}

impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
fn poll_write(
self: Pin<&mut Self>,
Expand All @@ -163,24 +65,49 @@ impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
return Poll::Ready(Ok(0));
}

let mut input = PartialBuffer::new(buf);
let mut this = self.project();

let mut encodeme = PartialBuffer::new(buf);

match self.do_poll_write(cx, &mut input)? {
Poll::Pending if input.written().is_empty() => Poll::Pending,
_ => Poll::Ready(Ok(input.written().len())),
loop {
let mut space =
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
this.encoder.encode(&mut encodeme, &mut space)?;
let bytes_encoded = space.written().len();
this.writer.as_mut().produce(bytes_encoded);
if encodeme.unwritten().is_empty() {
break;
}
}

Poll::Ready(Ok(encodeme.written().len()))
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().do_poll_flush(cx))?;
ready!(self.project().writer.as_mut().poll_flush(cx))?;
let mut this = self.project();
loop {
let mut space =
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
let flushed = this.encoder.flush(&mut space)?;
let bytes_encoded = space.written().len();
this.writer.as_mut().produce(bytes_encoded);
if flushed {
break;
}
}
Poll::Ready(Ok(()))
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().do_poll_shutdown(cx))?;
ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
Poll::Ready(Ok(()))
let mut this = self.project();
while !*this.finished {
let mut space =
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
*this.finished = this.encoder.finish(&mut space)?;
let bytes_encoded = space.written().len();
this.writer.as_mut().produce(bytes_encoded);
}
this.writer.poll_shutdown(cx)
}
}

Expand Down
153 changes: 153 additions & 0 deletions tests/issues.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#![cfg(all(feature = "tokio", feature = "zstd"))]
NobodyXu marked this conversation as resolved.
Show resolved Hide resolved

use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};

use async_compression::tokio::write::ZstdEncoder;
use tokio::io::{AsyncWrite, AsyncWriteExt as _};
use tracing_subscriber::fmt::format::FmtSpan;

/// This issue covers our state machine being invalid when using adapters
/// like [`tokio_util::codec`].
///
/// After the first [`poll_shutdown`] call,
/// we must expect any number of [`poll_flush`] and [`poll_shutdown`] calls,
/// until [`poll_shutdown`] returns [`Poll::Ready`],
/// according to the documentation on [`AsyncWrite`].
///
/// <https://github.com/Nullus157/async-compression/issues/246>
///
/// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec)
/// [`poll_shutdown`](AsyncWrite::poll_shutdown)
/// [`poll_flush`](AsyncWrite::poll_flush)
#[test]
fn issue_246() {
tracing_subscriber::fmt()
.without_time()
.with_ansi(false)
.with_level(false)
.with_test_writer()
.with_target(false)
.with_span_events(FmtSpan::NEW)
.init();
let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default())));
futures::executor::block_on(zstd_encoder.shutdown()).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should test poll_shutdown by keeping track of numbers of calls to underlying poll_flush?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that feels a bit excessive - we know that poll_shutdown has been called twice because of how DelayedShutdown is implemented (only returns Poll::Ready the second time), and that's what we really want to test.

i.e I don't think a change like this actually does anything for the test:

-    let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default())));
+    let mut delayed_shutdown = DelayedShutdown::default();
+    let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(&mut delayed_shutdown)));
     futures::executor::block_on(zstd_encoder.shutdown()).unwrap();
+    assert_eq!(delayed_shutdown.num_times_shutdown_called, 1);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd want the poll_flush to be tested, by making sure it is called if and only if poll_shutdown is called.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) could you help me write this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit busy now, but in general, I think we want something like this:

struct DummyWriter {
    is_poll_flush_called: boolean,
    is_poll_shutdown_called: usize,
}

impl AsyncWrite {
    fn poll_flush(...) {
        assert!(!self.is_poll_shutdown_called);
        self.is_poll_flush_called = true;
        Ready(Ok())
    }

    fn poll_shutdown(...) {
        self.is_poll_shutdown_called = true;
        Ready(Ok())
    }
}

And then after shutdown is called, we check that both is set to true to verify this fix.

}

pin_project_lite::pin_project! {
/// A simple wrapper struct that follows the [`AsyncWrite`] protocol.
/// This is a stand-in for combinators like `tokio_util::codec`s
struct Wrapper<T> {
#[pin] inner: T
}
}

impl<T> Wrapper<T> {
fn new(inner: T) -> Self {
Self { inner }
}
}

impl<T: AsyncWrite> AsyncWrite for Wrapper<T> {
#[tracing::instrument(name = "Wrapper::poll_write", skip_all, ret)]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}

#[tracing::instrument(name = "Wrapper::poll_flush", skip_all, ret)]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}

/// To quote the [`AsyncWrite`] docs:
/// > Invocation of a shutdown implies an invocation of flush.
/// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened.
/// > That is, callers don't need to call flush before calling shutdown.
/// > They can rely that by calling shutdown any pending buffered data will be written out.
#[tracing::instrument(name = "Wrapper::poll_shutdown", skip_all, ret)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let mut this = self.project();
ready!(this.inner.as_mut().poll_flush(cx))?;
this.inner.poll_shutdown(cx)
}
}

pin_project_lite::pin_project! {
/// Yields [`Poll::Pending`] the first time [`AsyncWrite::poll_shutdown`] is called.
#[derive(Default)]
struct DelayedShutdown {
contents: Vec<u8>,
num_times_shutdown_called: u8,
}
}

impl AsyncWrite for DelayedShutdown {
#[tracing::instrument(name = "DelayedShutdown::poll_write", skip_all, ret)]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let _ = cx;
self.project().contents.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}

#[tracing::instrument(name = "DelayedShutdown::poll_flush", skip_all, ret)]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}

#[tracing::instrument(name = "DelayedShutdown::poll_shutdown", skip_all, ret)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project().num_times_shutdown_called {
it @ 0 => {
*it += 1;
cx.waker().wake_by_ref();
Poll::Pending
}
_ => Poll::Ready(Ok(())),
}
}
}

pin_project_lite::pin_project! {
/// A wrapper which traces all calls
struct Trace<T> {
#[pin] inner: T
}
}

impl<T> Trace<T> {
fn new(inner: T) -> Self {
Self { inner }
}
}

impl<T: AsyncWrite> AsyncWrite for Trace<T> {
#[tracing::instrument(name = "Trace::poll_write", skip_all, ret)]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[tracing::instrument(name = "Trace::poll_flush", skip_all, ret)]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}

#[tracing::instrument(name = "Trace::poll_shutdown", skip_all, ret)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}