Skip to content

Commit

Permalink
add python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
joellubi committed Aug 1, 2024
1 parent 5250e44 commit 86210e8
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def print_entry(label, value):
dictionary,
run_end_encoded,
fixed_shape_tensor,
bool8,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -182,7 +183,7 @@ def print_entry(label, value):
TimestampType, Time32Type, Time64Type, DurationType,
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, FixedShapeTensorType,
RunEndEncodedType, FixedShapeTensorType, Bool8Type,
PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
Expand Down Expand Up @@ -216,7 +217,7 @@ def print_entry(label, value):
Time32Array, Time64Array, DurationArray,
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray, FixedShapeTensorArray,
RunEndEncodedArray, FixedShapeTensorArray, Bool8Array,
scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
Expand All @@ -233,7 +234,7 @@ def print_entry(label, value):
StringScalar, LargeStringScalar, StringViewScalar,
FixedSizeBinaryScalar, DictionaryScalar,
MapScalar, StructScalar, UnionScalar,
RunEndEncodedScalar, ExtensionScalar)
RunEndEncodedScalar, ExtensionScalar, Bool8Scalar)

# Buffers, allocation
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,
Expand Down
24 changes: 24 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4447,6 +4447,30 @@ cdef class FixedShapeTensorArray(ExtensionArray):
FixedSizeListArray.from_arrays(values, shape[1:].prod())
)

cdef class Bool8Array(ExtensionArray):
"""
Concrete class for bool8 extension arrays.
Examples
--------
Define the extension type for an bool8 array
>>> import pyarrow as pa
>>> bool8_type = pa.bool8()
Create an extension array
>>> arr = [-1, 0, 1, 2, None]
>>> storage = pa.array(arr, pa.int8())
>>> pa.ExtensionArray.from_storage(bool8_type, storage)
<pyarrow.lib.Bool8Array object at ...>
[
-1,
0,
1,
2,
null
]
"""

def to_numpy(self):
return self.storage.to_numpy().view(np.bool_)

