-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cbe9d3b
commit dd5e05c
Showing
3 changed files
with
107 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
use daft_logical_plan::LogicalPlanBuilder; | ||
use eyre::{bail, WrapErr}; | ||
use spark_connect::join::JoinType; | ||
use tracing::warn; | ||
|
||
use crate::translation::to_logical_plan; | ||
|
||
pub fn join(join: spark_connect::Join) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::Join { | ||
left, | ||
right, | ||
join_condition, | ||
join_type, | ||
using_columns, | ||
join_data_type, | ||
} = join; | ||
|
||
let Some(left) = left else { | ||
bail!("Left side of join is required"); | ||
}; | ||
|
||
let Some(right) = right else { | ||
bail!("Right side of join is required"); | ||
}; | ||
|
||
if let Some(join_condition) = join_condition { | ||
bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}"); | ||
} | ||
|
||
let join_type = JoinType::try_from(join_type) | ||
.wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?; | ||
|
||
let join_type = to_daft_join_type(join_type)?; | ||
|
||
let using_columns_exprs: Vec<_> = using_columns | ||
.iter() | ||
.map(|s| daft_dsl::col(s.as_str())) | ||
.collect(); | ||
|
||
if let Some(join_data_type) = join_data_type { | ||
warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented"); | ||
} | ||
|
||
let left = to_logical_plan(*left)?; | ||
let right = to_logical_plan(*right)?; | ||
|
||
Ok(left.join( | ||
&right, | ||
// join_conditions.clone(), // todo(correctness): is this correct? | ||
// join_conditions, // todo(correctness): is this correct? | ||
using_columns_exprs.clone(), | ||
using_columns_exprs, | ||
join_type, | ||
None, | ||
None, | ||
None, | ||
false, // todo(correctness): we want join keys or not | ||
)?) | ||
} | ||
|
||
fn to_daft_join_type(join_type: JoinType) -> eyre::Result<daft_core::join::JoinType> { | ||
match join_type { | ||
JoinType::Unspecified => { | ||
bail!("Join type must be specified; got Unspecified") | ||
} | ||
JoinType::Inner => Ok(daft_core::join::JoinType::Inner), | ||
JoinType::FullOuter => { | ||
bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented | ||
} | ||
JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left), // todo(correctness): is this correct? | ||
JoinType::RightOuter => Ok(daft_core::join::JoinType::Right), | ||
JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti), // todo(correctness): is this correct? | ||
JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented | ||
JoinType::Cross => bail!("Cross joins not yet supported"), // todo(completeness): add support for cross joins if it is not already implemented | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import annotations | ||
|
||
from pyspark.sql.functions import col | ||
|
||
|
||
def test_join(spark_session): | ||
# Create two DataFrames with overlapping IDs | ||
df1 = spark_session.range(5) | ||
df2 = spark_session.range(3, 7) | ||
|
||
# Perform inner join on 'id' column | ||
joined_df = df1.join(df2, "id", "inner") | ||
|
||
# Verify join results using collect() | ||
joined_ids = {row.id for row in joined_df.select("id").collect()} | ||
assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4" | ||
|
||
# Test left outer join | ||
left_joined_df = df1.join(df2, "id", "left") | ||
left_joined_ids = {row.id for row in left_joined_df.select("id").collect()} | ||
assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame" | ||
|
||
# Test right outer join | ||
right_joined_df = df1.join(df2, "id", "right") | ||
right_joined_ids = {row.id for row in right_joined_df.select("id").collect()} | ||
assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame" |