Skip to content

GenCast autoregressive rollout fails with JaxArrayWrapper / DynamicJaxprTracer __array_ufunc__ TypeError on both GPU and Colab TPU #203

@DarinaAndr

Description

@DarinaAndr

Hi, would really appreciate some help with the following issue of getting a TypeError while running the GenCast autoregressive rollout. It fails with a TypeError related to xarray_jax.JaxArrayWrapper interacting with a DynamicJaxprTracer.

The failure occurs during the autoregressive rollout step (chunked_prediction_generator_multiple_runs) in the GenCast demo.

I can reproduce the same error on a local multi-GPU machine, and on Google Colab TPU (v5e-1 TPU with 2025.07 runtime version), and I think the error is related to the GenCast rollout path interacting with xarray_jax.

The simpler GraphCast forward pass works correctly.

Execution fails with:

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
): 'DynamicJaxprTracer', 'JaxArrayWrapper'

For reproduction:
Using GenCast demo notebook, the failure occurs in the Autoregressive rollout cell:

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk=1,
    num_samples=num_ensemble_members,
    pmap_devices=jax.local_devices(),
):
    chunks.append(chunk)

predictions = xarray.combine_by_coords(chunks)

Setup:

Machine:

Linux cluster

2× NVIDIA GPUs
Python environment:

Python 3.11
jax 0.4.38
jaxlib 0.4.38
numpy 2.4.2
xarray 2024.11.0
pandas 2.2.3

Devices detected:

devices: [CudaDevice(id=0), CudaDevice(id=1)]
backend: gpu

Error traceback ends with:

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions