Skip to content

Commit

Permalink
Align in_axis in functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Giulero committed Jun 27, 2024
1 parent 18f07e6 commit 9d37316
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/adam/pytorch/computation_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class KinDynComputationsBatch:
"""This is a small class that retrieves robot quantities using Jax for Floating Base systems.
These functions are vmapped and jit compiled and passed to jax2torch to convert them to torch functions.
These functions are vmapped and jit compiled and passed to jax2torch to convert them to PyTorch functions.
"""

def __init__(
Expand Down Expand Up @@ -233,7 +233,7 @@ def fun(base_transform, joint_positions):
frame, base_transform, joint_positions
).array

vmapped_fun = jax.vmap(fun)
vmapped_fun = jax.vmap(fun, in_axes=(0, 0))
jit_vmapped_fun = jax.jit(vmapped_fun)
self.funcs[f"forward_kinematics_{frame}"] = jax2torch(jit_vmapped_fun)
return self.funcs[f"forward_kinematics_{frame}"]
Expand Down Expand Up @@ -269,7 +269,7 @@ def jacobian_fun(self, frame: str):
def fun(base_transform, joint_positions):
return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array

vmapped_fun = jax.vmap(fun)
vmapped_fun = jax.vmap(fun, in_axes=(0, 0))
jit_vmapped_fun = jax.jit(vmapped_fun)
self.funcs[f"jacobian_{frame}"] = jax2torch(jit_vmapped_fun)
return self.funcs[f"jacobian_{frame}"]
Expand Down Expand Up @@ -398,7 +398,7 @@ def fun(base_transform, joint_positions):
self.g,
).array.squeeze()

vmapped_fun = jax.vmap(fun)
vmapped_fun = jax.vmap(fun, in_axes=(0, 0))
jit_vmapped_fun = jax.jit(vmapped_fun)
self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun)
return self.funcs["gravity_term"]
Expand Down Expand Up @@ -430,7 +430,7 @@ def CoM_position_fun(self):
def fun(base_transform, joint_positions):
return self.rbdalgos.CoM_position(base_transform, joint_positions).array

vmapped_fun = jax.vmap(fun)
vmapped_fun = jax.vmap(fun, in_axes=(0, 0))
jit_vmapped_fun = jax.jit(vmapped_fun)
self.funcs["CoM_position"] = jax2torch(jit_vmapped_fun)
return self.funcs["CoM_position"]
Expand Down

0 comments on commit 9d37316

Please sign in to comment.