Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibilities with bloat16 after update to numpy 2 #9568

Open
5 tasks
alvarosg opened this issue Oct 2, 2024 · 1 comment
Open
5 tasks

Incompatibilities with bloat16 after update to numpy 2 #9568

alvarosg opened this issue Oct 2, 2024 · 1 comment
Labels
array API standard Support for the Python array API standard upstream issue

Comments

@alvarosg
Copy link

alvarosg commented Oct 2, 2024

What happened?

Computing the max or the isnull on a DataArray with bfloat16 values raises a:
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

This worked fine before updating numpy to version 2. The main difference in the code seems to be that with numpy < 2, xarray uses its own implementation of isdtype, while for numpy >= 2 it relies on np.isdtype. This is confirmed by checking that doing import numpy as np; del np.isdtype fixes the problem.

What did you expect to happen?

I expected the computation to be successful, just as prior to numpy 2.

Minimal Complete Verifiable Example

import numpy as np
# del np.isdtype  # Uncommenting this line fixes it.

import xarray
import ml_dtypes

da = xarray.DataArray(np.zeros([2], dtype=ml_dtypes.bfloat16), dims=("dim",))
da.isnull() # Or da.max("dim")

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

TypeError                                 Traceback (most recent call last)
Cell In[1], line 5
      3 import numpy as np
      4 da = xarray.DataArray(np.zeros([2], dtype=jnp.bfloat16), dims=("dim",))
----> 5 da.isnull()

File ~/dev/xarray/xarray/core/common.py:1293, in DataWithCoords.isnull(self, keep_attrs)
   1290 if keep_attrs is None:
   1291     keep_attrs = _get_keep_attrs(default=False)
-> 1293 return apply_ufunc(
   1294     duck_array_ops.isnull,
   1295     self,
   1296     dask="allowed",
   1297     keep_attrs=keep_attrs,
   1298 )

File ~/dev/xarray/xarray/core/computation.py:1278, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
   1276 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1277 elif any(isinstance(a, DataArray) for a in args):
-> 1278     return apply_dataarray_vfunc(
   1279         variables_vfunc,
   1280         *args,
   1281         signature=signature,
   1282         join=join,
   1283         exclude_dims=exclude_dims,
   1284         keep_attrs=keep_attrs,
   1285     )
   1286 # feed Variables directly through apply_variable_ufunc
   1287 elif any(isinstance(a, Variable) for a in args):

File ~/dev/xarray/xarray/core/computation.py:320, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    315 result_coords, result_indexes = build_output_coords_and_indexes(
    316     args, signature, exclude_dims, combine_attrs=keep_attrs
    317 )
    319 data_vars = [getattr(a, "variable", a) for a in args]
--> 320 result_var = func(*data_vars)
    322 out: tuple[DataArray, ...] | DataArray
    323 if signature.num_outputs > 1:

File ~/dev/xarray/xarray/core/computation.py:831, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    826     if vectorize:
    827         func = _vectorize(
    828             func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
    829         )
--> 831 result_data = func(*input_data)
    833 if signature.num_outputs == 1:
    834     result_data = (result_data,)

File ~/dev/xarray/xarray/core/duck_array_ops.py:144, in isnull(data)
    139 if dtypes.is_datetime_like(scalar_type):
    140     # datetime types use NaT for null
    141     # note: must check timedelta64 before integers, because currently
    142     # timedelta64 inherits from np.integer
    143     return isnat(data)
--> 144 elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
    145     # float types use NaN for null
    146     xp = get_array_namespace(data)
    147     return xp.isnan(data)

File ~/dev/xarray/xarray/core/dtypes.py:208, in isdtype(dtype, kind, xp)
    205     raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
    207 if isinstance(dtype, np.dtype):
--> 208     return npcompat.isdtype(dtype, kind)
    209 elif is_extension_array_dtype(dtype):
    210     # we never want to match pandas extension array dtypes
    211     return False

File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/numpy/_core/numerictypes.py:425, in isdtype(dtype, kind)
    423     dtype = _preprocess_dtype(dtype)
    424 except _PreprocessDTypeError:
--> 425     raise TypeError(
    426         "dtype argument must be a NumPy dtype, "
    427         f"but it is a {type(dtype)}."
    428     ) from None
    430 input_kinds = kind if isinstance(kind, tuple) else (kind,)
    432 processed_kinds = set()

TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

Anything else we need to know?

Here's a a different reproducer showing the inconsistency between np.isdtype and npcompat.isdtype

import importlib
from xarray.core import npcompat
import ml_dtypes
import numpy as np
try:
  npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # `AttributeError: 'module' object has no attribute 'isdtype'`
except Exception as e:
  print(e)

numpy_is_dytype = np.isdtype
del np.isdtype
importlib.reload(npcompat)
np.isdtype = numpy_is_dytype

npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # No error, but returns False.

Environment

In [5]: xarray.show_versions()

INSTALLED VERSIONS

commit: 03d3e0b
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
python-bits: 64
OS: Darwin
OS-release: 23.6.0
machine: arm64
processor: arm
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.3
libnetcdf: 4.9.2

xarray: 2024.7.1.dev73+g781877cb
pandas: 2.2.2
numpy: 2.1.1
scipy: 1.13.1
netCDF4: 1.6.5
pydap: None
h5netcdf: None
h5py: None
zarr: 2.18.2
cftime: 1.6.4
nc_time_axis: None
iris: None
bottleneck: None
dask: 2024.8.2
distributed: 2024.5.2
matplotlib: 3.9.0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.6.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: 0.11.1
setuptools: 70.0.0
pip: 24.0
conda: 24.7.1
pytest: 8.2.2
mypy: 1.10.0
IPython: 8.25.0

@alvarosg alvarosg added bug needs triage Issue that has not been reviewed by xarray team member labels Oct 2, 2024
@keewis keewis added upstream issue array API standard Support for the Python array API standard and removed bug needs triage Issue that has not been reviewed by xarray team member labels Oct 5, 2024
@keewis
Copy link
Collaborator

keewis commented Oct 5, 2024

the difference here is that npcompat.isdtype translates the string to a numpy.dtype superclass, then uses isinstance to perform the check, while np.isdtype explicitly raises if it receives anything other than np.dtype subclasses or the string categories.

I don't think we can do a lot here (correct me if I'm wrong, @shoyer), so it might make more sense to take this up with the numpy devs.

cc @rgommers, @seberg for awareness

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard upstream issue
Projects
None yet
Development

No branches or pull requests

2 participants