diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 93005ddae7..5ee40dfb4b 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -146,14 +146,11 @@ def _truncated_table_string(self) -> str: def apply_hints(self, hints: Schema) -> Schema: return Schema._from_pyschema(self._schema.apply_hints(hints._schema)) + # Takes the unions between two schemas. Throws an error if the schemas contain overlapping keys. def union(self, other: Schema) -> Schema: if not isinstance(other, Schema): raise ValueError(f"Expected Schema, got other: {type(other)}") - intersecting_names = self.to_name_set().intersection(other.to_name_set()) - if intersecting_names: - raise ValueError(f"Cannot union schemas with overlapping names: {intersecting_names}") - return Schema._from_pyschema(self._schema.union(other._schema)) def __reduce__(self) -> tuple: diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index e8d349e55f..b1a71431ee 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -870,7 +870,7 @@ pub fn read_csv_into_micropartition( let unioned_schema = tables .iter() .map(|tbl| tbl.schema.clone()) - .try_reduce(|s1, s2| s1.union(s2.as_ref()).map(Arc::new))? + .reduce(|s1, s2| Arc::new(s1.non_distinct_union(s2.as_ref()))) .unwrap(); let tables = tables .into_iter() @@ -919,7 +919,7 @@ pub fn read_json_into_micropartition( let unioned_schema = tables .iter() .map(|tbl| tbl.schema.clone()) - .try_reduce(|s1, s2| s1.union(s2.as_ref()).map(Arc::new))? + .reduce(|s1, s2| Arc::new(s1.non_distinct_union(s2.as_ref()))) .unwrap(); let tables = tables .into_iter() @@ -1082,7 +1082,7 @@ fn _read_parquet_into_loaded_micropartition>( let unioned_schema = all_tables .iter() .map(|t| t.schema.clone()) - .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; + .reduce(|l, r| l.non_distinct_union(&r).into()); unioned_schema.expect("we need at least 1 schema") }; @@ -1231,7 +1231,7 @@ pub fn read_parquet_into_micropartition>( } else { let unioned_schema = schemas .into_iter() - .try_reduce(|l, r| l.union(&r).map(Arc::new))?; + .reduce(|l, r| Arc::new(l.non_distinct_union(&r))); unioned_schema.expect("we need at least 1 schema") }; diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index d220897228..a36f6abe02 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -100,23 +100,38 @@ impl Schema { self.fields.is_empty() } + /// Takes the disjoint union over the `self` and `other` schemas, throwing an error if the + /// schemas contain overlapping keys. pub fn union(&self, other: &Self) -> DaftResult { let self_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys()); - let other_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys()); - match self_keys.difference(&other_keys).count() { - 0 => { - let mut fields = IndexMap::new(); - for (k, v) in self.fields.iter().chain(other.fields.iter()) { - fields.insert(k.clone(), v.clone()); - } - Ok(Self { fields }) - } - _ => Err(DaftError::ValueError( - "Cannot union two schemas with overlapping keys".to_string(), - )), + let other_keys: HashSet<&String> = HashSet::from_iter(other.fields.keys()); + if self_keys.is_disjoint(&other_keys) { + let fields = self + .fields + .iter() + .chain(other.fields.iter()) + .map(|(k, v)| (k.clone(), v.clone())) // Convert references to owned values + .collect(); + Ok(Self { fields }) + } else { + Err(DaftError::ValueError( + "Cannot disjoint union two schemas with overlapping keys".to_string(), + )) } } + /// Takes the non-distinct union of two schemas. If there are overlapping keys, then we take the + /// corresponding field from one of the two schemas. + pub fn non_distinct_union(&self, other: &Self) -> Self { + let fields = self + .fields + .iter() + .chain(other.fields.iter()) + .map(|(k, v)| (k.clone(), v.clone())) // Convert references to owned values + .collect(); + Self { fields } + } + pub fn apply_hints(&self, hints: &Self) -> DaftResult { let applied_fields = self .fields