Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
317501e
Replace cf.rank==0 with utils.distributed.is_root
Jul 16, 2025
77de417
replace cf.rank==0 with weathergen.utils.distributed.is_root
Jul 16, 2025
6439618
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 22, 2025
8993875
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 25, 2025
f4a9d85
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 28, 2025
f8fdef4
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 29, 2025
ca89e7b
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 30, 2025
49d7a4d
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
f39f094
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
ebb03ea
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 25, 2025
f40737d
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 28, 2025
87fa078
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 10, 2025
5dfe275
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 19, 2025
b7244d9
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
5be41f5
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
39d3965
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 23, 2025
015ec88
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 24, 2025
cb1b7cc
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 1, 2025
90da4cf
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 20, 2025
f04891b
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 21, 2025
105d992
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 24, 2025
5f56073
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 26, 2025
95ee18a
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 3, 2025
3c702d3
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 10, 2025
6f14a30
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 13, 2025
5e87881
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 14, 2025
0c7d305
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 24, 2025
e43ac94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 25, 2025
5f63bcc
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
c51eb94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
dd5acc2
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
f03672d
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 27, 2025
49c52e1
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 28, 2025
c6356a2
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
36c709a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
765276a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 9, 2025
f3eb78a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 10, 2025
542f23e
Merge branch 'ecmwf:develop' into develop
csjfwang Jan 23, 2026
d14f61f
Merge branch 'ecmwf:develop' into develop
csjfwang Jan 30, 2026
692703b
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 10, 2026
165f498
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 15, 2026
a5adf2a
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 18, 2026
0442d5d
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 21, 2026
eb8480d
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 24, 2026
fff2626
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 26, 2026
c612dff
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 26, 2026
a2f56d3
Merge branch 'ecmwf:develop' into develop
csjfwang Feb 27, 2026
eca4792
Merge branch 'ecmwf:develop' into develop
csjfwang Mar 6, 2026
e0562f7
Merge branch 'ecmwf:develop' into develop
csjfwang Mar 18, 2026
bbe7916
spherical rope, 1st version
Mar 19, 2026
a9929c6
fix lint; change name of rope_spherical_l to rope_spherical_band
Mar 19, 2026
fb474e3
Spherical RoPE with post qk lnorm
Mar 24, 2026
295527f
Refactor and harden RoPE config flow by unifying rope_mode resolution…
Mar 24, 2026
e3e7108
Merge branch 'ecmwf:develop' into develop
csjfwang Mar 26, 2026
ef91a58
Merge branch 'develop' into spherical-rope
Mar 26, 2026
bb78f7d
Merge branch 'ecmwf:develop' into develop
csjfwang Mar 28, 2026
b4005f1
Merge branch 'ecmwf:develop' into develop
csjfwang Mar 31, 2026
401c41d
Merge branch 'ecmwf:develop' into develop
csjfwang Apr 3, 2026
dd2fabc
Merge branch 'ecmwf:develop' into develop
csjfwang Apr 24, 2026
4796c71
Merge branch 'develop' into spherical-rope
csjfwang Apr 24, 2026
9b8f8ad
remove rope_post_mod_qk_lnorm arg in the config, to make it
csjfwang Apr 24, 2026
200ceb7
clean resolve_rope_mode related code, no support rope_2D arg
csjfwang Apr 24, 2026
7eea678
clean spherical_harmonics_band_all_pixels, etc.
csjfwang Apr 24, 2026
f6e3800
clean conjugate arg and remove redundant num_complex slicing
csjfwang Apr 24, 2026
0123d7c
add comments to spherical rope core function - healpy_band_maps
csjfwang Apr 25, 2026
f4ee3e6
Merge branch 'ecmwf:develop' into develop
csjfwang Apr 25, 2026
cef6341
Merge branch 'develop' into spherical-rope
csjfwang Apr 25, 2026
f72a172
Merge branch 'ecmwf:develop' into develop
csjfwang Apr 28, 2026
d0a05dc
allow for backwards compatability the rope_2D option and add a
csjfwang May 11, 2026
ba5fe4d
Merge branch 'ecmwf:develop' into develop
csjfwang May 11, 2026
6b7766d
Merge branch 'develop' into spherical-rope
csjfwang May 11, 2026
6797363
Suppress healpy logs during spherical RoPE setup
csjfwang May 11, 2026
04f6224
remove rope option to unrelated config, remove non-standard comments
csjfwang May 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion config/config_forecasting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ forecast_att_dense_rate: 1.0

