From ab1b772ab2efc3374d5de30c62e55202ee1b107b Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 11 Oct 2024 08:28:49 -0700 Subject: [PATCH] [BUG] Register super extension on to_arrow (#3030) This is an issue where Daft Extension types were not getting converted to PyArrow properly. @jaychia discovered this while trying to write parquet with a tensor column, where the Extension metadata for tensor was getting dropped. A simple test to reproduce the error: ``` import daft import numpy as np from daft import Series # Create sample tensor data with some null values tensor_data = [np.array([[1, 2], [3, 4]]), None, None] # Uncomment this and it will work # from daft.datatype import _ensure_registered_super_ext_type # _ensure_registered_super_ext_type() df_original = daft.from_pydict({"tensor_col": Series.from_pylist(tensor_data)}) print(df_original.to_arrow().schema) ``` Output: ``` tensor_col: struct, shape: large_list> child 0, data: large_list child 0, item: int64 child 1, shape: large_list child 0, item: uint64 ``` It's not a tensor type! However if you uncomment the `_ensure_registered_super_ext_type()`, you will now see: ``` tensor_col: extension> ``` The issue here is that the `class DaftExtension(pa.ExtensionType):` is not imported during the FFI, as it is now a lazy import that must be called via `_ensure_registered_super_ext_type()`. This PR adds calls to this import in `to_arrow` for series and schema. However, I do not know if this is exhaustive, and I will give this more thought. @desmondcheongzx @samster25 --------- Co-authored-by: Colin Ho --- daft/logical/schema.py | 3 ++- daft/series.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 401a235d53..93005ddae7 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -8,7 +8,7 @@ from daft.daft import read_csv_schema as _read_csv_schema from daft.daft import read_json_schema as _read_json_schema from daft.daft import read_parquet_schema as _read_parquet_schema -from daft.datatype import DataType, TimeUnit +from daft.datatype import DataType, TimeUnit, _ensure_registered_super_ext_type if TYPE_CHECKING: import pyarrow as pa @@ -82,6 +82,7 @@ def to_pyarrow_schema(self) -> pa.Schema: Returns: pa.Schema: PyArrow schema that corresponds to the provided Daft schema """ + _ensure_registered_super_ext_type() return self._schema.to_pyarrow_schema() @classmethod diff --git a/daft/series.py b/daft/series.py index 5cbcfe7ba0..fd85d33f13 100644 --- a/daft/series.py +++ b/daft/series.py @@ -213,6 +213,8 @@ def to_arrow(self) -> pa.Array: """ Convert this Series to an pyarrow array. """ + _ensure_registered_super_ext_type() + dtype = self.datatype() arrow_arr = self._series.to_arrow()