Skip to content

Commit

Permalink
Add HAS_NUMPY in order to avoid 'numpy' in sys.modules every time
Browse files Browse the repository at this point in the history
  • Loading branch information
raulcd committed Jun 12, 2024
1 parent 32fb1e7 commit 3ac349e
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def parse_git(root, **kwargs):
io_thread_count, set_io_thread_count)


# Expose whether NUMPY is available or not
from pyarrow.lib import HAS_NUMPY


def show_versions():
"""
Print various version information, to help with error reporting.
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from cython.operator cimport dereference as deref

from collections import namedtuple

from pyarrow.lib import frombytes, tobytes, ArrowInvalid
from pyarrow.lib import frombytes, tobytes, ArrowInvalid, HAS_NUMPY
from pyarrow.lib cimport *
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
Expand Down Expand Up @@ -476,7 +476,7 @@ cdef class MetaFunction(Function):

cdef _pack_compute_args(object values, vector[CDatum]* out):
for val in values:
if "numpy" in sys.modules and isinstance(val, (list, np.ndarray)):
if HAS_NUMPY and isinstance(val, (list, np.ndarray)):
val = lib.asarray(val)

if isinstance(val, Array):
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

import datetime
import decimal as _pydecimal
HAS_NUMPY=False
try:
import numpy as np
HAS_NUMPY=True
except ImportError:
np = None
import os
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def test_array_slice():
res.validate()
expected = arr.to_pylist()[start:stop]
assert res.to_pylist() == expected
if "numpy" in sys.modules:
if pa.HAS_NUMPY:
assert res.to_numpy().tolist() == expected


Expand Down

0 comments on commit 3ac349e

Please sign in to comment.