From 3e0e3348580d5c2ea70f3799eb03a692545162b0 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 07:35:30 -0800 Subject: [PATCH] [FEAT] MapArray (#1959) Closes #1847 Introduces a new MapArray type, which will improve our iceberg/parquet compatibility. MapArray is implemented as a logical type on top of ListArray, where the fields are Structs with key and value entries. --- Cargo.lock | 2 +- Cargo.toml | 2 +- daft/daft.pyi | 3 + daft/datatype.py | 24 +++++++ daft/series.py | 2 +- .../src/array/growable/logical_growable.rs | 3 +- src/daft-core/src/array/growable/mod.rs | 3 +- src/daft-core/src/array/ops/cast.rs | 15 ++++- src/daft-core/src/array/ops/from_arrow.rs | 55 +++++++++++++--- src/daft-core/src/array/ops/get.rs | 10 ++- src/daft-core/src/array/ops/repr.rs | 13 +++- src/daft-core/src/array/ops/sort.rs | 8 ++- src/daft-core/src/array/ops/take.rs | 3 +- src/daft-core/src/datatypes/dtype.rs | 26 +++++++- src/daft-core/src/datatypes/logical.rs | 25 ++++++-- src/daft-core/src/datatypes/matching.rs | 1 + src/daft-core/src/datatypes/mod.rs | 21 ++++++ src/daft-core/src/python/datatype.rs | 13 ++++ .../src/series/array_impl/binary_ops.rs | 6 +- .../src/series/array_impl/logical_array.rs | 4 +- src/daft-core/src/series/serdes.rs | 10 ++- src/daft-core/src/utils/arrow.rs | 50 ++++++++++++--- src/daft-core/src/utils/supertype.rs | 5 ++ src/daft-stats/src/column_stats/mod.rs | 2 +- tests/dataframe/test_creation.py | 21 ++++++ tests/integration/iceberg/test_table_load.py | 2 +- tests/io/delta_lake/conftest.py | 9 ++- tests/io/test_parquet.py | 2 + tests/io/test_parquet_roundtrip.py | 5 ++ tests/series/test_concat.py | 25 ++++++++ tests/series/test_if_else.py | 64 +++++++++++++++++++ tests/table/test_from_py.py | 47 ++++++++++++++ tests/test_schema.py | 2 + 33 files changed, 440 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7f670752de..1b3d855dc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,7 +105,7 @@ dependencies = [ [[package]] name = "arrow2" version = "0.17.4" -source = "git+https://github.com/Eventual-Inc/arrow2?rev=d5685eebf1d65c3f3d854370ad39f93dcd91971a#d5685eebf1d65c3f3d854370ad39f93dcd91971a" +source = "git+https://github.com/Eventual-Inc/arrow2?rev=c0764b00cc05126c80c7ce17ebd7a95d87f815c1#c0764b00cc05126c80c7ce17ebd7a95d87f815c1" dependencies = [ "ahash", "arrow-format", diff --git a/Cargo.toml b/Cargo.toml index eaf82e3000..e24c1d2dbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,7 +115,7 @@ url = "2.4.0" # branch = "daft-fork" git = "https://github.com/Eventual-Inc/arrow2" package = "arrow2" -rev = "d5685eebf1d65c3f3d854370ad39f93dcd91971a" +rev = "c0764b00cc05126c80c7ce17ebd7a95d87f815c1" [workspace.dependencies.bincode] version = "1.3.3" diff --git a/daft/daft.pyi b/daft/daft.pyi index 853d490e27..a0dd41890d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -826,6 +826,8 @@ class PyDataType: @staticmethod def fixed_size_list(data_type: PyDataType, size: int) -> PyDataType: ... @staticmethod + def map(key_type: PyDataType, value_type: PyDataType) -> PyDataType: ... + @staticmethod def struct(fields: dict[str, PyDataType]) -> PyDataType: ... @staticmethod def extension(name: str, storage_data_type: PyDataType, metadata: str | None = None) -> PyDataType: ... @@ -842,6 +844,7 @@ class PyDataType: def is_fixed_shape_image(self) -> builtins.bool: ... def is_tensor(self) -> builtins.bool: ... def is_fixed_shape_tensor(self) -> builtins.bool: ... + def is_map(self) -> builtins.bool: ... def is_logical(self) -> builtins.bool: ... def is_temporal(self) -> builtins.bool: ... def is_equal(self, other: Any) -> builtins.bool: ... diff --git a/daft/datatype.py b/daft/datatype.py index 06b5443a89..6c0cf6c1d4 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -236,6 +236,15 @@ def fixed_size_list(cls, dtype: DataType, size: int) -> DataType: raise ValueError("The size for a fixed-size list must be a positive integer, but got: ", size) return cls._from_pydatatype(PyDataType.fixed_size_list(dtype._dtype, size)) + @classmethod + def map(cls, key_type: DataType, value_type: DataType) -> DataType: + """Create a Map DataType: A map is a nested type of key-value pairs that is implemented as a list of structs with two fields, key and value. + Args: + key_type: DataType of the keys in the map + value_type: DataType of the values in the map + """ + return cls._from_pydatatype(PyDataType.map(key_type._dtype, value_type._dtype)) + @classmethod def struct(cls, fields: dict[str, DataType]) -> DataType: """Create a Struct DataType: a nested type which has names mapped to child types @@ -387,6 +396,12 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: assert isinstance(arrow_type, pa.StructType) fields = [arrow_type[i] for i in range(arrow_type.num_fields)] return cls.struct({field.name: cls.from_arrow_type(field.type) for field in fields}) + elif pa.types.is_map(arrow_type): + assert isinstance(arrow_type, pa.MapType) + return cls.map( + key_type=cls.from_arrow_type(arrow_type.key_type), + value_type=cls.from_arrow_type(arrow_type.item_type), + ) elif _RAY_DATA_EXTENSIONS_AVAILABLE and isinstance(arrow_type, tuple(_TENSOR_EXTENSION_TYPES)): scalar_dtype = cls.from_arrow_type(arrow_type.scalar_type) shape = arrow_type.shape if isinstance(arrow_type, ArrowTensorType) else None @@ -464,12 +479,21 @@ def _is_image_type(self) -> builtins.bool: def _is_fixed_shape_image_type(self) -> builtins.bool: return self._dtype.is_fixed_shape_image() + def _is_map(self) -> builtins.bool: + return self._dtype.is_map() + def _is_logical_type(self) -> builtins.bool: return self._dtype.is_logical() def _is_temporal_type(self) -> builtins.bool: return self._dtype.is_temporal() + def _should_cast_to_python(self) -> builtins.bool: + # NOTE: This is used to determine if we should cast a column to a Python object type when converting to PyList. + # Map is a logical type, but we don't want to cast it to Python because the underlying physical type is a List, + # which we can handle without casting to Python. + return self._is_logical_type() and not self._is_map() + def __repr__(self) -> str: return self._dtype.__repr__() diff --git a/daft/series.py b/daft/series.py index dac7aca5ef..75680eb1ee 100644 --- a/daft/series.py +++ b/daft/series.py @@ -279,7 +279,7 @@ def to_pylist(self) -> list: """ if self.datatype()._is_python_type(): return self._series.to_pylist() - elif self.datatype()._is_logical_type(): + elif self.datatype()._should_cast_to_python(): return self._series.cast(DataType.python()._dtype).to_pylist() else: return self._series.to_arrow().to_pylist() diff --git a/src/daft-core/src/array/growable/logical_growable.rs b/src/daft-core/src/array/growable/logical_growable.rs index 5f0770789b..c272e8cd8f 100644 --- a/src/daft-core/src/array/growable/logical_growable.rs +++ b/src/daft-core/src/array/growable/logical_growable.rs @@ -6,7 +6,7 @@ use crate::{ datatypes::{ logical::LogicalArray, DaftDataType, DaftLogicalType, DateType, Decimal128Type, DurationType, EmbeddingType, Field, FixedShapeImageType, FixedShapeTensorType, ImageType, - TensorType, TimeType, TimestampType, + MapType, TensorType, TimeType, TimestampType, }, DataType, IntoSeries, Series, }; @@ -84,3 +84,4 @@ impl_logical_growable!(LogicalFixedShapeTensorGrowable, FixedShapeTensorType); impl_logical_growable!(LogicalImageGrowable, ImageType); impl_logical_growable!(LogicalDecimal128Growable, Decimal128Type); impl_logical_growable!(LogicalTensorGrowable, TensorType); +impl_logical_growable!(LogicalMapGrowable, MapType); diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index b0de488352..9c7a93d515 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -5,7 +5,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, @@ -211,3 +211,4 @@ impl_growable_array!( Decimal128Array, logical_growable::LogicalDecimal128Growable<'a> ); +impl_growable_array!(MapArray, logical_growable::LogicalMapGrowable<'a>); diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 9b23ef1c3b..5d2ce193c9 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -10,8 +10,8 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray, - TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, MapArray, + TensorArray, TimeArray, TimestampArray, }, DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, Int64Array, TimeUnit, UInt64Array, Utf8Array, @@ -1727,11 +1727,22 @@ impl ListArray { } } } + DataType::Map(..) => Ok(MapArray::new( + Field::new(self.name(), dtype.clone()), + self.clone(), + ) + .into_series()), _ => unimplemented!("List casting not implemented for dtype: {}", dtype), } } } +impl MapArray { + pub fn cast(&self, dtype: &DataType) -> DaftResult { + self.physical.cast(dtype) + } +} + impl StructArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match (self.data_type(), dtype) { diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index f22f81e2ca..8b03781d70 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -30,7 +30,13 @@ where { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { let data_array_field = Arc::new(Field::new(field.name.clone(), field.dtype.to_physical())); - let physical_arrow_arr = arrow_arr.to_type(data_array_field.dtype.to_arrow()?); + let physical_arrow_arr = match field.dtype { + // TODO: Consolidate Map to use the same .to_type conversion as other logical types + // Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent, + // but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray + DataType::Map(..) => arrow_arr, + _ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?), + }; let physical = ::ArrayType::from_arrow( data_array_field, physical_arrow_arr, @@ -64,13 +70,26 @@ impl FromArrow for FixedSizeListArray { impl FromArrow for ListArray { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { match (&field.dtype, arrow_arr.data_type()) { - (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::List(arrow_child_field)) | - (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::LargeList(arrow_child_field)) - => { - let arrow_arr = arrow_arr.to_type(arrow2::datatypes::DataType::LargeList(arrow_child_field.clone())); - let arrow_arr = arrow_arr.as_any().downcast_ref::>().unwrap(); + ( + DataType::List(daft_child_dtype), + arrow2::datatypes::DataType::List(arrow_child_field), + ) + | ( + DataType::List(daft_child_dtype), + arrow2::datatypes::DataType::LargeList(arrow_child_field), + ) => { + let arrow_arr = arrow_arr.to_type(arrow2::datatypes::DataType::LargeList( + arrow_child_field.clone(), + )); + let arrow_arr = arrow_arr + .as_any() + .downcast_ref::>() + .unwrap(); let arrow_child_array = arrow_arr.values(); - let child_series = Series::from_arrow(Arc::new(Field::new("list", daft_child_dtype.as_ref().clone())), arrow_child_array.clone())?; + let child_series = Series::from_arrow( + Arc::new(Field::new("list", daft_child_dtype.as_ref().clone())), + arrow_child_array.clone(), + )?; Ok(ListArray::new( field.clone(), child_series, @@ -78,7 +97,27 @@ impl FromArrow for ListArray { arrow_arr.validity().cloned(), )) } - (d, a) => Err(DaftError::TypeError(format!("Attempting to create Daft FixedSizeListArray with type {} from arrow array with type {:?}", d, a))) + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => { + let map_arr = arrow_arr + .as_any() + .downcast_ref::() + .unwrap(); + let arrow_child_array = map_arr.field(); + let child_series = Series::from_arrow( + Arc::new(Field::new("map", daft_child_dtype.as_ref().clone())), + arrow_child_array.clone(), + )?; + Ok(ListArray::new( + field.clone(), + child_series, + map_arr.offsets().try_into().unwrap(), + arrow_arr.validity().cloned(), + )) + } + (d, a) => Err(DaftError::TypeError(format!( + "Attempting to create Daft ListArray with type {} from arrow array with type {:?}", + d, a + ))), } } } diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 6f972325ef..5776d19b31 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -2,7 +2,8 @@ use crate::{ array::{DataArray, FixedSizeListArray, ListArray}, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, TimeArray, TimestampArray, + DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, MapArray, TimeArray, + TimestampArray, }, BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, NullArray, Utf8Array, @@ -159,6 +160,13 @@ impl ListArray { } } +impl MapArray { + #[inline] + pub fn get(&self, idx: usize) -> Option { + self.physical.get(idx) + } +} + #[cfg(test)] mod tests { use common_error::DaftResult; diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 5ef14773e6..4c447ad3a5 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -5,7 +5,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, ImageFormat, NullArray, UInt64Array, Utf8Array, @@ -227,6 +227,16 @@ impl FixedSizeListArray { } } +impl MapArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let val = self.get(idx); + match val { + None => Ok("None".to_string()), + Some(v) => series_as_list_str(&v), + } + } +} + impl EmbeddingArray { pub fn str_value(&self, idx: usize) -> DaftResult { if self.physical.is_valid(idx) { @@ -338,6 +348,7 @@ impl_array_html_value!(NullArray); impl_array_html_value!(BinaryArray); impl_array_html_value!(ListArray); impl_array_html_value!(FixedSizeListArray); +impl_array_html_value!(MapArray); impl_array_html_value!(StructArray); impl_array_html_value!(ExtensionArray); impl_array_html_value!(Decimal128Array); diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index f1813c7a4c..ae3bcf012e 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -3,7 +3,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, Float32Array, Float64Array, NullArray, Utf8Array, @@ -574,6 +574,12 @@ impl ListArray { } } +impl MapArray { + pub fn sort(&self, _descending: bool) -> DaftResult { + todo!("impl sort for MapArray") + } +} + impl StructArray { pub fn sort(&self, _descending: bool) -> DaftResult { todo!("impl sort for StructArray") diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index e393d4342e..880fb2afeb 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -6,7 +6,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, NullArray, Utf8Array, @@ -78,6 +78,7 @@ impl_logicalarray_take!(ImageArray); impl_logicalarray_take!(FixedShapeImageArray); impl_logicalarray_take!(TensorArray); impl_logicalarray_take!(FixedShapeTensorArray); +impl_logicalarray_take!(MapArray); #[cfg(feature = "python")] impl crate::datatypes::PythonArray { diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 9e9815e8cb..04e28a5c9c 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -76,6 +76,8 @@ pub enum DataType { List(Box), /// A nested [`DataType`] with a given number of [`Field`]s. Struct(Vec), + /// A nested [`DataType`] that is represented as List>. + Map(Box), /// Extension type. Extension(String, Box, Option), // Stop ArrowTypes @@ -156,6 +158,14 @@ impl DataType { DataType::List(field) => Ok(ArrowType::LargeList(Box::new( arrow2::datatypes::Field::new("item", field.to_arrow()?, true), ))), + DataType::Map(field) => Ok(ArrowType::Map( + Box::new(arrow2::datatypes::Field::new( + "item", + field.to_arrow()?, + true, + )), + false, + )), DataType::Struct(fields) => Ok({ let fields = fields .iter() @@ -201,6 +211,7 @@ impl DataType { FixedSizeList(child_dtype, size) => { FixedSizeList(Box::new(child_dtype.to_physical()), *size) } + Map(child_dtype) => List(Box::new(child_dtype.to_physical())), Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size), Image(mode) => Struct(vec![ Field::new( @@ -310,6 +321,11 @@ impl DataType { matches!(self, DataType::FixedShapeImage(..)) } + #[inline] + pub fn is_map(&self) -> bool { + matches!(self, DataType::Map(..)) + } + #[inline] pub fn is_null(&self) -> bool { match self { @@ -385,6 +401,7 @@ impl DataType { | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) + | DataType::Map(..) ) } @@ -398,7 +415,10 @@ impl DataType { let p: DataType = self.to_physical(); matches!( p, - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Struct(..) + DataType::List(..) + | DataType::FixedSizeList(..) + | DataType::Struct(..) + | DataType::Map(..) ) } @@ -449,6 +469,7 @@ impl From<&ArrowType> for DataType { ArrowType::FixedSizeList(field, size) => { DataType::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } + ArrowType::Map(field, ..) => DataType::Map(Box::new(field.as_ref().data_type().into())), ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); DataType::Struct(fields) @@ -493,6 +514,9 @@ impl Display for DataType { DataType::FixedSizeList(inner, size) => { write!(f, "FixedSizeList[{}; {}]", inner, size) } + DataType::Map(inner, ..) => { + write!(f, "Map[{}]", inner) + } DataType::Struct(fields) => { let fields: String = fields .iter() diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index ec00c609b7..d3a2c9306d 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -1,7 +1,7 @@ use std::{marker::PhantomData, sync::Arc}; use crate::{ - array::StructArray, + array::{ListArray, StructArray}, datatypes::{DaftLogicalType, DateType, Field}, with_match_daft_logical_primitive_types, }; @@ -9,8 +9,8 @@ use common_error::DaftResult; use super::{ DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, - FixedShapeImageType, FixedShapeTensorType, FixedSizeListArray, ImageType, TensorType, TimeType, - TimestampType, + FixedShapeImageType, FixedShapeTensorType, FixedSizeListArray, ImageType, MapType, TensorType, + TimeType, TimestampType, }; /// A LogicalArray is a wrapper on top of some underlying array, applying the semantic meaning of its @@ -62,6 +62,9 @@ impl LogicalArrayImpl { macro_rules! impl_logical_type { ($physical_array_type:ident) => { + // Clippy triggers false positives here for the MapArray implementation + // This is added to suppress the warning + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { self.physical.len() } @@ -144,6 +147,20 @@ impl LogicalArrayImpl { } } +impl MapArray { + impl_logical_type!(ListArray); + + pub fn to_arrow(&self) -> Box { + let arrow_dtype = self.data_type().to_arrow().unwrap(); + Box::new(arrow2::array::MapArray::new( + arrow_dtype, + self.physical.offsets().try_into().unwrap(), + self.physical.flat_child.to_arrow(), + self.physical.validity().cloned(), + )) + } +} + pub type LogicalArray = LogicalArrayImpl::PhysicalType as DaftDataType>::ArrayType>; pub type Decimal128Array = LogicalArray; @@ -156,7 +173,7 @@ pub type TensorArray = LogicalArray; pub type EmbeddingArray = LogicalArray; pub type FixedShapeTensorArray = LogicalArray; pub type FixedShapeImageArray = LogicalArray; - +pub type MapArray = LogicalArray; pub trait DaftImageryType: DaftLogicalType {} impl DaftImageryType for ImageType {} diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index 42778f8e85..ab1817528d 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -30,6 +30,7 @@ macro_rules! with_match_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, + Map(_) => __with_ty__! { MapType }, Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 1b91b499ca..68d9c55464 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -132,6 +132,26 @@ macro_rules! impl_daft_logical_fixed_size_list_datatype { }; } +macro_rules! impl_daft_logical_list_datatype { + ($ca:ident, $variant:ident) => { + #[derive(Clone, Debug)] + pub struct $ca {} + + impl DaftDataType for $ca { + #[inline] + fn get_dtype() -> DataType { + DataType::$variant + } + + type ArrayType = logical::LogicalArray<$ca>; + } + + impl DaftLogicalType for $ca { + type PhysicalType = ListType; + } + }; +} + macro_rules! impl_nested_datatype { ($ca:ident, $array_type:ident) => { #[derive(Clone, Debug)] @@ -181,6 +201,7 @@ impl_daft_logical_data_array_datatype!(TensorType, Unknown, StructType); impl_daft_logical_fixed_size_list_datatype!(EmbeddingType, Unknown); impl_daft_logical_fixed_size_list_datatype!(FixedShapeImageType, Unknown); impl_daft_logical_fixed_size_list_datatype!(FixedShapeTensorType, Unknown); +impl_daft_logical_list_datatype!(MapType, Unknown); #[cfg(feature = "python")] impl_daft_non_arrow_datatype!(PythonType, Python); diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index aacdd647ad..4e4fa3de87 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -200,6 +200,15 @@ impl PyDataType { Ok(DataType::FixedSizeList(Box::new(data_type.dtype), usize::try_from(size)?).into()) } + #[staticmethod] + pub fn map(key_type: Self, value_type: Self) -> PyResult { + Ok(DataType::Map(Box::new(DataType::Struct(vec![ + Field::new("key", key_type.dtype), + Field::new("value", value_type.dtype), + ]))) + .into()) + } + #[staticmethod] pub fn r#struct(fields: &PyDict) -> PyResult { Ok(DataType::Struct( @@ -350,6 +359,10 @@ impl PyDataType { Ok(self.dtype.is_fixed_shape_tensor()) } + pub fn is_map(&self) -> PyResult { + Ok(self.dtype.is_map()) + } + pub fn is_logical(&self) -> PyResult { Ok(self.dtype.is_logical()) } diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index 5fd46e106a..443bdc6c0a 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -7,7 +7,10 @@ use crate::{ ops::{DaftCompare, DaftLogical}, FixedSizeListArray, ListArray, StructArray, }, - datatypes::{logical::Decimal128Array, Int128Array}, + datatypes::{ + logical::{Decimal128Array, MapArray}, + Int128Array, + }, series::series_like::SeriesLike, with_match_comparable_daft_types, with_match_numeric_daft_types, DataType, }; @@ -226,6 +229,7 @@ impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} +impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index f8dc7a2e94..91b31a78e5 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -1,6 +1,7 @@ use crate::datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, TensorArray, TimeArray, + TimestampArray, }; use crate::datatypes::{BooleanArray, DaftLogicalType, Field}; @@ -234,3 +235,4 @@ impl_series_like_for_logical_array!(TensorArray); impl_series_like_for_logical_array!(EmbeddingArray); impl_series_like_for_logical_array!(FixedShapeImageArray); impl_series_like_for_logical_array!(FixedShapeTensorArray); +impl_series_like_for_logical_array!(MapArray); diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index 6503a6ebae..69a0ca33ec 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -10,7 +10,7 @@ use crate::{ }, datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, + FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, with_match_daft_types, DataType, IntoSeries, Series, }; @@ -154,6 +154,14 @@ impl<'d> serde::Deserialize<'d> for Series { .unwrap() .into_series()) } + Map(..) => { + let physical = map.next_value::()?; + Ok(MapArray::new( + Arc::new(field), + physical.downcast::().unwrap().clone(), + ) + .into_series()) + } Struct(..) => { let mut all_series = map.next_value::>>()?; let validity = all_series diff --git a/src/daft-core/src/utils/arrow.rs b/src/daft-core/src/utils/arrow.rs index 60c11459f8..a18b387994 100644 --- a/src/daft-core/src/utils/arrow.rs +++ b/src/daft-core/src/utils/arrow.rs @@ -43,6 +43,20 @@ fn coerce_to_daft_compatible_type( .with_metadata(field.metadata.clone()), ))) } + arrow2::datatypes::DataType::Map(field, sorted) => { + let new_field = match coerce_to_daft_compatible_type(field.data_type()) { + Some(new_inner_dtype) => Box::new( + arrow2::datatypes::Field::new( + field.name.clone(), + new_inner_dtype, + field.is_nullable, + ) + .with_metadata(field.metadata.clone()), + ), + None => field.clone(), + }; + Some(arrow2::datatypes::DataType::Map(new_field, *sorted)) + } arrow2::datatypes::DataType::FixedSizeList(field, size) => { let new_inner_dtype = coerce_to_daft_compatible_type(field.data_type())?; Some(arrow2::datatypes::DataType::FixedSizeList( @@ -95,15 +109,33 @@ pub fn cast_array_for_daft_if_needed( arrow_array: Box, ) -> Box { match coerce_to_daft_compatible_type(arrow_array.data_type()) { - Some(coerced_dtype) => cast::cast( - arrow_array.as_ref(), - &coerced_dtype, - cast::CastOptions { - wrapped: true, - partial: false, - }, - ) - .unwrap(), + Some(coerced_dtype) => match coerced_dtype { + // TODO: Consolidate Map to use the same cast::cast method as other datatypes. + // Currently, Map is not supported in Arrow2::compute::cast, so this workaround is necessary. + // A known limitation of this workaround is that it does not handle nested maps. + arrow2::datatypes::DataType::Map(to_field, sorted) => { + let map_array = arrow_array + .as_any() + .downcast_ref::() + .unwrap(); + let casted = cast_array_for_daft_if_needed(map_array.field().clone()); + Box::new(arrow2::array::MapArray::new( + arrow2::datatypes::DataType::Map(to_field.clone(), sorted), + map_array.offsets().clone(), + casted, + arrow_array.validity().cloned(), + )) + } + _ => cast::cast( + arrow_array.as_ref(), + &coerced_dtype, + cast::CastOptions { + wrapped: true, + partial: false, + }, + ) + .unwrap(), + }, None => arrow_array, } } diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index 4ffec69b2c..52e426a779 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -195,6 +195,11 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { let inner_st = get_supertype(inner_left_dtype.as_ref(), inner_right_dtype.as_ref())?; Some(DataType::List(Box::new(inner_st))) } + // TODO(Colin): Add support for getting supertype for two maps once StructArray supports such a cast. + // (Map(inner_left_dtype), Map(inner_right_dtype)) => { + // let inner_st = get_supertype(inner_left_dtype.as_ref(), inner_right_dtype.as_ref())?; + // Some(DataType::Map(Box::new(inner_st))) + // } // TODO(Clark): Add support for getting supertype for two fixed size lists once Arrow2 supports such a cast. // (FixedSizeList(inner_left_field, inner_left_size), FixedSizeList(inner_right_field, inner_right_size)) if inner_left_size == inner_right_size => { // let inner_st = inner(&inner_left_field.dtype, &inner_right_field.dtype)?; diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index ed2d2dfa72..6da6bce270 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -75,7 +75,7 @@ impl ColumnRangeStatistics { // UNSUPPORTED TYPES: // Types that don't support comparisons and can't be used as ColumnRangeStatistics - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, #[cfg(feature = "python")] DataType::Python => false, } diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 52ab451824..b330925848 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -350,6 +350,27 @@ def test_create_dataframe_pandas_tensor(valid_data: list[dict[str, float]]) -> N id="numpy_1d_arrays", ), pytest.param(pa.array([[1, 2, 3], [1, 2], [1]]), DataType.list(DataType.int64()), id="pa_nested"), + pytest.param( + pa.array([[("a", 1), ("b", 2)], [("c", 3), ("d", 4)]], type=pa.map_(pa.string(), pa.int64())), + DataType.map(DataType.string(), DataType.int64()), + id="pa_map", + ), + # TODO(Colin): Enable this test once cast_array_for_daft_if_needed in src/daft-core/src/utils/arrow.rs supports nested maps + # pytest.param( + # pa.array( + # [{"a": {"b": 1}, "c": {"d": 2}}, {"e": {"f": 3}, "g": {"h": 4}}], + # type=pa.map_(pa.string(), pa.map_(pa.string(), pa.int64())), + # ), + # DataType.map( + # DataType.struct( + # { + # "key": DataType.string(), + # "value": DataType.map(DataType.struct({"key": DataType.string(), "value": DataType.int64()})), + # } + # ) + # ), + # id="pa_nested_map", + # ), pytest.param( pa.chunked_array([pa.array([[1, 2, 3], [1, 2], [1]])]), DataType.list(DataType.int64()), diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 4b7692e3c4..4e106c47a8 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -21,7 +21,7 @@ def test_daft_iceberg_table_open(local_iceberg_tables): WORKING_SHOW_COLLECT = [ - # "test_all_types", # ValueError: DaftError::ArrowError Not yet implemented: Deserializing type Decimal(10, 2) from parquet + "test_all_types", "test_limit", "test_null_nan", "test_null_nan_rewritten", diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index a5e7e6ee43..8f25a187ed 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -58,11 +58,10 @@ def base_table() -> pa.Table: ), "i": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "j": [{"x": 1, "y": False}, {"y": True, "z": "foo"}, {"x": 5, "z": "bar"}], - # TODO(Clark): Uncomment test case when MapArray support is merged. - # "k": pa.array( - # [[("x", 1), ("y", 0)], [("a", 2), ("b", 45)], [("c", 4), ("d", 18)]], - # type=pa.map_(pa.string(), pa.int64()), - # ), + "k": pa.array( + [[("x", 1), ("y", 0)], [("a", 2), ("b", 45)], [("c", 4), ("d", 18)]], + type=pa.map_(pa.string(), pa.int64()), + ), # TODO(Clark): Wait for more temporal type support in Delta Lake. # "l": [ # datetime.time(hour=1, minute=2, second=4, microsecond=5), diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index 39df7581ab..7d6aee6a4b 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -118,12 +118,14 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema) "struct_nested_timestamp": pa.array( [{"foo": [dt]} for _ in range(3)], type=pa.struct({"foo": pa.list_(pa.timestamp("ns"))}) ), + "map_timestamp": pa.array([[("foo", dt)] for _ in range(3)], type=pa.map_(pa.string(), pa.timestamp("ns"))), } schema = [ ("timestamp", DataType.timestamp(coerce_to)), ("nested_timestamp", DataType.list(DataType.timestamp(coerce_to))), ("struct_timestamp", DataType.struct({"foo": DataType.timestamp(coerce_to)})), ("struct_nested_timestamp", DataType.struct({"foo": DataType.list(DataType.timestamp(coerce_to))})), + ("map_timestamp", DataType.map(DataType.string(), DataType.timestamp(coerce_to))), ] expected = Schema._from_field_name_and_types(schema) diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 79050f3fcc..81569f1be2 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -51,6 +51,11 @@ # TODO: Crashes when parsing fixed size lists # ([[1, 2, 3], [4, 5, 6], None], pa.list_(pa.int64(), list_size=3), DataType.fixed_size_list(DataType.int64(), 3)), ([{"bar": 1}, {"bar": None}, None], pa.struct({"bar": pa.int64()}), DataType.struct({"bar": DataType.int64()})), + ( + [[("a", 1), ("b", 2)], [], None], + pa.map_(pa.large_string(), pa.int64()), + DataType.map(DataType.string(), DataType.int64()), + ), ], ) def test_roundtrip_simple_arrow_types(tmp_path, data, pa_type, expected_dtype): diff --git a/tests/series/test_concat.py b/tests/series/test_concat.py index 6599097d5e..f04e00b187 100644 --- a/tests/series/test_concat.py +++ b/tests/series/test_concat.py @@ -87,6 +87,31 @@ def test_series_concat_list_array(chunks, fixed) -> None: counter += 1 +@pytest.mark.parametrize("chunks", [1, 2, 3, 10]) +def test_series_concat_map_array(chunks) -> None: + series = [] + for i in range(chunks): + series.append( + Series.from_arrow( + pa.array( + [[("a", i + j), ("b", float(i * j))] for j in range(i)], + type=pa.map_(pa.string(), pa.float64()), + ) + ) + ) + + concated = Series.concat(series) + + assert concated.datatype() == DataType.map(DataType.string(), DataType.float64()) + concated_list = concated.to_pylist() + counter = 0 + for i in range(chunks): + for j in range(i): + assert concated_list[counter][0][1] == i + j + assert concated_list[counter][1][1] == float(i * j) + counter += 1 + + @pytest.mark.parametrize("chunks", [1, 2, 3, 10]) def test_series_concat_struct_array(chunks) -> None: series = [] diff --git a/tests/series/test_if_else.py b/tests/series/test_if_else.py index 591cd449bf..482fc3a841 100644 --- a/tests/series/test_if_else.py +++ b/tests/series/test_if_else.py @@ -231,6 +231,70 @@ def test_series_if_else_fixed_size_list(if_true, if_false, expected) -> None: assert result.to_pylist() == expected +@pytest.mark.parametrize( + ["if_true", "if_false", "expected"], + [ + # Same length, same type + ( + pa.array( + [[("a", 1), ("b", 2)], [("b", 3), ("c", 4)], None, [("a", 5), ("c", 7)]], + type=pa.map_(pa.string(), pa.int64()), + ), + pa.array( + [[("a", 8), ("b", 9)], [("c", 10)], None, [("a", 12), ("b", 13)]], + type=pa.map_(pa.string(), pa.int64()), + ), + [[("a", 1), ("b", 2)], [("c", 10)], None, [("a", 5), ("c", 7)]], + ), + # TODO(Colin): Uncomment this case when StructArrays are supported. + # Same length, different super-castable data type + # ( + # pa.array( + # [[("a", 1), ("b", 2)], [("b", 3), ("c", 4)], None, [("a", 5), ("c", 7)]], + # type=pa.map_(pa.string(), pa.int64()), + # ), + # pa.array( + # [[("a", 8), ("b", 9)], [("c", 10)], None, [("a", 12), ("b", 13)]], + # type=pa.map_(pa.string(), pa.int64()), + # ), + # [[("a", 1), ("b", 2)], [("c", 10)], None, [("a", 5), ("c", 7)]], + # ), + # ), + # Broadcast left + ( + pa.array([[("a", 1), ("b", 2)]], type=pa.map_(pa.string(), pa.int64())), + pa.array( + [[("a", 8), ("b", 9)], [("c", 10)], None, [("a", 12), ("b", 13)]], + type=pa.map_(pa.string(), pa.int64()), + ), + [[("a", 1), ("b", 2)], [("c", 10)], None, [("a", 1), ("b", 2)]], + ), + # Broadcast right + ( + pa.array( + [[("a", 1), ("b", 2)], [("b", 3), ("c", 4)], None, [("a", 5), ("c", 7)]], + type=pa.map_(pa.string(), pa.int64()), + ), + pa.array([[("a", 8), ("b", 9)]], type=pa.map_(pa.string(), pa.int64())), + [[("a", 1), ("b", 2)], [("a", 8), ("b", 9)], None, [("a", 5), ("c", 7)]], + ), + # Broadcast both + ( + pa.array([[("a", 1), ("b", 2)]], type=pa.map_(pa.string(), pa.int64())), + pa.array([[("a", 8), ("b", 9)]], type=pa.map_(pa.string(), pa.int64())), + [[("a", 1), ("b", 2)], [("a", 8), ("b", 9)], None, [("a", 1), ("b", 2)]], + ), + ], +) +def test_series_if_else_map(if_true, if_false, expected) -> None: + if_true_series = Series.from_arrow(if_true) + if_false_series = Series.from_arrow(if_false) + predicate_series = Series.from_arrow(pa.array([True, False, None, True])) + result = predicate_series.if_else(if_true_series, if_false_series) + assert result.datatype() == DataType.map(DataType.string(), DataType.int64()) + assert result.to_pylist() == expected + + @pytest.mark.parametrize( ["if_true", "if_false", "expected"], [ diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index ec89cd87be..95012309be 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -98,6 +98,10 @@ "time64_nanoseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("ns")), "list": pa.array(PYTHON_TYPE_ARRAYS["list"], pa.list_(pa.int64())), "fixed_size_list": pa.array([[1, 2], [3, 4]], pa.list_(pa.int64(), 2)), + "map": pa.array( + [[(1.0, 1), (2.0, 2)], [(3.0, 3), (4.0, 4)]], + pa.map_(pa.float32(), pa.int32()), + ), "struct": pa.array(PYTHON_TYPE_ARRAYS["struct"], pa.struct([("a", pa.int64()), ("b", pa.float64())])), "empty_struct": pa.array(PYTHON_TYPE_ARRAYS["empty_struct"], pa.struct([])), "nested_struct": pa.array( @@ -149,6 +153,7 @@ "time64_nanoseconds": pa.time64("ns"), "list": pa.large_list(pa.int64()), "fixed_size_list": pa.list_(pa.int64(), 2), + "map": pa.map_(pa.float32(), pa.int32()), "struct": pa.struct([("a", pa.int64()), ("b", pa.float64())]), "empty_struct": pa.struct([]), "nested_struct": pa.struct([("a", pa.struct([("b", pa.int64())])), ("c", pa.struct([]))]), @@ -285,6 +290,17 @@ def test_from_pydict_arrow_fixed_size_list_array() -> None: assert daft_table.to_arrow()["a"].combine_chunks() == expected +def test_from_pydict_arrow_map_array() -> None: + data = [[(1, 2.0), (3, 4.0)], None, [(5, 6.0), (7, 8.0)]] + arrow_arr = pa.array(data, pa.map_(pa.int64(), pa.float64())) + daft_table = MicroPartition.from_pydict({"a": arrow_arr}) + assert "a" in daft_table.column_names() + # Perform expected Daft cast, where the inner string and int arrays are cast to large string and int arrays. + expected = arrow_arr.cast(pa.map_(pa.int64(), pa.float64())) + assert daft_table.to_arrow()["a"].combine_chunks() == expected + assert daft_table.to_pydict()["a"] == data + + def test_from_pydict_arrow_struct_array() -> None: data = [{"a": "foo", "b": "bar"}, {"b": "baz", "c": "quux"}] arrow_arr = pa.array(data) @@ -342,6 +358,10 @@ def test_from_pydict_arrow_deeply_nested() -> None: pa.array([{"a": 1, "b": 2}, {"b": 3, "c": 4}, None, {"a": 5, "c": 6}]), pa.struct([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())]), ), + ( + pa.array([[(1, 2), (3, 4)], None, [(5, 6), (7, 8)]], pa.map_(pa.int64(), pa.int64())), + pa.map_(pa.int64(), pa.int64()), + ), ], ) @pytest.mark.parametrize("chunked", [False, True]) @@ -368,6 +388,10 @@ def test_from_pydict_arrow_with_nulls_roundtrip(data, out_dtype, chunked) -> Non pa.array([{"a": 1, "b": 2}, {"b": 3, "c": 4}, {"a": 5}, {"a": 6, "c": 7}]), pa.struct([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())]), ), + ( + pa.array([[(1, 2), (3, 4)], [(5, 6)], [(7, 8)]], pa.map_(pa.int64(), pa.int64())), + pa.map_(pa.int64(), pa.int64()), + ), # Contains nulls. (pa.array([1, 2, None, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", None, "d"], type=pa.string()), pa.large_string()), @@ -378,6 +402,10 @@ def test_from_pydict_arrow_with_nulls_roundtrip(data, out_dtype, chunked) -> Non pa.array([{"a": 1, "b": 2}, {"b": 3, "c": 4}, None, {"a": 5, "c": 6}]), pa.struct([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())]), ), + ( + pa.array([[(1, 2), (3, 4)], None, [(7, 8)]], pa.map_(pa.int64(), pa.int64())), + pa.map_(pa.int64(), pa.int64()), + ), ], ) @pytest.mark.parametrize("chunked", [False, True]) @@ -414,6 +442,10 @@ def test_from_pydict_series() -> None: pa.array([{"a": 1, "b": 2}, {"b": 3, "c": 4}, {"a": 5}, {"a": 6, "c": 7}]), pa.struct([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())]), ), + ( + pa.array([[(1, 2), (3, 4)], [(5, 6)], [(7, 8)]], pa.map_(pa.int64(), pa.int64())), + pa.map_(pa.int64(), pa.int64()), + ), # Contains nulls. (pa.array([1, 2, None, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", None, "d"], type=pa.string()), pa.large_string()), @@ -424,6 +456,10 @@ def test_from_pydict_series() -> None: pa.array([{"a": 1, "b": 2}, {"b": 3, "c": 4}, None, {"a": 5, "c": 6}]), pa.struct([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())]), ), + ( + pa.array([[(1, 2), (3, 4)], None, [(7, 8)]], pa.map_(pa.int64(), pa.int64())), + pa.map_(pa.int64(), pa.int64()), + ), ], ) @pytest.mark.parametrize("slice_", list(itertools.combinations(range(4), 2))) @@ -470,6 +506,17 @@ def test_from_arrow_struct_array() -> None: assert daft_table.to_arrow()["a"].combine_chunks() == expected +def test_from_arrow_map_array() -> None: + data = [[(1.0, 1), (2.0, 2)], [(3.0, 3), (4.0, 4)]] + arrow_arr = pa.array(data, pa.map_(pa.float32(), pa.int32())) + daft_table = MicroPartition.from_arrow(pa.table({"a": arrow_arr})) + assert "a" in daft_table.column_names() + # Perform expected Daft cast, where the inner string and int arrays are cast to large string and int arrays. + expected = arrow_arr.cast(pa.map_(pa.float32(), pa.int32())) + assert daft_table.to_arrow()["a"].combine_chunks() == expected + assert daft_table.to_pydict()["a"] == data + + @pytest.mark.skipif( get_context().runner_config.name == "ray", reason="pyarrow extension types aren't supported on Ray clusters.", diff --git a/tests/test_schema.py b/tests/test_schema.py index 84402f6a2b..d29ba69684 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -165,6 +165,7 @@ def test_schema_from_pyarrow(): "int": pa.int64(), "str": pa.string(), "list": pa.list_(pa.int64()), + "map": pa.map_(pa.string(), pa.int64()), } ) @@ -173,6 +174,7 @@ def test_schema_from_pyarrow(): ("int", DataType.int64()), ("str", DataType.string()), ("list", DataType.list(DataType.int64())), + ("map", DataType.map(DataType.string(), DataType.int64())), ] )