Skip to content

Commit

Permalink
[REFACTOR] map datatype → key, value
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Sep 24, 2024
1 parent 1560b4f commit 7c40c82
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ impl ListArray {
}
}
}
DataType::Map{..} => Ok(MapArray::new(
DataType::Map { .. } => Ok(MapArray::new(
Field::new(self.name(), dtype.clone()),
self.clone(),
)
Expand Down
4 changes: 2 additions & 2 deletions src/daft-core/src/array/ops/from_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
// TODO: Consolidate Map to use the same .to_type conversion as other logical types
// Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent,
// but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray
DataType::Map{..} => arrow_arr,
DataType::Map { .. } => arrow_arr,
_ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?),
};
let physical = <L::PhysicalType as DaftDataType>::ArrayType::from_arrow(
Expand Down Expand Up @@ -98,7 +98,7 @@ impl FromArrow for ListArray {
arrow_arr.validity().cloned(),
))
}
(DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => {
(DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map { .. }) => {
let map_arr = arrow_arr
.as_any()
.downcast_ref::<arrow2::array::MapArray>()
Expand Down
9 changes: 6 additions & 3 deletions src/daft-core/src/array/ops/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult<Series> {

impl MapArray {
pub fn map_get(&self, key_to_get: &Series) -> DaftResult<Series> {
let DataType::Map { value: value_type, .. } = self.data_type() else {
let DataType::Map {
value: value_type, ..
} = self.data_type()
else {
return Err(DaftError::TypeError(format!(
"Expected input to be a map type, got {:?}",
self.data_type()
Expand All @@ -37,7 +40,7 @@ impl MapArray {
for series in self.physical.into_iter() {
match series {
Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
_ => result.push(Series::full_null("value", value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
Expand All @@ -47,7 +50,7 @@ impl MapArray {
for (i, series) in self.physical.into_iter().enumerate() {
match (series, key_to_get.slice(i, i + 1)?) {
(Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
_ => result.push(Series::full_null("value", value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/series/ops/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{datatypes::DataType, series::Series};
impl Series {
pub fn map_get(&self, key: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Map{..} => self.map()?.map_get(key),
DataType::Map { .. } => self.map()?.map_get(key),
dt => Err(DaftError::TypeError(format!(
"map.get not implemented for {}",
dt
Expand Down
52 changes: 26 additions & 26 deletions src/daft-core/src/series/serdes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,83 +78,83 @@ impl<'d> serde::Deserialize<'d> for Series {
&field.dtype,
map.next_value::<usize>()?,
)
.into_series()),
.into_series()),
DataType::Boolean => Ok(BooleanArray::from((
field.name.as_str(),
map.next_value::<Vec<Option<bool>>>()?.as_slice(),
))
.into_series()),
.into_series()),
DataType::Int8 => Ok(Int8Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<i8>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Int16 => Ok(Int16Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<i16>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Int32 => Ok(Int32Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<i32>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Int64 => Ok(Int64Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<i64>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Int128 => Ok(Int128Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<i128>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::UInt8 => Ok(UInt8Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<u8>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::UInt16 => Ok(UInt16Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<u16>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::UInt32 => Ok(UInt32Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<u32>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::UInt64 => Ok(UInt64Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<u64>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Float32 => Ok(Float32Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<f32>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Float64 => Ok(Float64Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<f64>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Utf8 => Ok(Utf8Array::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<Cow<str>>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::Binary => Ok(BinaryArray::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<Cow<[u8]>>>>()?.into_iter(),
)
.into_series()),
.into_series()),
DataType::FixedSizeBinary(size) => Ok(FixedSizeBinaryArray::from_iter(
field.name.as_str(),
map.next_value::<Vec<Option<Cow<[u8]>>>>()?.into_iter(),
*size,
)
.into_series()),
.into_series()),
DataType::Extension(..) => {
let physical = map.next_value::<Series>()?;
let physical = physical.to_arrow();
Expand All @@ -169,7 +169,7 @@ impl<'d> serde::Deserialize<'d> for Series {
Arc::new(field),
physical.downcast::<ListArray>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::Struct(..) => {
let mut all_series = map.next_value::<Vec<Option<Series>>>()?;
Expand Down Expand Up @@ -198,7 +198,7 @@ impl<'d> serde::Deserialize<'d> for Series {
let offsets = OffsetsBuffer::<i64>::try_from(
offsets_array.as_arrow().values().clone(),
)
.unwrap();
.unwrap();
let flat_child = all_series
.pop()
.ok_or_else(|| serde::de::Error::missing_field("flat_child"))?
Expand All @@ -225,7 +225,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::Timestamp(..) => {
type PType = <<TimestampType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -234,7 +234,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::Date => {
type PType = <<DateType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -260,7 +260,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series(),
.into_series(),
)
}
DataType::Embedding(..) => {
Expand All @@ -270,7 +270,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::Image(..) => {
type PType = <<ImageType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -287,7 +287,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::FixedShapeTensor(..) => {
type PType = <<FixedShapeTensorType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -296,7 +296,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::SparseTensor(..) => {
type PType = <<SparseTensorType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -305,7 +305,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::FixedShapeSparseTensor(..) => {
type PType = <<FixedShapeSparseTensorType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand All @@ -314,7 +314,7 @@ impl<'d> serde::Deserialize<'d> for Series {
field,
physical.downcast::<PType>().unwrap().clone(),
)
.into_series())
.into_series())
}
DataType::Tensor(..) => {
type PType = <<TensorType as DaftLogicalType>::PhysicalType as DaftDataType>::ArrayType;
Expand Down
2 changes: 0 additions & 2 deletions src/daft-dsl/src/functions/map/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ impl FunctionEvaluator for GetEvaluator {

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {

// what is input and what is key
// input is a map field

[input, key] => match (input.to_field(schema), key.to_field(schema)) {
(Ok(input_field), Ok(_)) => match input_field.dtype {
DataType::Map { value, .. } => {
Expand Down
17 changes: 9 additions & 8 deletions src/daft-schema/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ pub enum DataType {

/// A nested [`DataType`] that is represented as List<entries: Struct<key: K, value: V>>.
#[display("Map[{key}: {value}]")]
Map { key: Box<DataType>, value: Box<DataType> },
Map {
key: Box<DataType>,
value: Box<DataType>,
},

/// Extension type.
#[display("{_1}")]
Expand Down Expand Up @@ -240,18 +243,16 @@ impl DataType {
]);

// entries
let struct_field = arrow2::datatypes::Field::new("entries", struct_type.clone(), true);
let struct_field =
arrow2::datatypes::Field::new("entries", struct_type.clone(), true);

let list_type = ArrowType::LargeList(Box::new(struct_field));

// entries
// todo: item? items? something else?
let list_field = arrow2::datatypes::Field::new("item", list_type.clone(), true);

Ok(ArrowType::Map(
Box::new(list_field),
false,
))
Ok(ArrowType::Map(Box::new(list_field), false))
}
DataType::Struct(fields) => Ok({
let fields = fields
Expand Down Expand Up @@ -454,7 +455,7 @@ impl DataType {

#[inline]
pub fn is_map(&self) -> bool {
matches!(self, DataType::Map{ .. })
matches!(self, DataType::Map { .. })
}

#[inline]
Expand Down Expand Up @@ -660,7 +661,7 @@ impl From<&ArrowType> for DataType {
let [key, value] = fields.as_slice() else {
panic!("Map should have two fields")
};

let key = &key.data_type;
let value = &value.data_type;

Expand Down
12 changes: 8 additions & 4 deletions src/daft-schema/src/python/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ impl PyDataType {

#[staticmethod]
pub fn map(key_type: Self, value_type: Self) -> PyResult<Self> {
Ok(DataType::Map { key: Box::new(key_type.dtype), value: Box::new(value_type.dtype) }.into())
Ok(DataType::Map {
key: Box::new(key_type.dtype),
value: Box::new(value_type.dtype),
}
.into())
}

#[staticmethod]
Expand All @@ -220,7 +224,7 @@ impl PyDataType {
.map(|(name, dtype)| Field::new(name, dtype.dtype))
.collect::<Vec<Field>>(),
)
.into()
.into()
}

#[staticmethod]
Expand All @@ -234,7 +238,7 @@ impl PyDataType {
Box::new(storage_data_type.dtype),
metadata.map(|s| s.to_string()),
)
.into())
.into())
}

#[staticmethod]
Expand Down Expand Up @@ -325,7 +329,7 @@ impl PyDataType {
Self {
dtype: *dtype.clone(),
}
.to_arrow(py)?,
.to_arrow(py)?,
pyo3::types::PyTuple::new_bound(py, shape.clone()),
))
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-stats/src/column_stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl ColumnRangeStatistics {

// UNSUPPORTED TYPES:
// Types that don't support comparisons and can't be used as ColumnRangeStatistics
DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false,
DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map { .. } | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false,
#[cfg(feature = "python")]
DataType::Python => false,
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/repr_html.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn html_value(s: &Series, idx: usize) -> String {
let arr = s.struct_().unwrap();
arr.html_value(idx)
}
DataType::Map(_) => {
DataType::Map { .. } => {
let arr = s.map().unwrap();
arr.html_value(idx)
}
Expand Down

0 comments on commit 7c40c82

Please sign in to comment.