Skip to content

Commit

Permalink
Merge pull request #93 from ami-iit/jax2torch
Browse files Browse the repository at this point in the history
Allow pytorch batching using jax2torch
  • Loading branch information
Giulero authored Jun 27, 2024
2 parents c7c5f1f + 9d37316 commit c6dc157
Show file tree
Hide file tree
Showing 6 changed files with 702 additions and 18 deletions.
82 changes: 65 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
**adam** is based on Roy Featherstone's Rigid Body Dynamics Algorithms.

### Table of contents
- [🐍 Dependencies](#-dependencies)
- [💾 Installation](#-installation)
- [🐍 Installation with pip](#-installation-with-pip)
- [📦 Installation with conda](#-installation-with-conda)
- [Installation from conda-forge package](#installation-from-conda-forge-package)
- [🔨 Installation from repo](#-installation-from-repo)
- [🚀 Usage](#-usage)
- [Jax interface](#jax-interface)
- [CasADi interface](#casadi-interface)
- [PyTorch interface](#pytorch-interface)
- [🦸‍♂️ Contributing](#️-contributing)
- [Todo](#todo)

- [🐍 Dependencies](#-dependencies)
- [💾 Installation](#-installation)
- [🐍 Installation with pip](#-installation-with-pip)
- [📦 Installation with conda](#-installation-with-conda)
- [Installation from conda-forge package](#installation-from-conda-forge-package)
- [🔨 Installation from repo](#-installation-from-repo)
- [🚀 Usage](#-usage)
- [Jax interface](#jax-interface)
- [CasADi interface](#casadi-interface)
- [PyTorch interface](#pytorch-interface)
- [PyTorch Batched interface](#pytorch-batched-interface)
- [🦸‍♂️ Contributing](#️-contributing)
- [Todo](#todo)

## 🐍 Dependencies

Expand All @@ -41,6 +43,7 @@ Other requisites are:
- `casadi`
- `pytorch`
- `numpy`
- `jax2torch`

They will be installed in the installation step!

Expand Down Expand Up @@ -114,6 +117,9 @@ mamba create -n adamenv -c conda-forge adam-robotics

If you want to use `jax` or `pytorch`, just install the corresponding package as well.

> [!NOTE]
> Check also the conda JAX installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#conda-community-supported)

### 🔨 Installation from repo

Install in a conda environment the required dependencies:
Expand All @@ -133,13 +139,13 @@ Install in a conda environment the required dependencies:
- **PyTorch** interface dependencies:

```bash
mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py
mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch
```

- **ALL** interfaces dependencies:

```bash
mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py
mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch
```

Activate the environment, clone the repo and install the library:
Expand All @@ -154,10 +160,13 @@ pip install --no-deps .
## 🚀 Usage

The following are small snippets of the use of **adam**. More examples are arriving!
Have also a look at te `tests` folder.
Have also a look at the `tests` folder.

### Jax interface

> [!NOTE]
> Check also the Jax installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#)

```python
import adam
from adam.jax import KinDynComputations
Expand Down Expand Up @@ -205,11 +214,14 @@ jitted_vmapped_frame_fk = jit(vmapped_frame_fk)
# and called on a batch of data
joints_batch = jnp.tile(joints, (1024, 1))
w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1))
w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch)
```

> [!NOTE]
> The first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster!

### CasADi interface

```python
Expand Down Expand Up @@ -251,7 +263,6 @@ joints = cs.MX.sym('joints', len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))
```

### PyTorch interface
Expand Down Expand Up @@ -284,6 +295,43 @@ M = kinDyn.mass_matrix(w_H_b, joints)
print(M)
```

### PyTorch Batched interface

> [!NOTE]
> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster!

```python
import adam
from adam.pytorch import KinDynComputationsBatch
import icub_models
# if you want to icub-models
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]
kinDyn = KinDynComputationsBatch(model_path, joints_name_list)
# choose the representation you want to use the body fixed representation
kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
# or, if you want to use the mixed representation (that is the default)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
num_samples = 1024
w_H_b_batch = torch.tensor(np.tile(w_H_b, (num_samples, 1, 1)), dtype=torch.float32)
joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float32)
M = kinDyn.mass_matrix(w_H_b_batch, joints_batch)
w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch)
```

## 🦸‍♂️ Contributing

**adam** is an open-source project. Contributions are very welcome!
Expand Down
3 changes: 2 additions & 1 deletion ci_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pytest-repeat
- icub-models
- idyntree >=11.0.0
- gitpython
- gitpython
- jax
- pytorch
- jax2torch
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ casadi =
casadi
pytorch =
torch
jax
jaxlib
jax2torch
test =
jax
jaxlib
Expand All @@ -54,13 +57,16 @@ test =
icub-models
black
gitpython
jax2torch
conversions =
idyntree
all =
jax
jaxlib
casadi
torch
jax2torch
[tool:pytest]
addopts = --capture=no --verbose
1 change: 1 addition & 0 deletions src/adam/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# GNU Lesser General Public License v2.1 or any later version.

from .computations import KinDynComputations
from .computation_batch import KinDynComputationsBatch
from .torch_like import TorchLike
Loading

0 comments on commit c6dc157

Please sign in to comment.