Skip to content

Commit

Permalink
[FEAT] Add sample function for Dataframe (#1770)
Browse files Browse the repository at this point in the history
Closes #1759 

Added new sample function based on
https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.DataFrame.sample.html

Changes:
- Modified existing sampling logic in table and micropartition to
instead accept fraction, with_replacement, and seed.
- Added logical + physical ops for sample
- Added end to end tests
  • Loading branch information
colin-ho authored Jan 11, 2024
1 parent cb9e134 commit c4bd1b3
Show file tree
Hide file tree
Showing 25 changed files with 489 additions and 36 deletions.
9 changes: 7 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,8 @@ class PyTable:
def join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ...
def explode(self, to_explode: list[PyExpr]) -> PyTable: ...
def head(self, num: int) -> PyTable: ...
def sample(self, num: int) -> PyTable: ...
def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyTable: ...
def sample_by_size(self, size: int, with_replacement: bool, seed: int | None) -> PyTable: ...
def quantiles(self, num: int) -> PyTable: ...
def partition_by_hash(self, exprs: list[PyExpr], num_partitions: int) -> list[PyTable]: ...
def partition_by_random(self, num_partitions: int, seed: int) -> list[PyTable]: ...
Expand Down Expand Up @@ -1002,7 +1003,8 @@ class PyMicroPartition:
def join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ...
def explode(self, to_explode: list[PyExpr]) -> PyMicroPartition: ...
def head(self, num: int) -> PyMicroPartition: ...
def sample(self, num: int) -> PyMicroPartition: ...
def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyMicroPartition: ...
def sample_by_size(self, size: int, with_replacement: bool, seed: int | None) -> PyMicroPartition: ...
def quantiles(self, num: int) -> PyMicroPartition: ...
def partition_by_hash(self, exprs: list[PyExpr], num_partitions: int) -> list[PyMicroPartition]: ...
def partition_by_random(self, num_partitions: int, seed: int) -> list[PyMicroPartition]: ...
Expand Down Expand Up @@ -1101,6 +1103,7 @@ class LogicalPlanBuilder:
) -> LogicalPlanBuilder: ...
def coalesce(self, num_partitions: int) -> LogicalPlanBuilder: ...
def distinct(self) -> LogicalPlanBuilder: ...
def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> LogicalPlanBuilder: ...
def aggregate(self, agg_exprs: list[PyExpr], groupby_exprs: list[PyExpr]) -> LogicalPlanBuilder: ...
def join(
self, right: LogicalPlanBuilder, left_on: list[PyExpr], right_on: list[PyExpr], join_type: JoinType
Expand Down Expand Up @@ -1132,6 +1135,8 @@ class PyDaftExecutionConfig:
def merge_scan_tasks_max_size_bytes(self): ...
@property
def broadcast_join_size_bytes_threshold(self): ...
@property
def sample_size_for_sort(self): ...

class PyDaftPlanningConfig:
def with_config_values(
Expand Down
21 changes: 21 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,27 @@ def distinct(self) -> "DataFrame":
builder = self._builder.distinct()
return DataFrame(builder)

@DataframePublicAPI
def sample(self, fraction: float, with_replacement: bool = False, seed: Optional[int] = None) -> "DataFrame":
"""Samples a fraction of rows from the DataFrame
Example:
>>> sampled_df = df.sample(0.5)
Args:
fraction (float): fraction of rows to sample.
with_replacement (bool, optional): whether to sample with replacement. Defaults to False.
seed (Optional[int], optional): random seed. Defaults to None.
Returns:
DataFrame: DataFrame with a fraction of rows.
"""
if fraction < 0.0 or fraction > 1.0:
raise ValueError(f"fraction should be between 0.0 and 1.0, but got {fraction}")

builder = self._builder.sample(fraction, with_replacement, seed)
return DataFrame(builder)

@DataframePublicAPI
def exclude(self, *names: str) -> "DataFrame":
"""Drops columns from the current DataFrame by name
Expand Down
20 changes: 13 additions & 7 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,19 +580,25 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])

@dataclass(frozen=True)
class Sample(SingleOutputInstruction):
sort_by: ExpressionsProjection
num_samples: int = 20
fraction: float | None = None
size: int | None = None
with_replacement: bool = False
seed: int | None = None
sort_by: ExpressionsProjection | None = None

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
return self._sample(inputs)

def _sample(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
[input] = inputs
result = (
input.sample(self.num_samples)
.eval_expression_list(self.sort_by)
.filter(ExpressionsProjection([~col(e.name()).is_null() for e in self.sort_by]))
)
if self.sort_by:
result = (
input.sample(self.fraction, self.size, self.with_replacement, self.seed)
.eval_expression_list(self.sort_by)
.filter(ExpressionsProjection([~col(e.name()).is_null() for e in self.sort_by]))
)
else:
result = input.sample(self.fraction, self.size, self.with_replacement, self.seed)
return [result]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
Expand Down
4 changes: 3 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import deque
from typing import Generator, Iterator, TypeVar, Union

from daft.context import get_context
from daft.daft import (
FileFormat,
FileFormatConfig,
Expand Down Expand Up @@ -766,6 +767,7 @@ def sort(
sample_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque()
stage_id_sampling = next(stage_id_counter)

sample_size = get_context().daft_execution_config.sample_size_for_sort
for source in source_materializations:
while not source.done():
logger.debug("sort blocked on completion of source: %s", source)
Expand All @@ -777,7 +779,7 @@ def sort(
partial_metadatas=None,
)
.add_instruction(
instruction=execution_step.Sample(sort_by=sort_by),
instruction=execution_step.Sample(size=sample_size, sort_by=sort_by),
)
.finalize_partition_task_single_output(stage_id=stage_id_sampling)
)
Expand Down
10 changes: 10 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def local_aggregate(
)


def sample(
input: physical_plan.InProgressPhysicalPlan[PartitionT], fraction: float, with_replacement: bool, seed: int | None
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
return physical_plan.pipeline_instruction(
child_plan=input,
pipeable_instruction=execution_step.Sample(fraction=fraction, with_replacement=with_replacement, seed=seed),
resource_request=ResourceRequest(),
)


def sort(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
sort_by: list[PyExpr],
Expand Down
4 changes: 4 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def distinct(self) -> LogicalPlanBuilder:
builder = self._builder.distinct()
return LogicalPlanBuilder(builder)

def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> LogicalPlanBuilder:
builder = self._builder.sample(fraction, with_replacement, seed)
return LogicalPlanBuilder(builder)

def sort(self, sort_by: list[Expression], descending: list[bool] | bool = False) -> LogicalPlanBuilder:
sort_by_pyexprs = [expr._expr for expr in sort_by]
if not isinstance(descending, list):
Expand Down
21 changes: 19 additions & 2 deletions daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,25 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return MicroPartition._from_pymicropartition(self._micropartition.sort(pyexprs, descending))

def sample(self, num: int) -> MicroPartition:
return MicroPartition._from_pymicropartition(self._micropartition.sample(num))
def sample(
self,
fraction: float | None = None,
size: int | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> MicroPartition:
if fraction is not None and size is not None:
raise ValueError("Must specify either `fraction` or `size`, but not both")
elif fraction is not None:
return MicroPartition._from_pymicropartition(
self._micropartition.sample_by_fraction(float(fraction), with_replacement, seed)
)
elif size is not None:
return MicroPartition._from_pymicropartition(
self._micropartition.sample_by_size(size, with_replacement, seed)
)
else:
raise ValueError("Must specify either `fraction` or `size`")

def agg(self, to_agg: list[Expression], group_by: ExpressionsProjection | None = None) -> MicroPartition:
to_agg_pyexprs = [e._expr for e in to_agg]
Expand Down
17 changes: 15 additions & 2 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,21 @@ def sort(self, sort_keys: ExpressionsProjection, descending: bool | list[bool] |
raise TypeError(f"Expected a bool, list[bool] or None for `descending` but got {type(descending)}")
return Table._from_pytable(self._table.sort(pyexprs, descending))

def sample(self, num: int) -> Table:
return Table._from_pytable(self._table.sample(num))
def sample(
self,
fraction: float | None = None,
size: int | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> Table:
if fraction is not None and size is not None:
raise ValueError("Must specify either `fraction` or `size`, but not both")
elif fraction is not None:
return Table._from_pytable(self._table.sample_by_fraction(fraction, with_replacement, seed))
elif size is not None:
return Table._from_pytable(self._table.sample_by_size(size, with_replacement, seed))
else:
raise ValueError("Must specify either `fraction` or `size`")

def agg(self, to_agg: list[Expression], group_by: ExpressionsProjection | None = None) -> Table:
to_agg_pyexprs = [e._expr for e in to_agg]
Expand Down
2 changes: 2 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct DaftExecutionConfig {
pub merge_scan_tasks_min_size_bytes: usize,
pub merge_scan_tasks_max_size_bytes: usize,
pub broadcast_join_size_bytes_threshold: usize,
pub sample_size_for_sort: usize,
}

impl Default for DaftExecutionConfig {
Expand All @@ -33,6 +34,7 @@ impl Default for DaftExecutionConfig {
merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64MB
merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512MB
broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB
sample_size_for_sort: 20,
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ impl PyDaftExecutionConfig {
Ok(self.config.broadcast_join_size_bytes_threshold)
}

#[getter]
fn get_sample_size_for_sort(&self) -> PyResult<usize> {
Ok(self.config.sample_size_for_sort)
}

fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec<u8>,))> {
let bin_data = bincode::serialize(self.config.as_ref())
.expect("DaftExecutionConfig should be serializable to bytes");
Expand Down
41 changes: 37 additions & 4 deletions src/daft-micropartition/src/ops/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ impl MicroPartition {
}
}

pub fn sample(&self, num: usize) -> DaftResult<Self> {
let io_stats = IOStatsContext::new(format!("MicroPartition::sample({num})"));
pub fn sample_by_fraction(
&self,
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new(format!("MicroPartition::sample({fraction})"));

if num == 0 {
if fraction == 0.0 {
return Ok(Self::empty(Some(self.schema.clone())));
}

Expand All @@ -51,7 +56,35 @@ impl MicroPartition {
match tables.as_slice() {
[] => Ok(Self::empty(Some(self.schema.clone()))),
[single] => {
let taken = single.sample(num)?;
let taken = single.sample_by_fraction(fraction, with_replacement, seed)?;
Ok(Self::new_loaded(
self.schema.clone(),
Arc::new(vec![taken]),
self.statistics.clone(),
))
}
_ => unreachable!(),
}
}

pub fn sample_by_size(
&self,
size: usize,
with_replacement: bool,
seed: Option<u64>,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new(format!("MicroPartition::sample({size})"));

if size == 0 {
return Ok(Self::empty(Some(self.schema.clone())));
}

let tables = self.concat_or_get(io_stats)?;

match tables.as_slice() {
[] => Ok(Self::empty(Some(self.schema.clone()))),
[single] => {
let taken = single.sample(size, with_replacement, seed)?;
Ok(Self::new_loaded(
self.schema.clone(),
Arc::new(vec![taken]),
Expand Down
42 changes: 38 additions & 4 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,48 @@ impl PyMicroPartition {
})
}

pub fn sample(&self, py: Python, num: i64) -> PyResult<Self> {
pub fn sample_by_fraction(
&self,
py: Python,
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
) -> PyResult<Self> {
py.allow_threads(|| {
if num < 0 {
if fraction < 0.0 {
return Err(PyValueError::new_err(format!(
"Can not sample table with negative fraction: {fraction}"
)));
}
if fraction > 1.0 {
return Err(PyValueError::new_err(format!(
"Can not sample table with negative number: {num}"
"Can not sample table with fraction greater than 1.0: {fraction}"
)));
}
Ok(self.inner.sample(num as usize)?.into())
Ok(self
.inner
.sample_by_fraction(fraction, with_replacement, seed)?
.into())
})
}

pub fn sample_by_size(
&self,
py: Python,
size: i64,
with_replacement: bool,
seed: Option<u64>,
) -> PyResult<Self> {
py.allow_threads(|| {
if size < 0 {
return Err(PyValueError::new_err(format!(
"Can not sample table with negative size: {size}"
)));
}
Ok(self
.inner
.sample_by_size(size as usize, with_replacement, seed)?
.into())
})
}

Expand Down
23 changes: 23 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ impl LogicalPlanBuilder {
Ok(logical_plan.into())
}

pub fn sample(
&self,
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
logical_ops::Sample::new(self.plan.clone(), fraction, with_replacement, seed).into();
Ok(logical_plan.into())
}

pub fn aggregate(&self, agg_exprs: Vec<Expr>, groupby_exprs: Vec<Expr>) -> DaftResult<Self> {
let agg_exprs = agg_exprs
.iter()
Expand Down Expand Up @@ -379,6 +390,18 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.distinct()?.into())
}

pub fn sample(
&self,
fraction: f64,
with_replacement: bool,
seed: Option<u64>,
) -> PyResult<Self> {
Ok(self
.builder
.sample(fraction, with_replacement, seed)?
.into())
}

pub fn aggregate(&self, agg_exprs: Vec<PyExpr>, groupby_exprs: Vec<PyExpr>) -> PyResult<Self> {
let agg_exprs = agg_exprs
.iter()
Expand Down
2 changes: 2 additions & 0 deletions src/daft-plan/src/logical_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod join;
mod limit;
mod project;
mod repartition;
mod sample;
mod sink;
mod sort;
mod source;
Expand All @@ -20,6 +21,7 @@ pub use join::Join;
pub use limit::Limit;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Loading

0 comments on commit c4bd1b3

Please sign in to comment.