Skip to content

Bug report: mse_ens loss function has the old signature #2250

@wael-mika

Description

@wael-mika

What happened?

The old mse_ens had the pre-refactor signature (target, ens, mu, stddev) and returned a scalar. When the loss framework does loss, loss_chs = loss_fct(...), Python tried to unpack the scalar into two values, which triggers TypeError: iteration over a 0-d tensor

0: TypeError: iteration over a 0-d tensor
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/run_train.py", line 160, in run_continue
3:     trainer.run(cf, devices, args.from_run_id, args.mini_epoch)
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/train/trainer.py", line 377, in run
3:     self.train(mini_epoch)
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/train/trainer.py", line 458, in train
3:     loss = self.loss_calculator.compute_loss(
3:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/train/loss_calculator.py", line 94, in compute_loss
3:     loss_values = calculator.compute_loss(
3:                   ^^^^^^^^^^^^^^^^^^^^^^^^
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/train/loss_modules/loss_module_physical.py", line 304, in compute_loss
3:     loss_lfct, loss_lfct_chs = self._loss_per_loss_function(
3:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/src/weathergen/train/loss_modules/loss_module_physical.py", line 148, in _loss_per_loss_function
3:     loss, loss_chs = loss_fct(
3:     ^^^^^^^^^^^^^^
3:   File "/e/scratch/weatherai/slurm/slurm_weathergen_x92xyt41_dir/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 1154, in __iter__
3:     raise TypeError("iteration over a 0-d tensor")
3: TypeError: iteration over a 0-d tensor

What are the steps to reproduce the bug?

Increase the number of ens_size in the stream config to any number higher than 2. in the stream config, you can specify per-stream loss:
loss_fcts:
"mse_ens":
weight: 1.0

And run training normally

Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No fields configured for Bug.

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions