Skip to content

Commit

Permalink
cast logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sagiahrac committed Oct 7, 2024
1 parent 396c004 commit 926cc21
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,43 @@ impl SparseTensorArray {
);
Ok(sparse_tensor_array.into_series())
}
DataType::Python => Python::with_gil(|py| {
let mut pydicts: Vec<Py<PyAny>> = Vec::with_capacity(self.len());
let sa = self.shape_array();
let va = self.values_array();
let ia = self.indices_array();
let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?;
for ((shape_array, values_array), indices_array) in sa
.into_iter()
.zip(va.into_iter())
.zip(ia.into_iter())
{
if let (Some(shape_array), Some(values_array), Some(indices_array)) =
(shape_array, values_array, indices_array)
{
let shape_array = shape_array.u64().unwrap().as_arrow();
let shape = shape_array.values().to_vec();
let py_values_array = ffi::to_py_array(py, values_array.to_arrow(), &pyarrow)?
.call_method1(pyo3::intern!(py, "to_numpy"), (false,))?;
let py_indices_array = ffi::to_py_array(py, indices_array.to_arrow(), &pyarrow)?
.call_method1(pyo3::intern!(py, "to_numpy"), (false,))?;
let pydict = pyo3::types::PyDict::new_bound(py);
pydict.set_item("values", py_values_array)?;
pydict.set_item("indices", py_indices_array)?;
pydict.set_item("shape", shape)?;
pydicts.push(pydict.unbind().into());
} else {
pydicts.push(py.None());
}
}
let py_objects_array =
PseudoArrowArray::new(pydicts.into(), self.physical.validity().cloned());
Ok(PythonArray::new(
Field::new(self.name(), dtype.clone()).into(),
py_objects_array.to_boxed(),
)?
.into_series())
}),
_ => self.physical.cast(dtype),
}
}
Expand Down Expand Up @@ -1793,6 +1830,14 @@ impl FixedShapeSparseTensorArray {
FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), physical);
Ok(fixed_shape_tensor_array.into_series())
}
(
DataType::Python,
DataType::FixedShapeSparseTensor(inner_dtype, _),
) => {
let sparse_tensor_series = self.cast(&DataType::SparseTensor(inner_dtype.clone()))?;
let sparse_pytensor_series = sparse_tensor_series.cast(&DataType::Python)?;
Ok(sparse_pytensor_series)
}
(_, _) => self.physical.cast(dtype),
}
}
Expand Down

0 comments on commit 926cc21

Please sign in to comment.