Skip to content

Commit

Permalink
[FEAT] connect: df.join
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent cbe9d3b commit dd5e05c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
use crate::translation::logical_plan::{
aggregate::aggregate, join::join, project::project, range::range,
};

mod aggregate;
mod join;
mod project;
mod range;

Expand All @@ -25,6 +28,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::Join(j) => join(*j).wrap_err("Failed to apply join to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
76 changes: 76 additions & 0 deletions src/daft-connect/src/translation/logical_plan/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
use spark_connect::join::JoinType;
use tracing::warn;

use crate::translation::to_logical_plan;

pub fn join(join: spark_connect::Join) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::Join {
left,
right,
join_condition,
join_type,
using_columns,
join_data_type,
} = join;

let Some(left) = left else {
bail!("Left side of join is required");
};

let Some(right) = right else {
bail!("Right side of join is required");
};

if let Some(join_condition) = join_condition {
bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}");
}

let join_type = JoinType::try_from(join_type)
.wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?;

let join_type = to_daft_join_type(join_type)?;

let using_columns_exprs: Vec<_> = using_columns
.iter()
.map(|s| daft_dsl::col(s.as_str()))
.collect();

if let Some(join_data_type) = join_data_type {
warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented");
}

let left = to_logical_plan(*left)?;
let right = to_logical_plan(*right)?;

Ok(left.join(
&right,
// join_conditions.clone(), // todo(correctness): is this correct?
// join_conditions, // todo(correctness): is this correct?
using_columns_exprs.clone(),
using_columns_exprs,
join_type,
None,
None,
None,
false, // todo(correctness): we want join keys or not
)?)
}

fn to_daft_join_type(join_type: JoinType) -> eyre::Result<daft_core::join::JoinType> {
match join_type {
JoinType::Unspecified => {
bail!("Join type must be specified; got Unspecified")
}
JoinType::Inner => Ok(daft_core::join::JoinType::Inner),
JoinType::FullOuter => {
bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented
}
JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left), // todo(correctness): is this correct?
JoinType::RightOuter => Ok(daft_core::join::JoinType::Right),
JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti), // todo(correctness): is this correct?
JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented
JoinType::Cross => bail!("Cross joins not yet supported"), // todo(completeness): add support for cross joins if it is not already implemented
}
}
26 changes: 26 additions & 0 deletions tests/connect/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_join(spark_session):
# Create two DataFrames with overlapping IDs
df1 = spark_session.range(5)
df2 = spark_session.range(3, 7)

# Perform inner join on 'id' column
joined_df = df1.join(df2, "id", "inner")

# Verify join results using collect()
joined_ids = {row.id for row in joined_df.select("id").collect()}
assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4"

# Test left outer join
left_joined_df = df1.join(df2, "id", "left")
left_joined_ids = {row.id for row in left_joined_df.select("id").collect()}
assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame"

# Test right outer join
right_joined_df = df1.join(df2, "id", "right")
right_joined_ids = {row.id for row in right_joined_df.select("id").collect()}
assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame"

0 comments on commit dd5e05c

Please sign in to comment.