From dad527bffe7e298b3f19dfef7d62ac264def4ef3 Mon Sep 17 00:00:00 2001 From: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:47:33 +0530 Subject: [PATCH] fix(jax): include in_out_argnames and stage argnames in FFI registry cache keys Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com> --- warp/_src/jax_experimental/ffi.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index e9fe408f47..ad2214627d 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -1203,6 +1203,7 @@ def jax_kernel( hashable_launch_dims = launch_dims if not enable_backward: + hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None key = ( kernel.func, kernel.sig, @@ -1210,6 +1211,7 @@ def jax_kernel( vmap_method, hashable_launch_dims, hashable_output_dims, + hashable_in_out, module_preload_mode, has_side_effect, ) @@ -1548,12 +1550,18 @@ def jax_callable( hashable_output_dims = output_dims # Note: we don't include graph_cache_max in the key, it is applied below. + hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None + hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None + hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None key = ( func, num_outputs, graph_mode, vmap_method, hashable_output_dims, + hashable_in_out, + hashable_stage_in, + hashable_stage_out, module_preload_mode, has_side_effect, )