healpix_level: 5

rope_2D: False
# Generalized RoPE selector.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make more sense to have a section positional_encoding in the future

rope_mode: none # one of: none, 2d, spherical
# Optional spherical harmonic band for spherical RoPE. If null, the model picks one
# conservative shared band that fits all spherical-RoPE attention modules.
rope_spherical_band: null

with_mixed_precision: True
with_flash_attention: True
Expand Down
2 changes: 0 additions & 2 deletions config/config_forecasting_eerie.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ forecast_att_dense_rate: 1.0

healpix_level: 5

rope_2D: False

with_mixed_precision: True
with_flash_attention: True
compile_model: False
Expand Down
10 changes: 6 additions & 4 deletions config/config_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ forecast_att_dense_rate: 1.0
with_step_conditioning: True # False

healpix_level: 5
# Use 2D RoPE instead of traditional global positional encoding
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
rope_2D: False

# Generalized RoPE selector.
rope_mode: none # one of: none, 2d, spherical
# Optional spherical harmonic band for spherical RoPE. If null, the model picks one
# conservative shared band that fits all spherical-RoPE attention modules.
rope_spherical_band: null

with_mixed_precision: True
with_flash_attention: True
Expand Down
9 changes: 5 additions & 4 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ num_register_tokens: 0

healpix_level: 5

# Use 2D RoPE instead of traditional global positional encoding
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
rope_2D: False
# Generalized RoPE selector.
rope_mode: none # one of: none, 2d, spherical
# Optional spherical harmonic band for spherical RoPE. If null, the model picks one
# conservative shared band that fits all spherical-RoPE attention modules.
rope_spherical_band: null

with_mixed_precision: True
with_flash_attention: True
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ requires-python = ">=3.12,<3.13"
dependencies = [
'numpy~=2.2',
'astropy_healpix~=1.1.2',
'healpy>=1.19,<2',
'zarr~=3.1.3',
'pandas~=2.2',
'tqdm',
Expand Down Expand Up @@ -273,4 +274,3 @@ members = [
# Explicitly not depending on 'packages/dashboard' : this causes issues when deploying
# the streamlit dashboard.
]

66 changes: 45 additions & 21 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from weathergen.model.norms import AdaLayerNorm, RMSNorm
from weathergen.model.positional_encoding import rotary_pos_emb_2d
from weathergen.model.positional_encoding import apply_rope

"""
Attention blocks used by WeatherGenerator.

Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D
coordinates aligned with the token order (lat, lon in radians).
Some blocks optionally apply RoPE-like positional modulation. When enabled, the caller must
provide per-token coordinates aligned with the token order (lat, lon in radians).
"""


Expand All @@ -40,7 +40,7 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHeadVarlen, self).__init__()

Expand All @@ -49,7 +49,10 @@ def __init__(
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -79,6 +82,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype

Expand All @@ -96,10 +102,12 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s)

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1)
qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 1
)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0
Expand Down Expand Up @@ -225,15 +233,18 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHeadLocal, self).__init__()

self.num_heads = num_heads
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -263,6 +274,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype
assert with_flash, "Only flash attention supported."
Expand All @@ -288,10 +302,12 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3])
vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3])

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1)
qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 1
)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2)

Expand Down Expand Up @@ -540,7 +556,7 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHead, self).__init__()

