Skip to content

Commit

Permalink
[FEAT]: Support intersect all and except distinct/all in DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Dec 19, 2024
1 parent ae74c10 commit eeb1282
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 102 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,7 @@ class LogicalPlanBuilder:
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def except_(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
def table_write(
self,
Expand Down
30 changes: 30 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,36 @@ def intersect(self, other: "DataFrame") -> "DataFrame":
builder = self._builder.intersect(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def intersect_all(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames, including duplicates.
:param other:
:return:
"""
builder = self._builder.intersect_all(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def except_distinct(self, other: "DataFrame") -> "DataFrame":
"""
:param other:
:return:
"""
builder = self._builder.except_distinct(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def except_all(self, other: "DataFrame") -> "DataFrame":
"""
:param other:
:return:
"""
builder = self._builder.except_all(other._builder)
return DataFrame(builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
12 changes: 12 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)

def intersect_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, True)
return LogicalPlanBuilder(builder)

def except_distinct(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.except_(other._builder, False)
return LogicalPlanBuilder(builder)

def except_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.except_(other._builder, True)
return LogicalPlanBuilder(builder)

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
builder = self._builder.add_monotonically_increasing_id(column_name)
return LogicalPlanBuilder(builder)
Expand Down
46 changes: 45 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{iter::repeat, sync::Arc};

use arrow2::offset::OffsetsBuffer;
use arrow2::offset::{Offsets, OffsetsBuffer};
use common_error::DaftResult;
use indexmap::{
map::{raw_entry_v1::RawEntryMut, RawEntryApiV1},
Expand Down Expand Up @@ -255,6 +255,31 @@ fn list_sort_helper_fixed_size(
.collect()
}

fn general_list_fill_helper(element: &Series, num_array: &Int64Array) -> DaftResult<Vec<Series>> {
let num_iter = create_iter(num_array, element.len());
let mut result = vec![];
let element_data = element.as_physical()?;
for (row_index, num) in num_iter.enumerate() {
let list_arr = if element.is_valid(row_index) {
let mut list_growable = make_growable(
element.name(),
element.data_type(),
vec![&element_data],
false,
num as usize,
);
for _ in 0..num {
list_growable.extend(0, row_index, 1);
}
list_growable.build()?
} else {
Series::full_null(element.name(), element.data_type(), num as usize)
};
result.push(list_arr);
}
Ok(result)
}

impl ListArray {
pub fn value_counts(&self) -> DaftResult<MapArray> {
struct IndexRef {
Expand Down Expand Up @@ -625,6 +650,25 @@ impl ListArray {
self.validity().cloned(),
))
}

pub fn list_fill(elem: &Series, num_array: &Int64Array) -> DaftResult<Self> {
let generated = general_list_fill_helper(elem, num_array)?;
let generated_refs: Vec<&Series> = generated.iter().collect();
let lengths = generated.iter().map(|arr| arr.len());
let offsets = Offsets::try_from_lengths(lengths)?;
let flat_child = if generated_refs.is_empty() {
// when there's no output, we should create an empty series
Series::empty(elem.name(), elem.data_type())
} else {
Series::concat(&generated_refs)?
};
Ok(Self::new(
elem.field().to_list_field()?,
flat_child,
offsets.into(),
None,
))
}
}

impl FixedSizeListArray {
Expand Down
13 changes: 12 additions & 1 deletion src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use common_error::{DaftError, DaftResult};
use daft_schema::field::Field;

use crate::{
array::ListArray,
datatypes::{DataType, UInt64Array, Utf8Array},
prelude::CountMode,
prelude::{CountMode, Int64Array},
series::{IntoSeries, Series},
};

Expand Down Expand Up @@ -217,4 +218,14 @@ impl Series {
))),
}
}

/// Given a series of data T, repeat each data T with num times to create a list, returns
/// a series of repeated list.
/// # Example
/// ```txt
/// repeat([1, 2, 3], [2, 0, 1]) --> [[1, 1], [], [3]]
/// ```
pub fn list_fill(&self, num: &Int64Array) -> DaftResult<Self> {
ListArray::list_fill(self, num).map(|arr| arr.into_series())
}
}
63 changes: 63 additions & 0 deletions src/daft-functions/src/list/list_fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::{DataType, Field},
prelude::{Schema, Series},
};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListFill {}

#[typetag::serde]
impl ScalarUDF for ListFill {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &'static str {
"fill"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[n, elem] => {
let num_field = n.to_field(schema)?;
let elem_field = elem.to_field(schema)?;
if !num_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected num field to be of numeric type, received: {}",
num_field.dtype
)));
}
elem_field.to_list_field()
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[num, elem] => {
let num = num.cast(&DataType::Int64)?;
let num_array = num.i64()?;
elem.list_fill(num_array)
}
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}

#[must_use]
pub fn list_fill(n: ExprRef, elem: ExprRef) -> ExprRef {
ScalarFunction::new(ListFill {}, vec![n, elem]).into()
}
2 changes: 2 additions & 0 deletions src/daft-functions/src/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod count;
mod explode;
mod get;
mod join;
mod list_fill;
mod max;
mod mean;
mod min;
Expand All @@ -17,6 +18,7 @@ pub use count::{list_count as count, ListCount};
pub use explode::{explode, Explode};
pub use get::{list_get as get, ListGet};
pub use join::{list_join as join, ListJoin};
pub use list_fill::list_fill;
pub use max::{list_max as max, ListMax};
pub use mean::{list_mean as mean, ListMean};
pub use min::{list_min as min, ListMin};
Expand Down
15 changes: 14 additions & 1 deletion src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,17 @@ impl LogicalPlanBuilder {
pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_optimized_join()?;
.to_logical_plan()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Except::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_logical_plan()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn union(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)?
Expand Down Expand Up @@ -861,6 +869,11 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}

#[pyo3(name = "except_")]
pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.except(&other.builder, is_all)?.into())
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
Ok(self
.builder
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::{Intersect, Union};
pub use set_operations::{Except, Intersect, Union};
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
Loading

0 comments on commit eeb1282

Please sign in to comment.