Skip to content

Commit

Permalink
Renaming on column expressions only
Browse files Browse the repository at this point in the history
Rename join keys only for column expressions; include original
expression name in the renamed expression.
  • Loading branch information
AnmolS authored and AnmolS committed Oct 3, 2024
1 parent 111995e commit 653b43c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 21 deletions.
86 changes: 66 additions & 20 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftError;
use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand Down Expand Up @@ -55,15 +55,31 @@ impl Join {
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (unique_left_on, unique_right_on) = Self::rename_join_keys(left_on, right_on);
let (left_on, left_fields) =
resolve_exprs(unique_left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(unique_right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) {
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, _) =
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

let (unique_left_on, unique_right_on) =
Self::rename_join_keys(left_on.clone(), right_on.clone());

let left_fields: Vec<Field> = unique_left_on
.iter()
.map(|e| e.to_field(&left.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

let right_fields: Vec<Field> = unique_right_on
.iter()
.map(|e| e.to_field(&right.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

for (on_exprs, on_fields) in [
(&unique_left_on, &left_fields),
(&unique_right_on, &right_fields),
] {
for (field, expr) in on_fields.iter().zip(on_exprs.iter()) {
// Null type check for both fields and expressions
if matches!(field.dtype, DataType::Null) {
return Err(DaftError::ValueError(format!(
"Can't join on null type expressions: {expr}"
Expand Down Expand Up @@ -169,23 +185,53 @@ impl Join {
}
}

/// Renames join keys for the given left and right expressions. This is required to
/// prevent errors when the join keys on the left and right expressions have the same key
/// name.
///
/// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and
/// checks for pairs of column expressions that differ. If both expressions in a pair
/// are column expressions and they are not identical, it generates a unique identifier
/// and renames both expressions by appending this identifier to their original names.
///
/// The function returns two vectors of expressions, where the renamed expressions are
/// substituted for the original expressions in the cases where renaming occurred.
///
/// # Parameters
/// - `left_exprs`: A vector of expressions from the left side of a join.
/// - `right_exprs`: A vector of expressions from the right side of a join.
///
/// # Returns
/// A tuple containing two vectors of expressions, one for the left side and one for the
/// right side, where expressions that needed to be renamed have been modified.
///
/// # Example
/// ```
/// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions);
/// ```
///
/// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649).
fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
left_exprs
.into_iter()
.zip(right_exprs)
.map(|(left_expr, right_expr)| {
if left_expr != right_expr {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr = left_expr.alias(unique_id.clone());
let renamed_right_expr = right_expr.alias(unique_id);
(renamed_left_expr, renamed_right_expr)
} else {
(left_expr, right_expr)
}
})
.map(
|(left_expr, right_expr)| match (&*left_expr, &*right_expr) {
(Expr::Column(_), Expr::Column(_)) if left_expr != right_expr => {
let unique_id = Uuid::new_v4().to_string();
let renamed_left_expr =
left_expr.alias(format!("{}_{}", left_expr.name(), unique_id));
let renamed_right_expr =
right_expr.alias(format!("{}_{}", right_expr.name(), unique_id));
(renamed_left_expr, renamed_right_expr)
}
_ => (left_expr, right_expr),
},
)
.unzip()
}

Expand Down
2 changes: 1 addition & 1 deletion tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_columns_after_join(make_df):
assert set(joined_df2.schema().column_names()) == set(["A", "B"])


def test_duplicate_join_keys_in_dataframe(make_df):
def test_rename_join_keys_in_dataframe(make_df):
df1 = make_df({"A": [1, 2], "B": [2, 2]})

df2 = make_df({"A": [1, 2]})
Expand Down

0 comments on commit 653b43c

Please sign in to comment.