Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix join errors with same key name joins (resolves #2649) #2877

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
uuid = {version = "1", features = ["v4"]}

[dev-dependencies]
daft-dsl = {path = "../daft-dsl", features = ["test-utils"]}
Expand Down
86 changes: 79 additions & 7 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftError;
use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand All @@ -13,6 +13,7 @@ use daft_dsl::{
};
use itertools::Itertools;
use snafu::ResultExt;
use uuid::Uuid;

use crate::{
logical_ops::Project,
Expand Down Expand Up @@ -54,14 +55,31 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, _) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) {
let (unique_left_on, unique_right_on) =
Self::rename_join_keys(left_on.clone(), right_on.clone());

let left_fields: Vec<Field> = unique_left_on
.iter()
.map(|e| e.to_field(&left.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

let right_fields: Vec<Field> = unique_right_on
.iter()
.map(|e| e.to_field(&right.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

for (on_exprs, on_fields) in [
(&unique_left_on, &left_fields),
(&unique_right_on, &right_fields),
] {
for (field, expr) in on_fields.iter().zip(on_exprs.iter()) {
// Null type check for both fields and expressions
if matches!(field.dtype, DataType::Null) {
return Err(DaftError::ValueError(format!(
"Can't join on null type expressions: {expr}"
Expand Down Expand Up @@ -167,6 +185,60 @@ impl Join {
}
}

/// Renames join keys for the given left and right expressions. This is required to
/// prevent errors when the join keys on the left and right expressions have the same key
/// name.
///
/// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and
/// checks for pairs of column expressions that differ. If both expressions in a pair
/// are column expressions and they are not identical, it generates a unique identifier
/// and renames both expressions by appending this identifier to their original names.
///
/// The function returns two vectors of expressions, where the renamed expressions are
/// substituted for the original expressions in the cases where renaming occurred.
///
/// # Parameters
/// - `left_exprs`: A vector of expressions from the left side of a join.
/// - `right_exprs`: A vector of expressions from the right side of a join.
///
/// # Returns
/// A tuple containing two vectors of expressions, one for the left side and one for the
/// right side, where expressions that needed to be renamed have been modified.
///
/// # Example
/// ```
/// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions);
/// ```
///
/// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649).

fn rename_join_keys(
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.zip(right_exprs)
.map(
|(left_expr, right_expr)| match (&*left_expr, &*right_expr) {
(Expr::Column(left_name), Expr::Column(right_name))
if left_name == right_name =>
{
(left_expr, right_expr)
}
_ => {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr =
left_expr.alias(format!("{}_{}", left_expr.name(), unique_id));
let renamed_right_expr =
right_expr.alias(format!("{}_{}", right_expr.name(), unique_id));
(renamed_left_expr, renamed_right_expr)
}
},
)
.unzip()
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!("Join: Type = {}", self.join_type));
Expand Down
11 changes: 11 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_rename_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"])
joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"])

assert set(joined_df1.schema().column_names()) == set(["A", "B"])
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
Expand Down
Loading