Skip to content

Commit

Permalink
pr feedback on impl
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Oct 10, 2023
1 parent 42d46f1 commit e98a1f4
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 26 deletions.
25 changes: 20 additions & 5 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,12 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
else:
requested_type = None
schema_capsule, array_capsule = obj.__arrow_c_array__(requested_type)
return Array._import_from_c_capsule(schema_capsule, array_capsule)
out_array = Array._import_from_c_capsule(schema_capsule, array_capsule)
if type is not None and array.type != type:
# PyCapsule interface type coersion is best effort, so we need to
# check the type of the returned array and cast if necessary
out_array = array.cast(type, safe=safe, memory_pool=memory_pool)
return out_array
elif _is_array_like(obj):
if mask is not None:
if _is_array_like(mask):
Expand Down Expand Up @@ -1730,8 +1735,9 @@ cdef class Array(_PandasConvertible):
respectively.
"""
cdef:
ArrowArray* c_array = <ArrowArray*> malloc(sizeof(ArrowArray))
ArrowSchema* c_schema = <ArrowSchema*> malloc(sizeof(ArrowSchema))
ArrowArray* c_array
ArrowSchema* c_schema
shared_ptr[CArray] inner_array

if requested_schema is not None:
target_type = DataType._import_from_c_capsule(requested_schema)
Expand All @@ -1750,8 +1756,17 @@ cdef class Array(_PandasConvertible):
else:
inner_array = self.sp_array

with nogil:
check_status(ExportArray(deref(inner_array), c_array, c_schema))
c_array = <ArrowArray*> malloc(sizeof(ArrowArray))
c_schema = <ArrowSchema*> malloc(sizeof(ArrowSchema))

try:
with nogil:
check_status(ExportArray(deref(inner_array), c_array, c_schema))
except:
# Ensure we don't leak memory if the export fails.
free(c_array)
free(c_schema)
raise

schema_capsule = PyCapsule_New(c_schema, 'arrow_schema', &pycapsule_schema_deleter)
array_capsule = PyCapsule_New(c_array, 'arrow_array', &pycapsule_array_deleter)
Expand Down
16 changes: 8 additions & 8 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ cdef class RecordBatchReader(_Weakrefable):
A capsule containing a C ArrowArrayStream struct.
"""
cdef:
ArrowArrayStream* c_stream = <ArrowArrayStream*>malloc(sizeof(ArrowArrayStream))
ArrowArrayStream* c_stream

if requested_schema is not None:
out_schema = Schema._import_from_c_capsule(requested_schema)
Expand All @@ -844,8 +844,13 @@ cdef class RecordBatchReader(_Weakrefable):
if self.schema != out_schema:
raise NotImplementedError("Casting to requested_schema")

with nogil:
check_status(ExportRecordBatchReader(self.reader, c_stream))
c_stream = <ArrowArrayStream*>malloc(sizeof(ArrowArrayStream))
try:
with nogil:
check_status(ExportRecordBatchReader(self.reader, c_stream))
except:
free(c_stream)
raise

return PyCapsule_New(c_stream, "arrow_array_stream", &pycapsule_stream_deleter)

Expand All @@ -868,11 +873,6 @@ cdef class RecordBatchReader(_Weakrefable):
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader self

# sanity checks
if not PyCapsule_IsValid(stream, 'arrow_array_stream'):
raise ValueError(
"Not an ArrayArrayStream object"
)
c_stream = <ArrowArrayStream*>PyCapsule_GetPointer(
stream, 'arrow_array_stream'
)
Expand Down
23 changes: 22 additions & 1 deletion python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -5009,7 +5009,12 @@ def record_batch(data, names=None, schema=None, metadata=None):
else:
requested_schema = None
schema_capsule, array_capsule = data.__arrow_c_array__(requested_schema)
return RecordBatch._import_from_c_capsule(schema_capsule, array_capsule)
batch = RecordBatch._import_from_c_capsule(schema_capsule, array_capsule)
if schema is not None and batch.schema != schema:
# __arrow_c_array__ coerces schema with best effort, so we might
# need to cast it if the producer wasn't able to cast to exact schema.
batch = Table.from_batches([batch]).cast(schema).to_batches()[0]
return batch
elif _pandas_api.is_data_frame(data):
return RecordBatch.from_pandas(data, schema=schema)
else:
Expand Down Expand Up @@ -5131,6 +5136,22 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
raise ValueError(
"The 'names' argument is not valid when passing a dictionary")
return Table.from_pydict(data, schema=schema, metadata=metadata)
elif hasattr(data, "__arrow_c_stream__"):
if schema is not None:
requested = schema.__arrow_c_schema__()
else:
requested = None
capsule = data.__arrow_c_stream__(requested)
reader = RecordBatchReader._import_from_c_capsule(capsule)
table = reader.read_all()
if schema is not None and table.schema != schema:
# __arrow_c_array__ coerces schema with best effort, so we might
# need to cast it if the producer wasn't able to cast to exact schema.
table = table.cast(schema)
return table
elif hasattr(data, "__arrow_c_array__"):
batch = record_batch(data, schema)
return Table.from_batches([batch])
elif _pandas_api.is_data_frame(data):
if names is not None or metadata is not None:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.

