diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index e52e0d242bee5..25b6adf606989 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -73,6 +73,10 @@ def parse_git(root, **kwargs): io_thread_count, set_io_thread_count) +# Expose whether NUMPY is available or not +from pyarrow.lib import HAS_NUMPY + + def show_versions(): """ Print various version information, to help with error reporting. diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0e052a91add75..b17b77901425d 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -24,7 +24,7 @@ from cython.operator cimport dereference as deref from collections import namedtuple -from pyarrow.lib import frombytes, tobytes, ArrowInvalid +from pyarrow.lib import frombytes, tobytes, ArrowInvalid, HAS_NUMPY from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * @@ -476,7 +476,7 @@ cdef class MetaFunction(Function): cdef _pack_compute_args(object values, vector[CDatum]* out): for val in values: - if "numpy" in sys.modules and isinstance(val, (list, np.ndarray)): + if HAS_NUMPY and isinstance(val, (list, np.ndarray)): val = lib.asarray(val) if isinstance(val, Array): diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 5c25631cb9bb9..4f2d39bbdda94 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -21,8 +21,10 @@ import datetime import decimal as _pydecimal +HAS_NUMPY=False try: import numpy as np + HAS_NUMPY=True except ImportError: np = None import os diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index eb5a488ade660..49a8402f88690 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -486,7 +486,7 @@ def test_array_slice(): res.validate() expected = arr.to_pylist()[start:stop] assert res.to_pylist() == expected - if "numpy" in sys.modules: + if pa.HAS_NUMPY: assert res.to_numpy().tolist() == expected