Skip to content

Commit

Permalink
[BUG] Fix intersection checking when unioning schemas (#3039)
Browse files Browse the repository at this point in the history
In the definition of `Schema::union`, the error message suggests that we
intended to throw errors when performing a union on two schemas with
overlapping keys. However, the original implementation took the set
difference of keys between one side of the union and itself, which would
never throw an error.

This bug was not noticed because the python tests went through the
python code path which would check for the intersection correctly. But
if one uses the Rust API directly, then this property is not upheld.

We fix this bug by instead checking that the two sides of the union have
distinct keys.
  • Loading branch information
desmondcheongzx authored Oct 15, 2024
1 parent a3453d1 commit c8871d0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
5 changes: 1 addition & 4 deletions daft/logical/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,11 @@ def _truncated_table_string(self) -> str:
def apply_hints(self, hints: Schema) -> Schema:
return Schema._from_pyschema(self._schema.apply_hints(hints._schema))

# Takes the unions between two schemas. Throws an error if the schemas contain overlapping keys.
def union(self, other: Schema) -> Schema:
if not isinstance(other, Schema):
raise ValueError(f"Expected Schema, got other: {type(other)}")

intersecting_names = self.to_name_set().intersection(other.to_name_set())
if intersecting_names:
raise ValueError(f"Cannot union schemas with overlapping names: {intersecting_names}")

return Schema._from_pyschema(self._schema.union(other._schema))

def __reduce__(self) -> tuple:
Expand Down
8 changes: 4 additions & 4 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ pub fn read_csv_into_micropartition(
let unioned_schema = tables
.iter()
.map(|tbl| tbl.schema.clone())
.try_reduce(|s1, s2| s1.union(s2.as_ref()).map(Arc::new))?
.reduce(|s1, s2| Arc::new(s1.non_distinct_union(s2.as_ref())))
.unwrap();
let tables = tables
.into_iter()
Expand Down Expand Up @@ -919,7 +919,7 @@ pub fn read_json_into_micropartition(
let unioned_schema = tables
.iter()
.map(|tbl| tbl.schema.clone())
.try_reduce(|s1, s2| s1.union(s2.as_ref()).map(Arc::new))?
.reduce(|s1, s2| Arc::new(s1.non_distinct_union(s2.as_ref())))
.unwrap();
let tables = tables
.into_iter()
Expand Down Expand Up @@ -1082,7 +1082,7 @@ fn _read_parquet_into_loaded_micropartition<T: AsRef<str>>(
let unioned_schema = all_tables
.iter()
.map(|t| t.schema.clone())
.try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?;
.reduce(|l, r| l.non_distinct_union(&r).into());
unioned_schema.expect("we need at least 1 schema")
};

Expand Down Expand Up @@ -1231,7 +1231,7 @@ pub fn read_parquet_into_micropartition<T: AsRef<str>>(
} else {
let unioned_schema = schemas
.into_iter()
.try_reduce(|l, r| l.union(&r).map(Arc::new))?;
.reduce(|l, r| Arc::new(l.non_distinct_union(&r)));
unioned_schema.expect("we need at least 1 schema")
};

Expand Down
39 changes: 27 additions & 12 deletions src/daft-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,38 @@ impl Schema {
self.fields.is_empty()
}

/// Takes the disjoint union over the `self` and `other` schemas, throwing an error if the
/// schemas contain overlapping keys.
pub fn union(&self, other: &Self) -> DaftResult<Self> {
let self_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys());
let other_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys());
match self_keys.difference(&other_keys).count() {
0 => {
let mut fields = IndexMap::new();
for (k, v) in self.fields.iter().chain(other.fields.iter()) {
fields.insert(k.clone(), v.clone());
}
Ok(Self { fields })
}
_ => Err(DaftError::ValueError(
"Cannot union two schemas with overlapping keys".to_string(),
)),
let other_keys: HashSet<&String> = HashSet::from_iter(other.fields.keys());
if self_keys.is_disjoint(&other_keys) {
let fields = self
.fields
.iter()
.chain(other.fields.iter())
.map(|(k, v)| (k.clone(), v.clone())) // Convert references to owned values
.collect();
Ok(Self { fields })
} else {
Err(DaftError::ValueError(
"Cannot disjoint union two schemas with overlapping keys".to_string(),
))
}
}

/// Takes the non-distinct union of two schemas. If there are overlapping keys, then we take the
/// corresponding field from one of the two schemas.
pub fn non_distinct_union(&self, other: &Self) -> Self {
let fields = self
.fields
.iter()
.chain(other.fields.iter())
.map(|(k, v)| (k.clone(), v.clone())) // Convert references to owned values
.collect();
Self { fields }
}

pub fn apply_hints(&self, hints: &Self) -> DaftResult<Self> {
let applied_fields = self
.fields
Expand Down

0 comments on commit c8871d0

Please sign in to comment.