Skip to content

Commit

Permalink
Impl custom PartialEq and Hash on Field that ignores metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Feb 24, 2024
1 parent 412ffbc commit 079991f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
15 changes: 14 additions & 1 deletion src/daft-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};

pub type Metadata = std::collections::BTreeMap<String, String>;

#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)]
#[derive(Clone, Debug, Eq, Deserialize, Serialize)]
pub struct Field {
pub name: String,
pub dtype: DataType,
Expand Down Expand Up @@ -129,3 +129,16 @@ impl Display for Field {
write!(f, "{}#{}", self.name, self.dtype)
}
}

impl PartialEq for Field {
fn eq(&self, other: &Self) -> bool {
self.dtype == other.dtype && self.name == other.name
}
}

impl std::hash::Hash for Field {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.dtype.hash(state);
}
}
9 changes: 2 additions & 7 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,8 @@ impl MicroPartition {
// Check and validate invariants with asserts
for table in tables.iter() {
assert!(
table.schema.fields.len() == schema.fields.len()
&& table.schema.fields.iter().zip(schema.fields.iter()).all(
|((s1, f1), (s2, f2))| s1 == s2
&& f1.name == f2.name
&& f1.dtype == f2.dtype
),
"Loaded MicroPartition's tables' schema must match its own schema exactly"
table.schema == schema,
"Loaded MicroPartition's tables' schema must match its own schema exactly",
);
}

Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Table {
let mut num_rows = 1;

for (field, series) in schema.fields.values().zip(columns.iter()) {
if field.name != series.field().name || field.dtype != series.field().dtype {
if field != series.field() {
return Err(DaftError::SchemaMismatch(format!("While building a Table, we found that the Schema Field and the Series Field did not match. schema field: {field} vs series field: {}", series.field())));
}
if (series.len() != 1) && (series.len() != num_rows) {
Expand Down

0 comments on commit 079991f

Please sign in to comment.