Expand All @@ -549,7 +565,10 @@ def __init__(
self.softcap = softcap
self.dropout_rate = dropout_rate
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -579,6 +598,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype
if with_flash:
Expand All @@ -599,10 +621,12 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2)
qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 2
)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0
Expand Down
25 changes: 22 additions & 3 deletions src/weathergen/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def forward(self, model_params, batch):
tokens_global = checkpoint(
self.ae_global_engine,
tokens_global,
coords=model_params.rope_coords,
coords=(
model_params.rope_spherical_coeffs.unbind(dim=-1)
if model_params.rope_spherical_coeffs is not None
else model_params.rope_coords
),
use_reentrant=False,
)

Expand Down Expand Up @@ -221,6 +225,8 @@ def aggregation_engine_unmasked(
tokens_global_register_class,
tokens_lens,
rope_cell_coords=None,
rope_cell_coeffs=None,
rope_extra_coeffs=None,
):
"""
Aggregation engine on the global latents of unmasked cells
Expand Down Expand Up @@ -251,8 +257,19 @@ def aggregation_engine_unmasked(
)

# Build packed coords matching the interleaved token order
if rope_cell_coords is not None:
num_extra = self.num_class_tokens + self.num_register_tokens
num_extra = self.num_class_tokens + self.num_register_tokens
if rope_cell_coeffs is not None:
extra_real, extra_imag = rope_extra_coeffs.unbind(dim=-1)
cell_real, cell_imag = rope_cell_coeffs.unbind(dim=-1)
packed_real = []
packed_imag = []
for mask_b in cell_mask.flatten(0, 1):
packed_real.append(extra_real)
packed_imag.append(extra_imag)
packed_real.append(cell_real[mask_b])
packed_imag.append(cell_imag[mask_b])
packed_coords = (torch.cat(packed_real, dim=0), torch.cat(packed_imag, dim=0))
elif rope_cell_coords is not None:
zero_coords = torch.zeros(
num_extra, 2, device=rope_cell_coords.device, dtype=rope_cell_coords.dtype
)
Expand Down Expand Up @@ -316,6 +333,8 @@ def assimilate_local(
tokens_global_register_class,
batch.tokens_lens,
rope_cell_coords=model_params.rope_cell_coords,
rope_cell_coeffs=model_params.rope_spherical_cell_coeffs,
rope_extra_coeffs=model_params.rope_spherical_extra_coeffs,
)

# final processing
Expand Down
14 changes: 9 additions & 5 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
StreamEmbedTransformer,
)
from weathergen.model.layers import MLP
from weathergen.model.positional_encoding import get_rope_mode
from weathergen.model.utils import ActivationFactory
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
super(QueryAggregationEngine, self).__init__()
self.cf = cf
self.num_healpix_cells = num_healpix_cells
rope_mode = get_rope_mode(self.cf)

self.ae_aggregation_blocks = torch.nn.ModuleList()

Expand All @@ -409,7 +411,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type),
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.get("rope_2D", False),
rope_mode=rope_mode,
)
)
else:
Expand Down Expand Up @@ -465,6 +467,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
super(GlobalAssimilationEngine, self).__init__()
self.cf = cf
self.num_healpix_cells = num_healpix_cells
rope_mode = get_rope_mode(self.cf)

self.ae_global_blocks = torch.nn.ModuleList()

Expand All @@ -485,7 +488,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type),
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.get("rope_2D", False),
rope_mode=rope_mode,
)
)
else:
Expand All @@ -502,7 +505,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type),
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.get("rope_2D", False),
rope_mode=rope_mode,
)
)
# MLP block
Expand Down Expand Up @@ -553,6 +556,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int =
super(ForecastingEngine, self).__init__()
self.cf = cf
self.num_healpix_cells = num_healpix_cells
rope_mode = get_rope_mode(self.cf)
self.fe_blocks = torch.nn.ModuleList()

global_rate = int(1 / self.cf.forecast_att_dense_rate)
Expand All @@ -572,7 +576,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int =
dim_aux=dim_aux,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.get("rope_2D", False),
rope_mode=rope_mode,
)
)
else:
Expand All @@ -590,7 +594,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int =
dim_aux=dim_aux,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_2d_rope=self.cf.get("rope_2D", False),
rope_mode=rope_mode,
)
)
# Add MLP block
Expand Down
Loading
Loading