From 4a60880ad7f81bf50dded9c516db65aef21aa8e0 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Mon, 26 Aug 2024 14:42:25 -0500 Subject: [PATCH] fix: fsl to list with validity --- src/daft-core/src/array/ops/cast.rs | 11 +---------- src/daft-core/src/array/ops/get.rs | 30 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 8e0919d39a..e2d2110456 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1631,16 +1631,7 @@ impl FixedSizeListArray { DataType::List(child_dtype) => { let element_size = self.fixed_element_len(); let casted_child = self.flat_child.cast(child_dtype.as_ref())?; - let offsets: Offsets = match self.validity() { - None => Offsets::try_from_iter(repeat(element_size).take(self.len()))?, - Some(validity) => Offsets::try_from_iter(validity.iter().map(|v| { - if v { - element_size - } else { - 0 - } - }))?, - }; + let offsets = Offsets::try_from_iter(repeat(element_size).take(self.len()))?; Ok(ListArray::new( Field::new(self.name().to_string(), dtype.clone()), casted_child, diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 1868669d0e..6a78b57e67 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -248,4 +248,34 @@ mod tests { Ok(()) } + + #[test] + fn test_list_get_some_valid() -> DaftResult<()> { + let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int32), 3)); + let flat_child = Int32Array::from(("foo", (0..9).collect::>())); + let raw_validity = vec![true, false, true]; + let validity = Some(arrow2::bitmap::Bitmap::from(raw_validity.as_slice())); + let arr = FixedSizeListArray::new(field, flat_child.into_series(), validity); + let list_dtype = DataType::List(Box::new(DataType::Int32)); + let list_arr = arr.cast(&list_dtype)?; + let l = list_arr.list()?; + let element = l.get(0).unwrap(); + let element = element.i32()?; + let data = element + .into_iter() + .map(|x| x.copied()) + .collect::>>(); + let expected = vec![Some(0), Some(1), Some(2)]; + assert_eq!(data, expected); + let element = l.get(2).unwrap(); + let element = element.i32()?; + let data = element + .into_iter() + .map(|x| x.copied()) + .collect::>>(); + let expected = vec![Some(6), Some(7), Some(8)]; + assert_eq!(data, expected); + + Ok(()) + } }