Skip to content

Commit

Permalink
[FEAT] Cast SparseTensor and FixedShapeSparseTensor to Python (#3010)
Browse files Browse the repository at this point in the history
Addresses: #3009

This PR enables casting of SparseTensor and FixedShapeSparseTensor to
Python, allowing iteration over Daft DataFrames with sparse tensor
columns without converting to dense formats.
  • Loading branch information
sagiahrac authored Oct 8, 2024
1 parent 648b804 commit c2397bf
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,44 @@ impl SparseTensorArray {
);
Ok(sparse_tensor_array.into_series())
}
#[cfg(feature = "python")]
DataType::Python => Python::with_gil(|py| {
let mut pydicts: Vec<Py<PyAny>> = Vec::with_capacity(self.len());
let sa = self.shape_array();
let va = self.values_array();
let ia = self.indices_array();
let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?;
for ((shape_array, values_array), indices_array) in
sa.into_iter().zip(va.into_iter()).zip(ia.into_iter())
{
if let (Some(shape_array), Some(values_array), Some(indices_array)) =
(shape_array, values_array, indices_array)
{
let shape_array = shape_array.u64().unwrap().as_arrow();
let shape = shape_array.values().to_vec();
let py_values_array =
ffi::to_py_array(py, values_array.to_arrow(), &pyarrow)?
.call_method1(pyo3::intern!(py, "to_numpy"), (false,))?;
let py_indices_array =
ffi::to_py_array(py, indices_array.to_arrow(), &pyarrow)?
.call_method1(pyo3::intern!(py, "to_numpy"), (false,))?;
let pydict = pyo3::types::PyDict::new_bound(py);
pydict.set_item("values", py_values_array)?;
pydict.set_item("indices", py_indices_array)?;
pydict.set_item("shape", shape)?;
pydicts.push(pydict.unbind().into());
} else {
pydicts.push(py.None());
}
}
let py_objects_array =
PseudoArrowArray::new(pydicts.into(), self.physical.validity().cloned());
Ok(PythonArray::new(
Field::new(self.name(), dtype.clone()).into(),
py_objects_array.to_boxed(),
)?
.into_series())
}),
_ => self.physical.cast(dtype),
}
}
Expand Down Expand Up @@ -1786,6 +1824,13 @@ impl FixedShapeSparseTensorArray {
FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), physical);
Ok(fixed_shape_tensor_array.into_series())
}
#[cfg(feature = "python")]
(DataType::Python, DataType::FixedShapeSparseTensor(inner_dtype, _)) => {
let sparse_tensor_series =
self.cast(&DataType::SparseTensor(inner_dtype.clone()))?;
let sparse_pytensor_series = sparse_tensor_series.cast(&DataType::Python)?;
Ok(sparse_pytensor_series)
}
(_, _) => self.physical.cast(dtype),
}
}
Expand Down
35 changes: 35 additions & 0 deletions tests/series/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,3 +1161,38 @@ def test_series_cast_fixed_size_list_to_list() -> None:
assert data.datatype() == DataType.fixed_size_list(DataType.int64(), 2)
casted = data.cast(DataType.list(DataType.int64()))
assert casted.to_pylist() == [[1, 2], [3, 4], [5, 6]]


### Sparse ###


def to_coo_sparse_dict(ndarray: np.ndarray) -> dict[str, np.ndarray]:
flat_array = ndarray.ravel()
indices = np.flatnonzero(flat_array).astype(np.uint64)
values = flat_array[indices]
shape = list(ndarray.shape)
return {"values": values, "indices": indices, "shape": shape}


def test_series_cast_sparse_to_python() -> None:
data = [np.zeros(shape=(1, 2), dtype=np.uint8), None, np.ones(shape=(2, 2), dtype=np.uint8)]
series = Series.from_pylist(data).cast(DataType.sparse_tensor(DataType.uint8()))
assert series.datatype() == DataType.sparse_tensor(DataType.uint8())

given = series.to_pylist()
expected = [to_coo_sparse_dict(ndarray) if ndarray is not None else None for ndarray in data]
np.testing.assert_equal(given, expected)


def test_series_cast_fixed_shape_sparse_to_python() -> None:
data = [np.zeros(shape=(2, 2), dtype=np.uint8), None, np.ones(shape=(2, 2), dtype=np.uint8)]
series = (
Series.from_pylist(data)
.cast(DataType.tensor(DataType.uint8(), shape=(2, 2))) # TODO: direct cast to fixed shape sparse
.cast(DataType.sparse_tensor(DataType.uint8(), shape=(2, 2)))
)
assert series.datatype() == DataType.sparse_tensor(DataType.uint8(), shape=(2, 2))

given = series.to_pylist()
expected = [to_coo_sparse_dict(ndarray) if ndarray is not None else None for ndarray in data]
np.testing.assert_equal(given, expected)

0 comments on commit c2397bf

Please sign in to comment.