Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
644273a
feat: add timing metrics (startup, training, overall)
florianscheidl Apr 17, 2026
92120b8
docs: add agent structure with skills, tasks, and docs
florianscheidl Apr 17, 2026
df69dcb
docs: add skills review cycle for periodic compactification
florianscheidl Apr 17, 2026
7f0648f
configs
florianscheidl Apr 17, 2026
8fe45a0
Merge branch 'feature/timing-metrics' into ekfs/scaling-plots-20260417
florianscheidl Apr 17, 2026
dd55fb0
Remove hermes tool tracking for now
florianscheidl Apr 17, 2026
09b6e82
Try duration metrics
florianscheidl Apr 17, 2026
da3c29b
Update metrics, store after each mini-epoch
florianscheidl Apr 17, 2026
fc9a111
Refactor configs/streams
florianscheidl Apr 20, 2026
cfc4c62
Extract scaling data
florianscheidl Apr 20, 2026
82b503a
Script to generate scaling plots
florianscheidl Apr 20, 2026
70053b1
Script update
florianscheidl Apr 20, 2026
0c2df97
Repeat data in mini epoch
florianscheidl Apr 20, 2026
2c79d28
corrected time window length
florianscheidl Apr 20, 2026
6374986
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 20, 2026
b5d70f6
Lower to 512 samples per mini epoch
florianscheidl Apr 20, 2026
f46828c
Updated extraction script
florianscheidl Apr 20, 2026
89ac519
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 20, 2026
7cad6b5
Log time more often
florianscheidl Apr 21, 2026
30ac102
Fix training start scope
florianscheidl Apr 21, 2026
5e7f63e
Minimal validation
florianscheidl Apr 22, 2026
2be95c6
Increase samples_per_mini_epoch to 1024
florianscheidl Apr 22, 2026
93b203b
Final training duration and terminal/metric logging
florianscheidl Apr 23, 2026
2b708e3
log metrics after mini-epoch
florianscheidl Apr 23, 2026
0d8407d
Log metrics after mini-epoch, change schema
florianscheidl Apr 23, 2026
422fc60
MEtric typo
florianscheidl Apr 23, 2026
f63cba9
Logging refactor
florianscheidl Apr 23, 2026
b596c14
Update extraction script
florianscheidl Apr 23, 2026
42ba646
NNode extraction
florianscheidl Apr 23, 2026
c9fa64d
Logs path
florianscheidl Apr 23, 2026
ccfbc64
Wait until all training complete and wait with validation until logs …
florianscheidl Apr 23, 2026
c0f96b7
Log seconds rather than hours
florianscheidl Apr 23, 2026
701eb00
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 23, 2026
e6475e9
Measure dataset advancement time
florianscheidl Apr 23, 2026
6fd001f
LR scheduler lower bounds
florianscheidl Apr 23, 2026
313cec6
At least two warmup steps
florianscheidl Apr 23, 2026
aa4d399
Len per rank at least 1 to avoid zero division error
florianscheidl Apr 24, 2026
7956c52
Write csv for easier viewing
florianscheidl Apr 24, 2026
cf659e1
Extraction and plotting
florianscheidl Apr 24, 2026
177df79
Remove parent dir creation
florianscheidl Apr 24, 2026
8a4bc56
more detailed extraction script
florianscheidl Apr 24, 2026
bca6d3d
Remove overall time logging
florianscheidl Apr 24, 2026
9436811
Cleanup trainer
florianscheidl Apr 24, 2026
21c1575
Metrics extraction and plot generation scripts
florianscheidl Apr 24, 2026
e67616a
Add efficiency factor in plot
florianscheidl Apr 24, 2026
b1e4ea4
RM checkpoint and log metrics at last iteration
florianscheidl Apr 27, 2026
c89fe20
Detailed metrics
florianscheidl Apr 27, 2026
c14b749
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 27, 2026
b0bc6c2
Remove barrier and extra logging on last batch
florianscheidl Apr 27, 2026
133ee4c
trainer code cleanup
florianscheidl Apr 27, 2026
ec665da
Lower bound beta2 in adam
florianscheidl Apr 27, 2026
2088311
update script for scaling plots, loss as separate entry point
florianscheidl Apr 28, 2026
5bd88d9
specify nodes in scaling data script
florianscheidl Apr 29, 2026
7e1ae1c
Update extract scaling data
florianscheidl Apr 29, 2026
b42432c
Add pyarrow
florianscheidl Apr 30, 2026
6d5683b
Update script for scaling plots
florianscheidl Apr 30, 2026
0396290
Update to generating scaling plots
florianscheidl May 4, 2026
86093d7
Merge branch 'develop' into ekfs/scaling-plots-20260417
florianscheidl May 4, 2026
6800262
Move scaling scripts to package
florianscheidl May 4, 2026
c5af276
init refactor
florianscheidl May 4, 2026
e760c13
Setup and linting
florianscheidl May 4, 2026
556106e
Updated plot generation script
florianscheidl May 4, 2026
b93813d
Update readme
florianscheidl May 4, 2026
b2fe866
Fewer diffs
florianscheidl May 4, 2026
c574514
no gitignore changes
florianscheidl May 4, 2026
dad5462
Refactor logging and move time for mini epoch logging outside loop
florianscheidl May 4, 2026
4f11519
Formatting and style fixes
florianscheidl May 4, 2026
b02b38f
Update config
florianscheidl May 4, 2026
55d8219
Avoid duplicate metrics
florianscheidl May 4, 2026
904713d
Fix lint issues
florianscheidl May 4, 2026
9ecd544
t_training in __init__
florianscheidl May 4, 2026
9f02dc1
Renamed metric
florianscheidl May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 289 additions & 0 deletions config/config_era5_georing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 8
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 4
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 0
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 0
num_register_tokens: 0

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
fe_num_blocks: 6
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

