Skip to content

Commit

Permalink
[BUG] Fix join errors with same key name joins (resolves #2649) (#2877)
Browse files Browse the repository at this point in the history
The issue fixed here had a workaround previously - aliasing the
duplicate column name. This is not needed anymore as the aliasing is
performed under the hood, taking care of uniqueness of individual column
keys to avoid the duplicate issue.

---------

Co-authored-by: AnmolS <[email protected]>
  • Loading branch information
anmolsingh20 and AnmolS authored Oct 3, 2024
1 parent f3b998e commit 62d0581
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
uuid = {version = "1", features = ["v4"]}

[dev-dependencies]
daft-dsl = {path = "../daft-dsl", features = ["test-utils"]}
Expand Down
86 changes: 79 additions & 7 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftError;
use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand All @@ -13,6 +13,7 @@ use daft_dsl::{
};
use itertools::Itertools;
use snafu::ResultExt;
use uuid::Uuid;

use crate::{
logical_ops::Project,
Expand Down Expand Up @@ -54,14 +55,31 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, _) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) {
let (unique_left_on, unique_right_on) =
Self::rename_join_keys(left_on.clone(), right_on.clone());

let left_fields: Vec<Field> = unique_left_on
.iter()
.map(|e| e.to_field(&left.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

let right_fields: Vec<Field> = unique_right_on
.iter()
.map(|e| e.to_field(&right.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

for (on_exprs, on_fields) in [
(&unique_left_on, &left_fields),
(&unique_right_on, &right_fields),
] {
for (field, expr) in on_fields.iter().zip(on_exprs.iter()) {
// Null type check for both fields and expressions
if matches!(field.dtype, DataType::Null) {
return Err(DaftError::ValueError(format!(
"Can't join on null type expressions: {expr}"
Expand Down Expand Up @@ -167,6 +185,60 @@ impl Join {
}
}

/// Renames join keys for the given left and right expressions. This is required to
/// prevent errors when the join keys on the left and right expressions have the same key
/// name.
///
/// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and
/// checks for pairs of column expressions that differ. If both expressions in a pair
/// are column expressions and they are not identical, it generates a unique identifier
/// and renames both expressions by appending this identifier to their original names.
///
/// The function returns two vectors of expressions, where the renamed expressions are
/// substituted for the original expressions in the cases where renaming occurred.
///
/// # Parameters
/// - `left_exprs`: A vector of expressions from the left side of a join.
/// - `right_exprs`: A vector of expressions from the right side of a join.
///
/// # Returns
/// A tuple containing two vectors of expressions, one for the left side and one for the
/// right side, where expressions that needed to be renamed have been modified.
///
/// # Example
/// ```
/// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions);
/// ```
///
/// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649).
fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.zip(right_exprs)
.map(
|(left_expr, right_expr)| match (&*left_expr, &*right_expr) {
(Expr::Column(left_name), Expr::Column(right_name))
if left_name == right_name =>
{
(left_expr, right_expr)
}
_ => {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr =
left_expr.alias(format!("{}_{}", left_expr.name(), unique_id));
let renamed_right_expr =
right_expr.alias(format!("{}_{}", right_expr.name(), unique_id));
(renamed_left_expr, renamed_right_expr)
}
},
)
.unzip()
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push(format!("Join: Type = {}", self.join_type));
Expand Down
11 changes: 11 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_rename_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"])
joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"])

assert set(joined_df1.schema().column_names()) == set(["A", "B"])
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
Expand Down

0 comments on commit 62d0581

Please sign in to comment.