Skip to content

Commit

Permalink
[BUG] [New Executor] Drain channel in limit sink (#2401)
Browse files Browse the repository at this point in the history
Simple scan + limit currently breaks on new executor. This is because
the sink stage runner does not try to drain the receiving channel after
the task scheduler has completed execution.

Example failure:
```
import daft
import tempfile
import daft.context

with tempfile.TemporaryDirectory() as td:
    # write to parquet without AQE and native executor because native executor does not support writes yet
    daft.context.set_execution_config(enable_aqe=False, enable_native_executor=False)
    df = daft.from_pydict({"a": [1]})
    df.write_parquet(td)

    # read from parquet with AQE and native executor enabled
    daft.context.set_execution_config(enable_aqe=True, enable_native_executor=True)
    df = daft.read_parquet(td)
    df.show()
```

Error: 
```
daft.exceptions.DaftCoreException: DaftError::ValueError Need at least 1 MicroPartition to perform concat
```

Todo:
- Add better error handling
  • Loading branch information
colin-ho authored Jun 19, 2024
1 parent a8876f0 commit 4f2ffb0
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/daft-execution/src/stage/sink/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ impl<T: PartitionRef, E: Executor<T> + 'static> Sink<T> for LimitSink<T, E> {
// TODO(Clark): Send remaining rows limit to each new partition task.
loop {
tokio::select! {
// Task scheduler has completed execution, break the loop.
_ = &mut exec_fut => break,
// Task scheduler has produced a new output, send partitions received to the output channel.
Some(part) = rx.recv() => {
if let Ok(ref p) = part {
assert!(p.len() == 1);
Expand All @@ -85,5 +87,20 @@ impl<T: PartitionRef, E: Executor<T> + 'static> Sink<T> for LimitSink<T, E> {
},
};
}

// If the limit has not been met, check to see if there are remaining outputs from the task scheduler and send them to the output channel.
if running_num_rows < limit {
while let Some(part) = rx.recv().await {
if let Ok(ref p) = part {
assert!(p.len() == 1);
running_num_rows += p[0].metadata().num_rows.unwrap_or(0);
}
output_channel.send(part).await.unwrap();
// If we've met the limit, early-terminate execution.
if running_num_rows >= limit {
break;
}
}
}
}
}

0 comments on commit 4f2ffb0

Please sign in to comment.