Skip to content

Commit

Permalink
[Docs]: Add PyArrow to JAX example to the docs
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed Sep 25, 2024
1 parent c557fe5 commit 47dcebb
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions docs/source/python/dlpack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ PyArrow implements the second part of the protocol
(``__dlpack__(self, stream=None)`` and ``__dlpack_device__``) and can
thus be consumed by libraries implementing ``from_dlpack``.

Example
-------
Examples
--------

Convert a PyArrow CPU array to NumPy array:
Convert a PyArrow CPU array into a NumPy array:

.. code-block::
Expand All @@ -84,10 +84,20 @@ Convert a PyArrow CPU array to NumPy array:
>>> np.from_dlpack(array)
array([2, 0, 2, 4])
Convert a PyArrow CPU array to PyTorch tensor:
Convert a PyArrow CPU array into a PyTorch tensor:

.. code-block::
>>> import torch
>>> torch.from_dlpack(array)
tensor([2, 0, 2, 4])
Convert a PyArrow CPU array into a JAX array:

.. code-block::
>>> import jax
>>> jax.numpy.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)
>>> jax.dlpack.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)

0 comments on commit 47dcebb

Please sign in to comment.