healpix_level: 5

# Use 2D RoPE instead of traditional global positional encoding
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
rope_2D: False

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

freeze_modules: ""
load_chkpt: {}

norm_type: "LayerNorm"
qk_norm_type: null # if null, defaults to norm_type

#####################################

streams_directory: "./config/streams/era5_georing/"
streams: ???

# type of zarr_store
zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

# local_rank,
# with_ddp,
# data_path_*,
# model_path,
# run_path,
# path_shared_

multiprocessing_method: "fork"

desc: ""
run_id: ???
run_history: []

# logging frequency in the training loop (in number of batches)
train_logging:
terminal: 16
metrics: 16
checkpoint: 256
log_grad_norms: False

# parameters for data loading
data_loading :

num_workers: 12
rng_seed: ???
repeat_data_in_mini_epoch : True

# pin GPU memory for faster transfer; it is possible that enabling memory_pinning with
# FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error.
# If this happens, you can disable the flag, but performance will drop on GH200.
memory_pinning: True


# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["masking"]

num_mini_epochs: 64
samples_per_mini_epoch: 4096
shuffle: True

start_date: 2014-01-01T00:00
end_date: 2022-12-31T00:00

time_window_step: 01:00:00
time_window_len: 06:00:00

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 1024
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08

losses : {
"physical": {
type: LossPhysical,
loss_fcts: { "mse": { }, },
},
}

model_input: {
"forecasting" : {
# masking strategy: "random", "healpix", "forecast"
masking_strategy: "forecast",
},
}

forecast :
time_step: 06:00:00
num_steps: 3
offset: 1
policy: "fixed"


# validation config; full validation config is merge of training and validation config
validation_config:

samples_per_mini_epoch: 1
shuffle: True

start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00

# whether to track the exponential moving average of weights for validation
validate_with_ema:
enabled : True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# parameters for validation samples that are written to disk
output : {
# number of samples that are written
num_samples: 0,
# write samples in normalized model space
normalized_samples: False,
# output streams to write; default all
streams: null,
}

# run validation before training starts (mainly for model development)
validate_before_training: False


# validation config; full validation config is merge of training and validation config
test_config:

samples_per_mini_epoch: 1
shuffle: False

start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00

# whether to track the exponential moving average of weights for validation
validate_with_ema:
enabled : True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

# parameters for validation samples that are written to disk
output : {
# number of samples that are written
num_samples: 0,
# write samples in normalized model space
normalized_samples: False,
# output streams to write; default all
streams: null,
}

# run validation before training starts (mainly for model development)
validate_before_training: False


# test config; full test config is merge of validation and test config
# test config is used by default when running inference

# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: null
# The Github issue corresponding to this run (number such as 1234)
# Github issues are the central point when running experiment and contain
# links to hedgedocs, code branches, pull requests etc.
# It is recommended to associate a run with a Github issue.
issue: null
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow, along with the
# issue number.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: null
# *** Experiment-specific tags ***
# All extra tags (including lists, dictionaries, etc.) are treated
# as strings by mlflow, so treat all extra tags as simple string key: value pairs.
grid: null
36 changes: 36 additions & 0 deletions config/streams/era5_georing/era5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# (C) Copyright 2024 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

ERA5 :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
stream_id : 0
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
loss_weight : 1.
location_weight : cosine_latitude
token_size : 8
tokenize_spacetime : True
max_num_targets: -1
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 256
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 512
target_readout :
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
Loading
Loading