Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] map type → key/value #2910

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ impl ListArray {
}
}
}
DataType::Map(..) => Ok(MapArray::new(
DataType::Map { .. } => Ok(MapArray::new(
Field::new(self.name(), dtype.clone()),
self.clone(),
)
Expand Down
4 changes: 2 additions & 2 deletions src/daft-core/src/array/ops/from_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where
// TODO: Consolidate Map to use the same .to_type conversion as other logical types
// Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent,
// but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray
DataType::Map(..) => arrow_arr,
DataType::Map { .. } => arrow_arr,
_ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?),
};
let physical = <L::PhysicalType as DaftDataType>::ArrayType::from_arrow(
Expand Down Expand Up @@ -98,7 +98,7 @@ impl FromArrow for ListArray {
arrow_arr.validity().cloned(),
))
}
(DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => {
(DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map { .. }) => {
let map_arr = arrow_arr
.as_any()
.downcast_ref::<arrow2::array::MapArray>()
Expand Down
21 changes: 6 additions & 15 deletions src/daft-core/src/array/ops/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,10 @@ fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult<Series> {

impl MapArray {
pub fn map_get(&self, key_to_get: &Series) -> DaftResult<Series> {
let value_type = if let DataType::Map(inner_dtype) = self.data_type() {
match *inner_dtype.clone() {
DataType::Struct(fields) if fields.len() == 2 => {
fields[1].dtype.clone()
}
_ => {
return Err(DaftError::TypeError(format!(
"Expected inner type to be a struct type with two fields: key and value, got {:?}",
inner_dtype
)))
}
}
} else {
let DataType::Map {
value: value_type, ..
} = self.data_type()
else {
return Err(DaftError::TypeError(format!(
"Expected input to be a map type, got {:?}",
self.data_type()
Expand All @@ -49,7 +40,7 @@ impl MapArray {
for series in self.physical.into_iter() {
match series {
Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
_ => result.push(Series::full_null("value", value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
Expand All @@ -59,7 +50,7 @@ impl MapArray {
for (i, series) in self.physical.into_iter().enumerate() {
match (series, key_to_get.slice(i, i + 1)?) {
(Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
_ => result.push(Series::full_null("value", value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/matching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ macro_rules! with_match_daft_types {(
FixedSizeList(_, _) => __with_ty__! { FixedSizeListType },
List(_) => __with_ty__! { ListType },
Struct(_) => __with_ty__! { StructType },
Map(_) => __with_ty__! { MapType },
Map{..} => __with_ty__! { MapType },
Extension(_, _, _) => __with_ty__! { ExtensionType },
#[cfg(feature = "python")]
Python => __with_ty__! { PythonType },
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/series/ops/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{datatypes::DataType, series::Series};
impl Series {
pub fn map_get(&self, key: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Map(_) => self.map()?.map_get(key),
DataType::Map { .. } => self.map()?.map_get(key),
dt => Err(DaftError::TypeError(format!(
"map.get not implemented for {}",
dt
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/series/serdes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl<'d> serde::Deserialize<'d> for Series {
.unwrap()
.into_series())
}
DataType::Map(..) => {
DataType::Map { .. } => {
let physical = map.next_value::<Series>()?;
Ok(MapArray::new(
Arc::new(field),
Expand Down
16 changes: 6 additions & 10 deletions src/daft-dsl/src/functions/map/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ impl FunctionEvaluator for GetEvaluator {

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
// what is input and what is key
// input is a map field
[input, key] => match (input.to_field(schema), key.to_field(schema)) {
(Ok(input_field), Ok(_)) => match input_field.dtype {
DataType::Map(inner) => match inner.as_ref() {
DataType::Struct(fields) if fields.len() == 2 => {
let value_dtype = &fields[1].dtype;
Ok(Field::new("value", value_dtype.clone()))
}
_ => Err(DaftError::TypeError(format!(
"Expected input map to have struct values with 2 fields, got {}",
inner
))),
},
DataType::Map { value, .. } => {
// todo: perhaps better naming
Ok(Field::new("value", *value))
}
_ => Err(DaftError::TypeError(format!(
"Expected input to be a map, got {}",
input_field.dtype
Expand Down
117 changes: 73 additions & 44 deletions src/daft-schema/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ pub enum DataType {
Struct(Vec<Field>),

/// A nested [`DataType`] that is represented as List<entries: Struct<key: K, value: V>>.
#[display("Map[{_0}]")]
Map(Box<DataType>),
#[display("Map[{key}: {value}]")]
Map {
key: Box<DataType>,
value: Box<DataType>,
},

/// Extension type.
#[display("{_1}")]
Expand Down Expand Up @@ -233,14 +236,24 @@ impl DataType {
DataType::List(field) => Ok(ArrowType::LargeList(Box::new(
arrow2::datatypes::Field::new("item", field.to_arrow()?, true),
))),
DataType::Map(field) => Ok(ArrowType::Map(
Box::new(arrow2::datatypes::Field::new(
"item",
field.to_arrow()?,
true,
)),
false,
)),
DataType::Map { key, value } => {
let struct_type = ArrowType::Struct(vec![
arrow2::datatypes::Field::new("key", key.to_arrow()?, true),
arrow2::datatypes::Field::new("value", value.to_arrow()?, true),
]);

// entries
let struct_field =
arrow2::datatypes::Field::new("entries", struct_type.clone(), true);

let list_type = ArrowType::LargeList(Box::new(struct_field));

// entries
// todo: item? items? something else?
let list_field = arrow2::datatypes::Field::new("item", list_type.clone(), true);

Ok(ArrowType::Map(Box::new(list_field), false))
}
DataType::Struct(fields) => Ok({
let fields = fields
.iter()
Expand Down Expand Up @@ -288,7 +301,10 @@ impl DataType {
FixedSizeList(child_dtype, size) => {
FixedSizeList(Box::new(child_dtype.to_physical()), *size)
}
Map(child_dtype) => List(Box::new(child_dtype.to_physical())),
Map { key, value } => List(Box::new(Struct(vec![
Field::new("key", key.to_physical()),
Field::new("value", value.to_physical()),
]))),
Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size),
Image(mode) => Struct(vec![
Field::new(
Expand Down Expand Up @@ -328,20 +344,6 @@ impl DataType {
}
}

#[inline]
pub fn nested_dtype(&self) -> Option<&DataType> {
match self {
DataType::Map(dtype)
| DataType::List(dtype)
| DataType::FixedSizeList(dtype, _)
| DataType::FixedShapeTensor(dtype, _)
| DataType::SparseTensor(dtype)
| DataType::FixedShapeSparseTensor(dtype, _)
| DataType::Tensor(dtype) => Some(dtype),
_ => None,
}
}

#[inline]
pub fn is_arrow(&self) -> bool {
self.to_arrow().is_ok()
Expand All @@ -350,21 +352,21 @@ impl DataType {
#[inline]
pub fn is_numeric(&self) -> bool {
match self {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
// DataType::Float16
| DataType::Float32
| DataType::Float64 => true,
DataType::Extension(_, inner, _) => inner.is_numeric(),
_ => false
}
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
// DataType::Float16
| DataType::Float32
| DataType::Float64 => true,
DataType::Extension(_, inner, _) => inner.is_numeric(),
_ => false
}
}

#[inline]
Expand Down Expand Up @@ -453,7 +455,7 @@ impl DataType {

#[inline]
pub fn is_map(&self) -> bool {
matches!(self, DataType::Map(..))
matches!(self, DataType::Map { .. })
}

#[inline]
Expand Down Expand Up @@ -576,7 +578,7 @@ impl DataType {
| DataType::FixedShapeTensor(..)
| DataType::SparseTensor(..)
| DataType::FixedShapeSparseTensor(..)
| DataType::Map(..)
| DataType::Map { .. }
)
}

Expand All @@ -593,7 +595,7 @@ impl DataType {
DataType::List(..)
| DataType::FixedSizeList(..)
| DataType::Struct(..)
| DataType::Map(..)
| DataType::Map { .. }
)
}

Expand Down Expand Up @@ -643,7 +645,34 @@ impl From<&ArrowType> for DataType {
ArrowType::FixedSizeList(field, size) => {
DataType::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size)
}
ArrowType::Map(field, ..) => DataType::Map(Box::new(field.as_ref().data_type().into())),
ArrowType::Map(field, ..) => {
// todo: TryFrom in future? want in second pass maybe

// field should be a list
let ArrowType::List(field) = &field.data_type else {
panic!("Map should have a list as its key")
};

// field should be a struct
let ArrowType::Struct(fields) = &field.data_type else {
panic!("Map should have a struct as its key")
};

let [key, value] = fields.as_slice() else {
panic!("Map should have two fields")
};

let key = &key.data_type;
let value = &value.data_type;

let key = DataType::from(key);
let value = DataType::from(value);

let key = Box::new(key);
let value = Box::new(value);

DataType::Map { key, value }
}
ArrowType::Struct(fields) => {
let fields: Vec<Field> = fields.iter().map(|fld| fld.into()).collect();
DataType::Struct(fields)
Expand Down
8 changes: 4 additions & 4 deletions src/daft-schema/src/python/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ impl PyDataType {

#[staticmethod]
pub fn map(key_type: Self, value_type: Self) -> PyResult<Self> {
Ok(DataType::Map(Box::new(DataType::Struct(vec![
Field::new("key", key_type.dtype),
Field::new("value", value_type.dtype),
])))
Ok(DataType::Map {
key: Box::new(key_type.dtype),
value: Box::new(value_type.dtype),
}
.into())
}

Expand Down
2 changes: 1 addition & 1 deletion src/daft-stats/src/column_stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl ColumnRangeStatistics {

// UNSUPPORTED TYPES:
// Types that don't support comparisons and can't be used as ColumnRangeStatistics
DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false,
DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map { .. } | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false,
#[cfg(feature = "python")]
DataType::Python => false,
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/repr_html.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn html_value(s: &Series, idx: usize) -> String {
let arr = s.struct_().unwrap();
arr.html_value(idx)
}
DataType::Map(_) => {
DataType::Map { .. } => {
let arr = s.map().unwrap();
arr.html_value(idx)
}
Expand Down
Loading