diff --git a/docs/source/python/dlpack.rst b/docs/source/python/dlpack.rst index 024c2800e1107..9f0d3b58aa6e5 100644 --- a/docs/source/python/dlpack.rst +++ b/docs/source/python/dlpack.rst @@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol (``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can thus be consumed by libraries implementing ``from_dlpack``. -Example -------- +Examples +-------- -Convert a PyArrow CPU array to NumPy array: +Convert a PyArrow CPU array into a NumPy array: .. code-block:: @@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array: >>> np.from_dlpack(array) array([2, 0, 2, 4]) -Convert a PyArrow CPU array to PyTorch tensor: +Convert a PyArrow CPU array into a PyTorch tensor: .. code-block:: >>> import torch >>> torch.from_dlpack(array) tensor([2, 0, 2, 4]) + +Convert a PyArrow CPU array into a JAX array: + +.. code-block:: + + >>> import jax + >>> jax.numpy.from_dlpack(array) + Array([2, 0, 2, 4], dtype=int32) + >>> jax.dlpack.from_dlpack(array) + Array([2, 0, 2, 4], dtype=int32)