diff --git a/config/config_era5_georing.yml b/config/config_era5_georing.yml new file mode 100644 index 000000000..7ef76620e --- /dev/null +++ b/config/config_era5_georing.yml @@ -0,0 +1,289 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 8 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 6 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" +qk_norm_type: null # if null, defaults to norm_type + +##################################### + +streams_directory: "./config/streams/era5_georing/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 16 + metrics: 16 + checkpoint: 256 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2014-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 01:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 1024 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 3 + offset: 1 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 1 + shuffle: True + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# validation config; full validation config is merge of training and validation config +test_config: + + samples_per_mini_epoch: 1 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/streams/era5_georing/era5.yml b/config/streams/era5_georing/era5.yml new file mode 100644 index 000000000..186b8d862 --- /dev/null +++ b/config/streams/era5_georing/era5.yml @@ -0,0 +1,36 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_georing/geos.yml b/config/streams/era5_georing/geos.yml new file mode 100644 index 000000000..e6cc89442 --- /dev/null +++ b/config/streams/era5_georing/geos.yml @@ -0,0 +1,77 @@ + +METEOSAT_SEVIRI : + type : obs + stream_id : 2 + # filenames : ['observations-od-ai-0001-201311-202505-msg-combined-seviri-o256-v1.zarr'] + filenames : ['observations-file-2014-2024-seviri-h512-v5.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +GOES_ABI : + type : obs + stream_id : 3 + # filenames : ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] + filenames : ['observations-file-2017-2024-abi-goes16-IR-h512-v2.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +HIMAWARI_AHI : + type : obs + stream_id : 4 + # filenames : ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr'] + filenames : ['observations-file-2015-2022-himawari8-IR-h512-v1.zarr'] + loss_weight : 1.0 + token_size : 128 + tokenize_spacetime : True + max_num_targets: 65536 + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/packages/performance/README.md b/packages/performance/README.md new file mode 100644 index 000000000..af1554d94 --- /dev/null +++ b/packages/performance/README.md @@ -0,0 +1,43 @@ +# WeatherGenerator Performance Analysis Tools + +This package contains tools for extracting and analyzing scaling performance data from WeatherGenerator training runs. + +## Installation + +Install the optional performance tools: + +```bash +uv sync --extra performance +``` + +## Scripts + +### extract_scaling_data.py + +Extracts scaling metrics from WeatherGenerator training runs and writes parquet output. + +```bash +extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet +``` + +### generate_scaling_plots.py + +Generates scaling plots and tables from parquet/NDJSON data using named columns from the input files. + +```bash +generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log +``` + +## Suggested workflow + +1. Extract the scaling data into a parquet file (on your HPC). +2. Copy the parquet file to your local machine. +3. Generate plots from the parquet file. + +Example: + +```bash +extract_scaling_data --run-ids RUN_ID1 RUN_ID2 --output scaling.parquet +scp user@remote:/path/to/scaling.parquet . +generate_scaling_plots standard --input scaling.parquet --type strong --y-scale log +``` diff --git a/packages/performance/pyproject.toml b/packages/performance/pyproject.toml new file mode 100644 index 000000000..933323e3f --- /dev/null +++ b/packages/performance/pyproject.toml @@ -0,0 +1,35 @@ +[project] +name = "weathergen-performance" +version = "0.1.0" +description = "Performance analysis tools for WeatherGenerator" +readme = "README.md" + +requires-python = ">=3.12,<3.13" +dependencies = [ + "polars~=1.25.2", + "pandas~=2.2", + "matplotlib", + "pyarrow>=23.0.1", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/performance"] + +[project.scripts] +extract_scaling_data = "performance.extract_scaling_data:main" +generate_scaling_plots = "performance.generate_scaling_plots:main" + +[tool.ruff.lint] +select = [ + "E", + "F", + "UP", + "B", + "SIM", + "I", + "N" +] diff --git a/packages/performance/src/performance/__init__.py b/packages/performance/src/performance/__init__.py new file mode 100644 index 000000000..7105b9746 --- /dev/null +++ b/packages/performance/src/performance/__init__.py @@ -0,0 +1 @@ +"""WeatherGenerator performance analysis tools.""" diff --git a/packages/performance/src/performance/extract_scaling_data.py b/packages/performance/src/performance/extract_scaling_data.py new file mode 100644 index 000000000..7af725ef6 --- /dev/null +++ b/packages/performance/src/performance/extract_scaling_data.py @@ -0,0 +1,311 @@ +#!/usr/bin/env uv run python +"""Extract strong scaling data from WeatherGenerator runs. + +Outputs parquet with: +- run_id, num_nodes, training_time +- overall_time_seconds, loss_avg_mean +""" + +import argparse +import re +import sys +from pathlib import Path + +import pandas as pd + + +def extract_num_nodes_from_output(run_id: str, logs_base_dir: Path) -> int | None: + """Extract num_nodes from output.*.txt file in the run directory. + + Looks for 'nNodes' pattern in output files. + """ + run_log_dir = logs_base_dir / run_id + if not run_log_dir.exists(): + return None + + # Look for output.*.txt files + output_files = list(run_log_dir.glob("output.*.txt")) + if not output_files: + # Fallback to err files if no output files found + output_files = list(run_log_dir.glob("weathergen.*.err")) + + for output_file in output_files: + try: + content = output_file.read_text() + # Look for nNodes pattern: "nNodes 128" (space-separated, as in NCCL logs) + match = re.search(r"nNodes\s+(\d+)", content, re.IGNORECASE) + if match: + return int(match.group(1)) + except Exception: + continue + + return None + + +def extract_metrics_from_run_id(run_id: str, shared_work_dir: Path) -> dict | None: + """Extract metrics from NDJSON file with startup and training lines. + + Format: + - Line 1: startup_time_seconds + - Line 2+: loss_avg_mean, LossPhysical.loss_avg, etc. + """ + metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" + if not metrics_path.exists(): + return None + try: + df = pd.read_json(metrics_path, lines=True) + if len(df) == 0: + return None + + # Extract startup_time from first row (startup line) + startup_time = None + if "startup_time_seconds" in df.columns: + val = df["startup_time_seconds"].dropna() + startup_time = val.iloc[0] if len(val) > 0 else None + + # Extract loss_avg_mean from last non-NaN training row + loss_avg_mean = None + if "loss_avg_mean" in df.columns: + val = df["loss_avg_mean"].dropna() + loss_avg_mean = val.iloc[-1] if len(val) > 0 else None + + # Extract training time for mini-epoch from last non-NaN row + overall_training_time = None + if "training_time_after_mini_epoch_seconds" in df.columns: + val = df["training_time_after_mini_epoch_seconds"].dropna() + overall_training_time = val.iloc[-1] if len(val) > 0 else None + + return { + "startup_time_seconds": startup_time, + "training_time": overall_training_time, + "loss_avg_mean": loss_avg_mean, + } + except Exception: + return None + + +def extract_detailed_metrics( + run_id: str, shared_work_dir: Path, num_nodes: int | None = None +) -> list[pd.DataFrame]: + """Extract detailed metrics pairing timing rows with preceding loss rows. + + For each row containing elapsed_training_time_seconds, pair it with the + preceding row containing loss metrics. Returns a list of DataFrames. + """ + metrics_path = shared_work_dir / "results" / run_id / f"{run_id}_train_metrics.json" + if not metrics_path.exists(): + return [] + + try: + df = pd.read_json(metrics_path, lines=True) + if len(df) == 0: + return [] + + # Find rows with elapsed_training_time_seconds (timing rows) + if "elapsed_training_time_seconds" not in df.columns: + return [] + timing_indices = df.index[df["elapsed_training_time_seconds"].notna()].tolist() + + if not timing_indices: + return [] + + # Find rows with loss data + if "loss_avg_mean" not in df.columns: + return [] + loss_indices = set(df.index[df["loss_avg_mean"].notna()].tolist()) + + timing_cols = [ + "elapsed_training_time_seconds", + "total_num_samples", + "average_samples_per_second", + ] + + detailed_records = [] + + for timing_idx in timing_indices: + # Find the last loss row before this timing row + loss_rows_before = [i for i in loss_indices if i < timing_idx] + if not loss_rows_before: + continue + + last_loss_idx = max(loss_rows_before) + + # Build record dict from loss row + timing row + record = {"run_id": run_id} + if num_nodes is not None: + record["num_nodes"] = num_nodes + + record["loss_avg_mean"] = df.at[last_loss_idx, "loss_avg_mean"] + + for col in timing_cols: + if col in df.columns: + record[col] = df.at[timing_idx, col] + + detailed_records.append(pd.DataFrame([record])) + + return detailed_records + + except Exception as e: + print(f"Error extracting detailed metrics for {run_id}: {e}") + import traceback + + traceback.print_exc() + return [] + + +def parse_run_ids(run_ids_str: list[str]) -> list[tuple[int | None, str]]: + """Parse run-ids argument which can be: + + 1. A list of run-ids (old format): + ["run1", "run2"] -> [(None, "run1"), (None, "run2")] + 2. A dict mapping num_nodes to run-ids (new format): + "{1: run1, 4: run2}" -> [(1, "run1"), (4, "run2")] + + Returns list of (num_nodes, run_id) tuples. + """ + if len(run_ids_str) == 1: + # Check if it looks like a dict: "{key: value, ...}" + stripped = run_ids_str[0].strip() + if stripped.startswith("{") and stripped.endswith("}"): + # Parse as dict format: {num_nodes: run_id, ...} + import ast + + try: + parsed = ast.literal_eval(stripped) + if isinstance(parsed, dict): + # Convert string keys to int if needed + result = [] + for k, v in parsed.items(): + key = int(k) if isinstance(k, str) and k.isdigit() else k + result.append((key, str(v))) + return result + except (ValueError, SyntaxError): + pass + + # Single run-id or comma-separated list + run_ids = [r.strip() for r in run_ids_str[0].split(",") if r.strip()] + return [(None, run_id) for run_id in run_ids] + + # Multiple arguments - treat as list of run-ids + return [(None, run_id) for run_id in run_ids_str] + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Extract strong scaling data from WeatherGenerator runs. " + "Run-ids can be provided as a list (--run-ids run1 run2) or as a dict " + "mapping num_nodes to run-ids (--run-ids '{1: run1, 4: run2}'). " + "If num_nodes is not provided in the dict, it will be extracted " + "from output.*.txt files." + ) + ) + parser.add_argument( + "--run-ids", + nargs="+", + help=( + "Run-ids to process. Can be: (1) list: run1 run2 run3, or " + "(2) dict: '{1: run1, 4: run2, 8: run3}'" + ), + ) + parser.add_argument( + "--logs-base-dir", + type=Path, + default=Path("logs"), + help="Base directory for run logs (default: logs relative to current dir)", + ) + parser.add_argument( + "--shared-work-dir", + type=Path, + default=Path("/e/scratch/weatherai/shared_work"), + help="Base directory for shared work/results", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("scaling_data.parquet"), + help="Output parquet file path", + ) + + args = parser.parse_args() + + run_id_mapping = parse_run_ids(args.run_ids) + if not run_id_mapping: + sys.exit("Error: No run-ids provided") + + results = [] + all_detailed_records = [] + + for num_nodes, run_id in run_id_mapping: + # If num_nodes not provided, extract from output.*.txt file + if num_nodes is None: + num_nodes = extract_num_nodes_from_output(run_id, args.logs_base_dir) + + metrics = extract_metrics_from_run_id(run_id, args.shared_work_dir) + if metrics is None: + continue + + row = { + "run_id": run_id, + "num_nodes": num_nodes, + "startup_time_seconds": metrics.get("startup_time_seconds"), + "training_time": metrics.get("training_time"), + "loss_avg_mean": metrics.get("loss_avg_mean"), + } + results.append(row) + + detailed_records = extract_detailed_metrics( + run_id, args.shared_work_dir, num_nodes + ) + if detailed_records: + all_detailed_records.extend(detailed_records) + print( + f"Extracted {len(detailed_records)} detailed metric entries " + f"for {run_id} ({num_nodes} nodes)" + ) + + if not results: + sys.exit("No data extracted") + + df = pd.DataFrame(results) + if "num_nodes" in df.columns: + df = df.sort_values("num_nodes").reset_index(drop=True) + + args.output.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(args.output, index=False) + df.to_csv(args.output.with_suffix(".csv"), index=False) + + # Write detailed metrics if any were collected + if all_detailed_records: + detailed_df = pd.concat(all_detailed_records, ignore_index=True) + + desired_cols = [ + "run_id", + "num_nodes", + "elapsed_training_time_seconds", + "total_num_samples", + "average_samples_per_second", + "loss_avg_mean", + ] + available_cols = [c for c in desired_cols if c in detailed_df.columns] + detailed_df = detailed_df[available_cols] + + output_stem = args.output.stem + detailed_output = args.output.with_name( + f"{output_stem}_detailed{args.output.suffix}" + ) + + detailed_df.to_parquet(detailed_output, index=False) + detailed_df.to_csv(detailed_output.with_suffix(".csv"), index=False) + + print("\nSummary:") + print(f" - Extracted {len(results)} run summaries to {args.output}") + if all_detailed_records: + print( + f" - Extracted {len(all_detailed_records)} detailed metric entries " + f"to {detailed_output}" + ) + + +if __name__ == "__main__": + main() diff --git a/packages/performance/src/performance/generate_scaling_plots.py b/packages/performance/src/performance/generate_scaling_plots.py new file mode 100644 index 000000000..b33073b82 --- /dev/null +++ b/packages/performance/src/performance/generate_scaling_plots.py @@ -0,0 +1,999 @@ +#!/usr/bin/env uv run python +"""Generate scaling plots from parquet/ndjson data using matplotlib only. + +Entry points: +- standard: plots run-level metrics vs num_nodes +- detailed: plots sample-level metrics vs total_num_samples +- combined: generates a comparison table from separate strong and weak + scaling input files + +Usage: + # Single scaling type (original behavior) + python -m performance.generate_scaling_plots standard --type strong \ + --input strong_data.parquet + + # Combined table from single file with both types + python -m performance.generate_scaling_plots standard \ + --type strong,weak --input data.parquet + + # Combined table from separate strong and weak input files (new) + python -m performance.generate_scaling_plots combined \ + --strong-input strong_data.parquet \ + --weak-input weak_data.parquet + + # Loss plot + python -m performance.generate_scaling_plots loss --type strong \ + --input data.parquet + + # Detailed scaling plot + python -m performance.generate_scaling_plots detailed \ + --input detailed_data.parquet +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import polars as pl + +SCRIPT_DIR = Path(__file__).resolve().parent +VALID_IMAGE_SUFFIXES = {".png", ".pdf", ".svg", ".jpg", ".jpeg"} +PALETTE = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", +] + + +def resolve_input_path(path: Path) -> Path: + """Resolve relative input paths against cwd first, then the script directory.""" + if path.is_absolute(): + return path + + cwd_candidate = Path.cwd() / path + if cwd_candidate.exists(): + return cwd_candidate + + script_candidate = SCRIPT_DIR / path + if script_candidate.exists(): + return script_candidate + + return cwd_candidate + + +def resolve_output_path(path: Path) -> Path: + """Ensure the output path uses a supported image suffix.""" + if path.suffix.lower() in VALID_IMAGE_SUFFIXES: + return path + return path.with_suffix(".png") + + +def read_table(path: Path) -> pl.DataFrame: + """Read parquet or ndjson automatically.""" + try: + print("Read as parquet") + return pl.read_parquet(path) + except Exception: + print("Read as NDJSON") + return pl.read_ndjson(path) + + +def color_map_for_nodes(node_counts: list) -> dict: + return {node: PALETTE[i % len(PALETTE)] for i, node in enumerate(node_counts)} + + +def save_figure(fig: plt.Figure, output_path: Path) -> None: + output_path = resolve_output_path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {output_path}") + + +def generate_scaling_table( + df: pl.DataFrame, + input_path: Path, + show_run_ids: bool = False, + scaling_types: list[str] = None, +) -> None: + """Generate a PNG table image with scaling metrics from the parquet file. + + Columns: num_nodes, training_time, ideal_time, efficiency (optionally run_id) + If scaling_types has multiple types, generates a combined table with + columns per type. + """ + # Check if required columns exist + if "num_nodes" not in df.columns or "training_time" not in df.columns: + print("Warning: Required columns (num_nodes, training_time) not found in data") + return + + # Filter out rows with null values in required columns + df_filtered = df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + if len(df_filtered) == 0: + print("No valid data for scaling table") + return + + # Get the 1-node training time for ideal time calculation + one_node_data = df_filtered.filter(pl.col("num_nodes") == 1) + if one_node_data.height == 0: + print("Warning: No 1-node data found for ideal time calculation") + return + + t1 = one_node_data["training_time"].item() + + # Determine scaling types to include + if scaling_types is None or len(scaling_types) == 0: + # Derive scaling type from input filename + input_name_lower = input_path.name.lower() + if "weak" in input_name_lower: + scaling_types = ["weak"] + elif "strong" in input_name_lower: + scaling_types = ["strong"] + else: + scaling_types = ["strong"] # Default to strong + + # Build table data with proper formatting + has_run_id = "run_id" in df_filtered.columns + + # Check if we're generating a combined table (multiple types) + is_combined = len(scaling_types) > 1 + + if is_combined: + # Combined table: columns per type + col_names = ["# Nodes"] + for stype in scaling_types: + col_names.extend( + [ + f"{stype.capitalize()} Training Time (seconds)", + f"{stype.capitalize()} Efficiency", + ] + ) + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + else: + # Single type table (original format) + scaling_type = scaling_types[0].capitalize() + col_names = [ + "# Nodes", + "Training Time (seconds)", + "Ideal Time (seconds)", + "Efficiency", + ] + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + + table_rows = [] + for row in df_filtered.iter_rows(named=True): + num_nodes = row["num_nodes"] + training_time = row["training_time"] + + if is_combined: + # Combined table: add metrics for each type + row_data = {} + if show_run_ids and has_run_id: + row_data["run_id"] = str(row.get("run_id", "")) + row_data["# Nodes"] = str(num_nodes) + + for stype in scaling_types: + time_col = f"{stype.capitalize()} Training Time (seconds)" + eff_col = f"{stype.capitalize()} Efficiency" + row_data[time_col] = f"{training_time:.2f}" + if num_nodes == 1: + row_data[eff_col] = "-" + else: + if stype == "strong": + # Strong scaling: ideal time = t1 / num_nodes + ideal_val = t1 / num_nodes + efficiency_val = ideal_val / training_time + else: + # Weak scaling: ideal time = t1 (same work per node) + ideal_val = t1 + efficiency_val = min(1.0, t1 / training_time) + + row_data[eff_col] = f"{efficiency_val:.2f}" + else: + # Single type table (original format) + scaling_type = scaling_types[0] + row_data = {} + if show_run_ids and has_run_id: + row_data["run_id"] = str(row.get("run_id", "")) + row_data["# Nodes"] = str(num_nodes) + row_data["Training Time (seconds)"] = f"{training_time:.2f}" + if num_nodes == 1: + row_data["Ideal Time (seconds)"] = "-" + row_data["Efficiency"] = "-" + else: + if scaling_type == "strong": + # Strong scaling: ideal time = t1 / num_nodes + ideal_val = t1 / num_nodes + efficiency_val = ideal_val / training_time + else: + # Weak scaling: ideal time = t1 (same work per node) + ideal_val = t1 + efficiency_val = min(1.0, t1 / training_time) + + row_data["Ideal Time (seconds)"] = f"{ideal_val:.2f}" + row_data["Efficiency"] = f"{efficiency_val:.2f}" + + table_rows.append(row_data) + + # Generate output filename: input_stem_table.csv + output_path = input_path.with_name(input_path.stem + "_table.csv") + + # Build DataFrame for CSV output from named columns. + df_table = pl.DataFrame( + [{col: row.get(col, "-") for col in col_names} for row in table_rows] + ) + + # Write to CSV + df_table.write_csv(output_path) + print(f"Saved scaling table: {output_path}") + + +def generate_combined_scaling_table( + strong_df: pl.DataFrame, + weak_df: pl.DataFrame, + strong_path: Path, + weak_path: Path, + output_path: Path, + show_run_ids: bool = False, +) -> None: + """Generate a combined table comparing strong and weak scaling from two + separate input files. + + Rows: num_nodes + Columns: # Nodes, Strong Training Time, Strong Efficiency, + Weak Training Time, Weak Efficiency + + Also generates a PNG visualization of the table. + """ + # Validate required columns + for name, df in [("strong", strong_df), ("weak", weak_df)]: + if "num_nodes" not in df.columns or "training_time" not in df.columns: + print( + f"Warning: Required columns (num_nodes, training_time) not found " + f"in {name} data" + ) + return + + # Filter and sort both datasets + strong_filtered = strong_df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + weak_filtered = weak_df.filter( + pl.col("num_nodes").is_not_null() & pl.col("training_time").is_not_null() + ).sort("num_nodes") + + if len(strong_filtered) == 0 or len(weak_filtered) == 0: + print("No valid data for combined scaling table") + return + + # Get 1-node training times for efficiency calculation + strong_one_node = strong_filtered.filter(pl.col("num_nodes") == 1) + weak_one_node = weak_filtered.filter(pl.col("num_nodes") == 1) + + if strong_one_node.height == 0 or weak_one_node.height == 0: + print("Warning: No 1-node data found for efficiency calculation") + return + + t1_strong = strong_one_node["training_time"].item() + t1_weak = weak_one_node["training_time"].item() + + # Check for run_id in either dataset + has_run_id = ( + "run_id" in strong_filtered.columns or "run_id" in weak_filtered.columns + ) + + # Build column names + col_names = [ + "# Nodes", + "Strong Training Time (seconds)", + "Strong Efficiency", + "Weak Training Time (seconds)", + "Weak Efficiency", + ] + if show_run_ids and has_run_id: + col_names.insert(0, "run_id") + + # Get all unique num_nodes from both datasets + all_nodes = sorted( + set(strong_filtered["num_nodes"].to_list()) + | set(weak_filtered["num_nodes"].to_list()) + ) + + # Create lookup dictionaries for easy access + strong_lookup = { + row["num_nodes"]: row["training_time"] + for row in strong_filtered.iter_rows(named=True) + } + weak_lookup = { + row["num_nodes"]: row["training_time"] + for row in weak_filtered.iter_rows(named=True) + } + strong_run_id_lookup = ( + { + row["num_nodes"]: row["run_id"] + for row in strong_filtered.iter_rows(named=True) + } + if "run_id" in strong_filtered.columns + else {} + ) + weak_run_id_lookup = ( + {row["num_nodes"]: row["run_id"] for row in weak_filtered.iter_rows(named=True)} + if "run_id" in weak_filtered.columns + else {} + ) + + table_rows = [] + for num_nodes in all_nodes: + row_data = {} + + # Get run_id if available + if show_run_ids and has_run_id: + run_id = str( + strong_run_id_lookup.get( + num_nodes, weak_run_id_lookup.get(num_nodes, "") + ) + ) + row_data["run_id"] = run_id + + # Add num_nodes + row_data["# Nodes"] = str(num_nodes) + + # Strong scaling metrics + if num_nodes in strong_lookup: + training_time_strong = strong_lookup[num_nodes] + row_data["Strong Training Time (seconds)"] = f"{training_time_strong:.2f}" + if num_nodes == 1: + row_data["Strong Efficiency"] = "-" + else: + ideal_strong = t1_strong / num_nodes + row_data["Strong Efficiency"] = ( + f"{ideal_strong / training_time_strong:.2f}" + ) + else: + row_data["Strong Training Time (seconds)"] = "-" + row_data["Strong Efficiency"] = "-" + + # Weak scaling metrics + if num_nodes in weak_lookup: + training_time_weak = weak_lookup[num_nodes] + row_data["Weak Training Time (seconds)"] = f"{training_time_weak:.2f}" + if num_nodes == 1: + row_data["Weak Efficiency"] = "-" + else: + ideal_weak = t1_weak # Weak scaling: ideal is same as 1-node time + row_data["Weak Efficiency"] = ( + f"{min(1.0, ideal_weak / training_time_weak):.2f}" + ) + else: + row_data["Weak Training Time (seconds)"] = "-" + row_data["Weak Efficiency"] = "-" + + table_rows.append(row_data) + + # Ensure output path has .csv suffix + if output_path.suffix.lower() != ".csv": + output_path = output_path.with_suffix(".csv") + + # Build DataFrame for CSV output from named columns. + df_table = pl.DataFrame( + [{col: row.get(col, "-") for col in col_names} for row in table_rows] + ) + + # Write to CSV + df_table.write_csv(output_path) + print(f"Saved scaling table CSV: {output_path}") + + # Generate PNG visualization of the table + png_path = output_path.with_suffix(".png") + _save_table_as_image( + [[row.get(col, "-") for col in col_names] for row in table_rows], + col_names, + png_path, + ) + + +def _save_table_as_image(table_data: list, col_names: list, output_path: Path) -> None: + """Save table data as a PNG image using matplotlib. + + Automatically sizes the figure to fit all content. + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Calculate figure size based on content + num_cols = len(col_names) + num_rows = len(table_data) + 1 # +1 for header + + # Width: base + per-column width, Height: base + per-row height + fig_width = max(8, num_cols * 2.5) + fig_height = max(3, num_rows * 0.5) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + ax.axis("off") + + # Create table + table = ax.table( + cellText=table_data, + colLabels=col_names, + cellLoc="center", + loc="center", + colColours=["#2E5C8A"] * num_cols, + cellColours=[ + ["#E8ECEF" if i % 2 == 0 else "white" for _ in range(num_cols)] + for i in range(len(table_data)) + ], + ) + + # Style the table + table.auto_set_font_size(False) + table.set_fontsize(9) + table.auto_set_column_width(col=list(range(num_cols))) + + # Style header cells + for i in range(num_cols): + table[(0, i)].set_text_props(color="white", fontweight="bold") + + # Adjust layout and save + plt.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_standard_scaling( + df: pl.DataFrame, + output_path: Path, + scaling_type: str, + metrics: list[str], + x_scale: str, + y_scale: str, + y_metric: str, + show_run_ids: bool = False, +) -> None: + """Plot run-level scaling data vs num_nodes.""" + metric_labels = { + "training_time": "Training Time (seconds)", + "loss_avg_mean": "Average Loss", + "normalized_throughput": "Speedup", + "efficiency": "Scaling Efficiency", + } + + valid_metrics = [ + m + for m in metrics + if m in df.columns and df.filter(pl.col(m).is_not_null()).height > 0 + ] + if not valid_metrics: + print("No valid metrics to plot") + return + + fig, axes = plt.subplots( + len(valid_metrics), 1, figsize=(12, 6 * len(valid_metrics)), squeeze=False + ) + + for idx, metric in enumerate(valid_metrics): + ax = axes[idx][0] + df_plot = df.filter(pl.col(metric).is_not_null()).sort("num_nodes") + # node_counts removed - was only used for colors which is also removed + + # Handle normalized_throughput and efficiency metrics + if y_metric == "normalized_throughput" and metric == "training_time": + # Calculate normalized throughput: T1 / T + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + # Create a new dataframe with normalized throughput + df_plot = df_plot.with_columns( + (t1 / pl.col("training_time")).alias("normalized_throughput") + ) + plot_y = df_plot["normalized_throughput"] + else: + print( + "Warning: No 1-node data found for normalized throughput " + "calculation" + ) + continue + elif y_metric == "efficiency" and metric == "training_time": + # Calculate efficiency based on scaling type + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + if scaling_type == "strong": + # Strong scaling: efficiency = (t1 / num_nodes) / training_time + df_plot = df_plot.with_columns( + ((t1 / pl.col("num_nodes")) / pl.col("training_time")).alias( + "efficiency" + ) + ) + else: + # Weak scaling: efficiency = min(1.0, t1 / training_time) + df_plot = df_plot.with_columns( + pl.min_horizontal( + pl.lit(1.0), t1 / pl.col("training_time") + ).alias("efficiency") + ) + plot_y = df_plot["efficiency"] + else: + print("Warning: No 1-node data found for efficiency calculation") + continue + else: + plot_y = df_plot[metric] + + ax.plot( + df_plot["num_nodes"], + plot_y, + "o-", + color="steelblue", + markersize=8, + ) + + if show_run_ids: + for x, y, label in zip( + df_plot["num_nodes"], plot_y.to_list(), df_plot["run_id"], strict=False + ): + ax.text(x, y, label, ha="center", va="bottom", fontsize=8) + + if ( + metric == "training_time" + and y_metric in ("time", "normalized_throughput", "efficiency") + and "training_time" in df.columns + ): + one_node_data = df.filter(pl.col("num_nodes") == 1) + if one_node_data.height > 0: + t1 = one_node_data["training_time"].item() + nodes = df_plot["num_nodes"].to_list() + if y_metric == "efficiency": + # For efficiency, optimal is always 1.0 (100% efficiency) + optimal_y = [1.0 for _ in nodes] + elif scaling_type == "weak": + if y_metric == "normalized_throughput": + # For normalized throughput, optimal is 1.0 (no speedup loss) + optimal_y = [1.0 for _ in nodes] + else: + optimal_y = [t1 for _ in nodes] + elif scaling_type == "strong": + if y_metric == "normalized_throughput": + # For normalized throughput, optimal is n (linear speedup) + optimal_y = [float(n) for n in nodes] + else: + optimal_y = [t1 / n for n in nodes] + else: + raise ValueError(f"Invalid scaling type: {scaling_type}") + ax.plot(nodes, optimal_y, "r--", linewidth=1, label="Optimal scaling") + + # Show per-point efficiency loss as a vertical line and factor + # label. Use plot_y (normalized throughput if applicable) instead + # of df_plot[metric] + for x, y, y_opt in zip( + nodes, plot_y.to_list(), optimal_y, strict=False + ): + if y_opt == 0: + continue + factor = y / y_opt + ax.vlines( + x, + y_opt, + y, + colors="gray", + linestyles=":", + linewidth=1, + alpha=0.7, + ) + y_mid = (y + y_opt) / 2 + ax.annotate( + f"{factor:.2f}", + xy=(x, y_mid), + xytext=(4, 0), + textcoords="offset points", + fontsize=14, + fontweight="bold", + color="dimgray", + va="center", + ) + ax.legend() + + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_xlabel("Number of Nodes", fontsize=16) + if y_metric == "normalized_throughput" and metric == "training_time": + ax.set_ylabel("Speedup", fontsize=16) + elif y_metric == "efficiency" and metric == "training_time": + ax.set_ylabel("Scaling Efficiency", fontsize=16) + else: + ax.set_ylabel(metric_labels.get(metric, metric), fontsize=16) + ax.tick_params(axis="both", which="major", labelsize=14) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + save_figure(fig, output_path) + + +def plot_detailed_scaling( + df: pl.DataFrame, + output_path: Path, + x_scale: str, + y_scale: str, +) -> None: + """Plot sample-level detailed scaling data vs total_num_samples.""" + required_cols = [ + "total_num_samples", + "elapsed_training_time_seconds", + "loss_avg_mean", + "num_nodes", + ] + if not all(col in df.columns for col in required_cols): + print("Detailed metrics not available in this dataset") + print(f"Available columns: {df.columns}") + return + + df_plot = df.filter( + pl.col("total_num_samples").is_not_null() + & (pl.col("total_num_samples") > 0) + & pl.col("elapsed_training_time_seconds").is_not_null() + & pl.col("loss_avg_mean").is_not_null() + & pl.col("num_nodes").is_not_null() + ).sort("num_nodes", "total_num_samples") + + if len(df_plot) == 0: + print("No valid data for detailed scaling plots") + return + + node_counts = sorted(df_plot["num_nodes"].unique().to_list()) + colors = color_map_for_nodes(node_counts) + + fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True) + + ax = axes[0] + for node_count in node_counts: + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( + "total_num_samples" + ) + ax.plot( + df_node["total_num_samples"], + df_node["elapsed_training_time_seconds"], + "o-", + color=colors[node_count], + markersize=6, + label=f"{node_count} nodes", + ) + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_ylabel("Elapsed Training Time (seconds)", fontsize=16) + ax.set_title("Elapsed Training Time vs Samples", fontsize=16) + ax.tick_params(axis="both", which="major", labelsize=14) + ax.grid(True, alpha=0.3) + ax.legend(title="Node Count") + + ax = axes[1] + for node_count in node_counts: + df_node = df_plot.filter(pl.col("num_nodes") == node_count).sort( + "total_num_samples" + ) + ax.plot( + df_node["total_num_samples"], + df_node["loss_avg_mean"], + "o-", + color=colors[node_count], + markersize=6, + label=f"{node_count} nodes", + ) + ax.set_xscale(x_scale) + if y_scale == "log": + ax.set_yscale("log") + ax.set_xlabel("Total Number of Samples", fontsize=16) + ax.set_ylabel("Average Loss", fontsize=16) + ax.set_title("Loss vs Samples", fontsize=16) + ax.tick_params(axis="both", which="major", labelsize=14) + ax.grid(True, alpha=0.3) + ax.legend(title="Node Count") + + fig.suptitle("Detailed Scaling Analysis", fontsize=16) + plt.tight_layout() + save_figure(fig, output_path) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Generate scaling plots from parquet or NDJSON data" + ) + subparsers = parser.add_subparsers(dest="mode", required=True) + + standard = subparsers.add_parser( + "standard", help="Plot run-level scaling metrics vs num_nodes" + ) + standard.add_argument( + "--type", + required=True, + help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", + ) + standard.add_argument( + "--input", + type=Path, + default=Path("scaling_data.parquet"), + help="Input parquet/ndjson file", + ) + standard.add_argument("--output", type=Path, default=None, help="Output image path") + standard.add_argument( + "--y-scale", choices=["linear", "log"], default="linear", help="Y-axis scale" + ) + standard.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) + standard.add_argument( + "--y-metric", + choices=["time", "normalized_throughput", "efficiency"], + default="normalized_throughput", + help=( + "Y-axis metric: 'time' for time-to-solution, " + "'normalized_throughput' for T1/T, or 'efficiency' for scaling efficiency" + ), + ) + standard.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels on the plot and in the output table", + ) + + loss_only = subparsers.add_parser( + "loss", help="Plot loss metrics vs num_nodes (separate from throughput)" + ) + loss_only.add_argument( + "--type", + required=True, + help="Scaling type(s): 'strong', 'weak', or 'strong,weak' for combined table", + ) + loss_only.add_argument( + "--input", + type=Path, + default=Path("scaling_data.parquet"), + help="Input parquet/ndjson file", + ) + loss_only.add_argument( + "--output", type=Path, default=None, help="Output image path" + ) + loss_only.add_argument( + "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" + ) + loss_only.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) + loss_only.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels on the plot and in the output table", + ) + + combined = subparsers.add_parser( + "combined", + help=( + "Generate combined table comparing strong and weak scaling " + "from separate input files" + ), + ) + combined.add_argument( + "--strong-input", + type=Path, + required=True, + help="Input parquet/ndjson file for strong scaling", + ) + combined.add_argument( + "--weak-input", + type=Path, + required=True, + help="Input parquet/ndjson file for weak scaling", + ) + combined.add_argument( + "--output", type=Path, default=None, help="Output table path (CSV)" + ) + combined.add_argument( + "--show-run-ids", + action="store_true", + help="Show run_id labels in the output table", + ) + + detailed = subparsers.add_parser( + "detailed", help="Plot sample-level detailed scaling metrics" + ) + detailed.add_argument( + "--input", + type=Path, + default=Path("scaling_data_detailed.parquet"), + help="Input detailed parquet/ndjson file", + ) + detailed.add_argument("--output", type=Path, default=None, help="Output image path") + detailed.add_argument( + "--y-scale", choices=["linear", "log"], default="log", help="Y-axis scale" + ) + detailed.add_argument( + "--x-scale", choices=["linear", "log"], default="log", help="X-axis scale" + ) + + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + if args.mode == "standard": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".png") + + print(f"Loading data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read input file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} rows") + + # Parse scaling types from --type argument + scaling_types = [t.strip().lower() for t in args.type.split(",")] + for stype in scaling_types: + if stype not in ("strong", "weak"): + print( + f"Error: Invalid scaling type '{stype}'. Use 'strong', " + "'weak', or 'strong,weak'" + ) + return + + # Standard mode: only plot training_time with normalized throughput or time + metrics_to_plot = ["training_time"] + # Use the first type for plotting (or strong if combined) + plot_type = scaling_types[0] + plot_standard_scaling( + df, + output_path, + plot_type, + metrics_to_plot, + args.x_scale, + args.y_scale, + args.y_metric, + args.show_run_ids, + ) + # Generate scaling table + generate_scaling_table( + df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types + ) + return + + if args.mode == "loss": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".loss.png") + + print(f"Loading data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read input file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} rows") + + # Parse scaling types from --type argument + scaling_types = [t.strip().lower() for t in args.type.split(",")] + for stype in scaling_types: + if stype not in ("strong", "weak"): + print( + f"Error: Invalid scaling type '{stype}'. Use 'strong', " + "'weak', or 'strong,weak'" + ) + return + + # Loss mode: only plot loss_avg_mean + metrics_to_plot = ["loss_avg_mean"] + # Use the first type for plotting (or strong if combined) + plot_type = scaling_types[0] + plot_standard_scaling( + df, + output_path, + plot_type, + metrics_to_plot, + args.x_scale, + args.y_scale, + "time", + args.show_run_ids, + ) + # Generate scaling table + generate_scaling_table( + df, input_path, show_run_ids=args.show_run_ids, scaling_types=scaling_types + ) + return + + if args.mode == "detailed": + input_path = resolve_input_path(args.input) + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + return + + output_path = args.output or input_path.with_suffix(".png") + + print(f"Loading detailed data from: {input_path}") + try: + df = read_table(input_path) + except Exception as e: + print("Error: Could not read detailed file as parquet or NDJSON") + print(str(e)) + return + print(f"Loaded {len(df)} detailed rows") + plot_detailed_scaling(df, output_path, args.x_scale, args.y_scale) + return + + if args.mode == "combined": + strong_path = resolve_input_path(args.strong_input) + weak_path = resolve_input_path(args.weak_input) + + if not strong_path.exists(): + print(f"Error: Strong scaling input file not found: {strong_path}") + return + if not weak_path.exists(): + print(f"Error: Weak scaling input file not found: {weak_path}") + return + + # Determine output path + if args.output: + output_path = args.output + if output_path.suffix.lower() not in VALID_IMAGE_SUFFIXES: + output_path = output_path.with_suffix(".csv") + else: + # Default output: strong_input_stem_combined_table.csv + output_path = strong_path.with_name( + strong_path.stem + "_combined_table.csv" + ) + + print(f"Loading strong scaling data from: {strong_path}") + try: + strong_df = read_table(strong_path) + except Exception as e: + print("Error: Could not read strong scaling input file") + print(str(e)) + return + print(f"Loaded {len(strong_df)} strong scaling rows") + + print(f"Loading weak scaling data from: {weak_path}") + try: + weak_df = read_table(weak_path) + except Exception as e: + print("Error: Could not read weak scaling input file") + print(str(e)) + return + print(f"Loaded {len(weak_df)} weak scaling rows") + + # Generate combined table + generate_combined_scaling_table( + strong_df, + weak_df, + strong_path, + weak_path, + output_path, + show_run_ids=args.show_run_ids, + ) + return + + raise ValueError(f"Unknown mode: {args.mode}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 00103cb8c..24ea62a73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,10 @@ dev = [ # aarch64: gpu [project.optional-dependencies] +performance = [ + "weathergen-performance", +] + cpu = [ 'torch==2.6.0', ] @@ -228,6 +232,7 @@ weathergen-common = { workspace = true } weathergen-evaluate = { workspace = true } weathergen-metrics = { workspace = true } weathergen-readers-extra = { workspace = true } +weathergen-performance = { workspace = true } flash-attn = [ @@ -272,5 +277,6 @@ members = [ "packages/readers_extra", # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. + "packages/performance", ] diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 7995b5864..0f6a5aee9 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -161,6 +161,7 @@ def run_train(args): Note: All model configurations are set in the function body. """ + t_start = time.time() cli_overwrite = config.from_cli_arglist(args.options) @@ -188,7 +189,7 @@ def run_train(args): trainer = Trainer(cf.train_logging) try: - trainer.run(cf, devices) + trainer.run(cf, devices, t_start=t_start) except Exception: extype, value, tb = sys.exc_info() traceback.print_exc() diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..9c8185fc2 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -53,9 +53,9 @@ def __init__( logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}") # ensure that steps_decay has a reasonable value if self.n_steps_decay < int(0.2 * lr_steps): - self.n_steps_warmup = int(0.1 * lr_steps) - self.n_steps_cooldown = int(0.05 * lr_steps) - self.n_steps_decay = lr_steps - self.n_steps_warmup - self.n_steps_cooldown + self.n_steps_warmup = max(2, int(0.1 * lr_steps)) + self.n_steps_cooldown = max(1, int(0.05 * lr_steps)) + self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown) s = ( "cf.lr_steps_warmup and cf.lr_steps_cooldown", f" were larger than cf.lr_steps={lr_steps}", diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 210fc16a4..eda83335e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -86,6 +86,7 @@ def __init__(self, train_logging: Config): self.batch_size_test_per_gpu = -1 self.collapse_monitor: CollapseMonitor | None = None self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker() + self.t_training_start: float = 0 def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -241,7 +242,9 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") - def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): + def run( + self, cf, devices, run_id_contd=None, mini_epoch_contd=None, t_start: float | None = None + ): # general initalization self.init(cf, devices) cf = self.cf @@ -321,7 +324,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) + beta2 = max(0.9, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2)) eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( @@ -363,7 +366,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: len_per_rank = ( - len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu) + max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu mini_epoch_base = int( self.cf.general.istep @@ -380,7 +383,14 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # run validation before training if requested self.validate_before_training() + # Log startup time + if is_root() and t_start is not None: + startup_time = time.time() - t_start + self.train_logger.log_metrics("train", {"startup_time_seconds": startup_time}) + logger.info(f"Startup time: {startup_time:.2f} seconds") + # training loop + self.t_training_start = time.time() for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): logger.info(f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: train.") @@ -558,6 +568,19 @@ def train(self, mini_epoch): self.dataset.advance() + if is_root(): + total_training_time = time.time() - self.t_training_start + self.train_logger.log_metrics( + "train", + { + "completed_mini_epoch": mini_epoch, + "training_time_after_mini_epoch_seconds": total_training_time, + }, + ) + logger.info( + f"Training time after mini epoch {mini_epoch}: {total_training_time} seconds" + ) + def validate(self, mini_epoch, mode_cfg, batch_size): """ Perform validation / test computation as specified by mode_cfg @@ -748,6 +771,7 @@ def _log(self, stage: Stage): self.train_logger.add_logs(stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: + elapsed_time = time.time() - self.t_training_start self.train_logger.add_logs( stage, samples, @@ -755,6 +779,7 @@ def _log(self, stage: Stage): stddev_all, avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), + elapsed_training_time_seconds=elapsed_time, ) loss_calculator.loss_hist = [] diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 5f1550e42..941605ff0 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -102,9 +102,10 @@ def add_logs( stddev_all: dict, avg_loss: list[float] = None, lr: float = None, + elapsed_training_time_seconds: float | None = None, ) -> None: """ - Log training or validation data + Log training or validation data. """ metrics: dict[str, float] = dict(num_samples=samples) @@ -112,6 +113,13 @@ def add_logs( metrics["loss_avg_mean"] = np.nanmean(avg_loss) metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) + if elapsed_training_time_seconds is not None: + metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds + metrics["average_samples_per_second"] = ( + samples / elapsed_training_time_seconds + if elapsed_training_time_seconds > 0 + else 0 + ) for key, value in losses_all.items(): metrics[key] = np.nanmean(value)