cdef dict _array_classes = {
_Type_NA: NullArray,
Expand Down
11 changes: 11 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2882,6 +2882,17 @@ cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extens
" arrow::extension::FixedShapeTensorArray"(CExtensionArray):
const CResult[shared_ptr[CTensor]] ToTensor() const

cdef extern from "arrow/extension/bool8.h" namespace "arrow::extension" nogil:
cdef cppclass CBool8Type \
" arrow::extension::Bool8Type"(CExtensionType):

@staticmethod
CResult[shared_ptr[CDataType]] Make()

cdef cppclass CBool8Array \
" arrow::extension::Bool8Array"(CExtensionArray):
pass

cdef extern from "arrow/util/compression.h" namespace "arrow" nogil:
cdef enum CCompressionType" arrow::Compression::type":
CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED"
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ cdef class FixedShapeTensorType(BaseExtensionType):
cdef:
const CFixedShapeTensorType* tensor_ext_type

cdef class Bool8Type(BaseExtensionType):
cdef:
const CBool8Type* bool8_ext_type

cdef class PyExtensionType(ExtensionType):
pass
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ cdef api object pyarrow_wrap_data_type(
return cpy_ext_type.GetInstance()
elif ext_type.extension_name() == b"arrow.fixed_shape_tensor":
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
elif ext_type.extension_name() == b"arrow.bool8":
out = Bool8Type.__new__(Bool8Type)
else:
out = BaseExtensionType.__new__(BaseExtensionType)
else:
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,10 @@ cdef class FixedShapeTensorScalar(ExtensionScalar):
ctensor = GetResultValue(c_type.MakeTensor(scalar))
return pyarrow_wrap_tensor(ctensor)

cdef class Bool8Scalar(ExtensionScalar):
"""
Concrete class for bool8 extension scalar.
"""

cdef dict _scalar_classes = {
_Type_BOOL: BooleanScalar,
Expand Down
58 changes: 58 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,3 +1661,61 @@ def test_legacy_int_type():
batch = ipc_read_batch(buf)
assert isinstance(batch.column(0).type, LegacyIntType)
assert batch.column(0) == ext_arr

def test_bool8_type(pickle_module):
bool8_type = pa.bool8()
storage_type = pa.int8()
assert bool8_type.extension_name == "arrow.bool8"
assert bool8_type.storage_type == storage_type
assert str(bool8_type) == "extension<arrow.bool8>"

assert bool8_type == bool8_type
assert bool8_type == pa.bool8()
assert bool8_type != storage_type

# Pickle roundtrip
result = pickle_module.loads(pickle_module.dumps(bool8_type))
assert result == bool8_type

# IPC roundtrip
bool8_arr_class = bool8_type.__arrow_ext_class__()
storage = pa.array([-1, 0, 1, 2, None], storage_type)
arr = pa.ExtensionArray.from_storage(bool8_type, storage)
assert isinstance(arr, bool8_arr_class)

with registered_extension_type(bool8_type):
buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"]))
batch = ipc_read_batch(buf)

assert batch.column(0).type.extension_name == "arrow.bool8"
assert isinstance(batch.column(0), bool8_arr_class)

# cast storage -> extension type
result = storage.cast(bool8_type)
assert result == arr

# cast extension type -> storage type
inner = arr.cast(storage_type)
assert inner == storage

# cast extension type -> arrow boolean type
bool_type = pa.bool_()
arrow_bool_arr = pa.array([True, False, True, True, None], bool_type)
cast_bool_arr = arr.cast(bool_type)
assert cast_bool_arr == arrow_bool_arr

# cast arrow boolean type -> extension type, expecting canonical values
cast_bool8_arr = arrow_bool_arr.cast(bool8_type)
canonical_storage = pa.array([1, 0, 1, 1, None], storage_type)
canonical_bool8_arr = pa.ExtensionArray.from_storage(bool8_type, canonical_storage)
assert cast_bool8_arr == canonical_bool8_arr

# zero-copy convert to numpy if non-null
with pytest.raises(pa.ArrowInvalid, match="Needed to copy 1 chunks with 1 nulls, but zero_copy_only was True"):
arr.to_numpy()

arr_np_bool = np.array([True, False, True, True], dtype=np.bool_)
arr_no_nulls = pa.ExtensionArray.from_storage(bool8_type, pa.array([-1, 0, 1, 2], storage_type))
arr_to_np = arr_no_nulls.to_numpy()
assert np.array_equal(arr_to_np, arr_np_bool)
assert arr_to_np.ctypes.data == arr_no_nulls.buffers()[1].address # zero-copy
3 changes: 3 additions & 0 deletions python/pyarrow/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def test_set_timezone_db_path_non_windows():
pa.ProxyMemoryPool,
pa.Device,
pa.MemoryManager,
pa.Bool8Array,
pa.Bool8Scalar,
pa.Bool8Type,
])
def test_extension_type_constructor_errors(klass):
# ARROW-2638: prevent calling extension class constructors directly
Expand Down
61 changes: 61 additions & 0 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,32 @@ cdef class FixedShapeTensorType(BaseExtensionType):
def __arrow_ext_scalar_class__(self):
return FixedShapeTensorScalar

cdef class Bool8Type(BaseExtensionType):
"""
Concrete class for bool8 extension type.
Bool8 is an alternate representation for boolean
arrays using 8 bits instead of 1 bit per value. The underlying
storage type is int8.
Examples
--------
Create an instance of bool8 extension type:
>>> import pyarrow as pa
>>> pa.bool8()
Bool8Type(extension<arrow.bool8>)
"""

cdef void init(self, const shared_ptr[CDataType]& type) except *:
BaseExtensionType.init(self, type)
self.bool8_ext_type = <const CBool8Type*> type.get()

def __arrow_ext_class__(self):
return Bool8Array

def __reduce__(self):
return bool8, ()

def __arrow_ext_scalar_class__(self):
return Bool8Scalar

_py_extension_type_auto_load = False

Expand Down Expand Up @@ -5206,6 +5232,41 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N

return out

def bool8():
"""
Create instance of bool8 extension type.
Examples
--------
Create an instance of bool8 extension type:
>>> import pyarrow as pa
>>> type = pa.bool8()
>>> type
Bool8Type(extension<arrow.bool8>)
Inspect the data type:
>>> type.storage_type
DataType(int8)
Create a table with a bool8 array:
>>> arr = [-1, 0, 1, 2, None]
>>> storage = pa.array(arr, pa.int8())
>>> other = pa.ExtensionArray.from_storage(type, storage)
>>> pa.table([other], names=["unknown_col"])
pyarrow.Table
unknown_col: extension<arrow.bool8>
----
unknown_col: [[True, False, True, True, null]]
Returns
-------
type : Bool8Type
"""

cdef Bool8Type out = Bool8Type.__new__(Bool8Type)

with nogil:
c_type = GetResultValue(CBool8Type.Make())

out.init(c_type)

return out

cdef dict _type_aliases = {
'null': null,
Expand Down

0 comments on commit 86210e8

Please sign in to comment.