Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mujoco_warp/_src/jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def test_jax(self, xml):
self.skipTest("JAX not installed")

from jax import numpy as jp
from jax.sharding import Mesh, PartitionSpec as PS
from warp.jax_experimental.ffi import jax_callable

if jax.default_backend() != "gpu":
self.skipTest("JAX default backend is not GPU")

NWORLDS = 2
NDEVICES = len(jax.devices())
NCONTACTS = 16
UNROLL_LENGTH = 1

Expand Down Expand Up @@ -93,7 +95,24 @@ def step(carry, _):
jax_qvel = jp.zeros((NWORLDS, m.nv))

jax_unroll_fn = jax.jit(unroll).lower(jax_qpos, jax_qvel).compile()
jax_unroll_fn(jax_qpos, jax_qvel)
res = jax_unroll_fn(jax_qpos, jax_qvel)
print(res)

# Test shard_map support
mesh = Mesh(jax.devices(), ["devices"])
sharded_unroll_fn = jax.jit(
jax.shard_map(
unroll,
mesh=mesh,
in_specs=(PS("devices"), PS("devices")),
out_specs=(PS("devices"), PS("devices"))))
sharded_qpos = jp.tile(jp.array(m.qpos0.numpy()), (NDEVICES * NWORLDS, 1))
sharded_qvel = jp.zeros((NDEVICES * NWORLDS, m.nv))

compiled_sharded_unroll_fn = sharded_unroll_fn.lower(sharded_qpos, sharded_qvel).compile()
res = compiled_sharded_unroll_fn(sharded_qpos, sharded_qvel)
print(res)



if __name__ == "__main__":
Expand Down