From 926cc21c8e63573bdad03cc46ce5a03d13446f1f Mon Sep 17 00:00:00 2001 From: Sagi Ahrac Date: Mon, 7 Oct 2024 12:55:26 +0300 Subject: [PATCH] cast logic --- src/daft-core/src/array/ops/cast.rs | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index c3dbe0c209..0c35a44a5b 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1698,6 +1698,43 @@ impl SparseTensorArray { ); Ok(sparse_tensor_array.into_series()) } + DataType::Python => Python::with_gil(|py| { + let mut pydicts: Vec> = Vec::with_capacity(self.len()); + let sa = self.shape_array(); + let va = self.values_array(); + let ia = self.indices_array(); + let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?; + for ((shape_array, values_array), indices_array) in sa + .into_iter() + .zip(va.into_iter()) + .zip(ia.into_iter()) + { + if let (Some(shape_array), Some(values_array), Some(indices_array)) = + (shape_array, values_array, indices_array) + { + let shape_array = shape_array.u64().unwrap().as_arrow(); + let shape = shape_array.values().to_vec(); + let py_values_array = ffi::to_py_array(py, values_array.to_arrow(), &pyarrow)? + .call_method1(pyo3::intern!(py, "to_numpy"), (false,))?; + let py_indices_array = ffi::to_py_array(py, indices_array.to_arrow(), &pyarrow)? + .call_method1(pyo3::intern!(py, "to_numpy"), (false,))?; + let pydict = pyo3::types::PyDict::new_bound(py); + pydict.set_item("values", py_values_array)?; + pydict.set_item("indices", py_indices_array)?; + pydict.set_item("shape", shape)?; + pydicts.push(pydict.unbind().into()); + } else { + pydicts.push(py.None()); + } + } + let py_objects_array = + PseudoArrowArray::new(pydicts.into(), self.physical.validity().cloned()); + Ok(PythonArray::new( + Field::new(self.name(), dtype.clone()).into(), + py_objects_array.to_boxed(), + )? + .into_series()) + }), _ => self.physical.cast(dtype), } } @@ -1793,6 +1830,14 @@ impl FixedShapeSparseTensorArray { FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), physical); Ok(fixed_shape_tensor_array.into_series()) } + ( + DataType::Python, + DataType::FixedShapeSparseTensor(inner_dtype, _), + ) => { + let sparse_tensor_series = self.cast(&DataType::SparseTensor(inner_dtype.clone()))?; + let sparse_pytensor_series = sparse_tensor_series.cast(&DataType::Python)?; + Ok(sparse_pytensor_series) + } (_, _) => self.physical.cast(dtype), } }