import ctypes
import gc

import pyarrow as pa
Expand All @@ -25,7 +26,6 @@
ffi = None

import pytest
import ctypes

try:
import pandas as pd
Expand Down
61 changes: 60 additions & 1 deletion python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)

data = pa.record_batch([
pa.array([1, 2, 3]), type=pa.int64()
pa.array([1, 2, 3], type=pa.int64())
], names=['a'])
wrapper = BatchWrapper(data)

Expand All @@ -581,6 +581,65 @@ def __arrow_c_array__(self, requested_type=None):
assert result == expected


def test_table_c_array_interface():
class BatchWrapper:
def __init__(self, batch):
self.batch = batch

def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)

data = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64())
], names=['a'])
wrapper = BatchWrapper(data)

# Can roundtrip through the wrapper.
result = pa.table(wrapper)
expected = pa.Table.from_batches([data])
assert result == expected

# Can also import with a schema that implementer can cast to.
castable_schema = pa.schema([
pa.field('a', pa.int32())
])
result = pa.table(wrapper, schema=castable_schema)
expected = pa.table({
'a': pa.array([1, 2, 3], type=pa.int32())
})
assert result == expected


def test_table_c_stream_interface():
class StreamWrapper:
def __init__(self, batches):
self.batches = batches

def __arrow_c_stream__(self, requested_type=None):
reader = pa.RecordBatchReader.from_batches(
self.batches[0].schema, self.batches)
return reader.__arrow_c_stream__(requested_type)

data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a'])
]
wrapper = StreamWrapper(data)

# Can roundtrip through the wrapper.
result = pa.table(wrapper)
expected = pa.Table.from_batches(data)
assert result == expected

# Passing schema works if already that schema
result = pa.table(wrapper, schema=data[0].schema)
assert result == expected

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
pa.table(wrapper, schema=pa.schema([pa.field('a', pa.int32())]))


def test_recordbatch_itercolumns():
data = [
pa.array(range(5), type='int16'),
Expand Down
35 changes: 25 additions & 10 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,14 @@ cdef class DataType(_Weakrefable):
Unlike _export_to_c, this will not leak memory if the capsule is not used.
"""
cdef ArrowSchema* c_schema = <ArrowSchema*>malloc(sizeof(ArrowSchema))
with nogil:
check_status(ExportType(deref(self.type), c_schema))
return cpython.PyCapsule_New(c_schema, 'arrow_schema', &pycapsule_schema_deleter)
try:
with nogil:
check_status(ExportType(deref(self.type), c_schema))
except:
# Avoid memory leak in case export fails.
free(c_schema)
raise
return PyCapsule_New(c_schema, 'arrow_schema', &pycapsule_schema_deleter)

@staticmethod
def _import_from_c_capsule(schema):
Expand All @@ -383,11 +388,11 @@ cdef class DataType(_Weakrefable):
ArrowSchema* c_schema
shared_ptr[CDataType] c_type

if not cpython.PyCapsule_IsValid(schema, 'arrow_schema'):
if not PyCapsule_IsValid(schema, 'arrow_schema'):
raise TypeError(
"Not an ArrowSchema object"
)
c_schema = <ArrowSchema*> cpython.PyCapsule_GetPointer(schema, 'arrow_schema')
c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema, 'arrow_schema')

with nogil:
c_type = GetResultValue(ImportType(c_schema))
Expand All @@ -396,7 +401,7 @@ cdef class DataType(_Weakrefable):


cdef void pycapsule_schema_deleter(object schema_capsule):
cdef ArrowSchema* schema = <ArrowSchema*>cpython.PyCapsule_GetPointer(
cdef ArrowSchema* schema = <ArrowSchema*>PyCapsule_GetPointer(
schema_capsule, 'arrow_schema'
)
if schema.release != NULL:
Expand Down Expand Up @@ -2424,8 +2429,13 @@ cdef class Field(_Weakrefable):
"""
cdef:
ArrowSchema* c_schema = <ArrowSchema*>malloc(sizeof(ArrowSchema))
with nogil:
check_status(ExportField(deref(self.field), c_schema))
try:
with nogil:
check_status(ExportField(deref(self.field), c_schema))
except:
# Avoid memory leak in case export fails.
free(c_schema)
raise
return PyCapsule_New(c_schema, 'arrow_schema', &pycapsule_schema_deleter)

@staticmethod
Expand Down Expand Up @@ -3246,8 +3256,13 @@ cdef class Schema(_Weakrefable):
"""
cdef:
ArrowSchema* c_schema = <ArrowSchema*>malloc(sizeof(ArrowSchema))
with nogil:
check_status(ExportSchema(deref(self.schema), c_schema))
try:
with nogil:
check_status(ExportSchema(deref(self.schema), c_schema))
except:
# Avoid memory leak in case export fails.
free(c_schema)
raise
return PyCapsule_New(c_schema, 'arrow_schema', &pycapsule_schema_deleter)

@staticmethod
Expand Down

0 comments on commit e98a1f4

Please sign in to comment.