From 3c4d1aef1a38c0da3a5ed739b462ddcde9dee5e2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 13 Aug 2021 09:05:55 -0700 Subject: [PATCH 01/63] Add initial MoE benchmark --- README.moe.md | 131 +++++ fairseq/checkpoint_utils.py | 195 +++++-- fairseq/criterions/__init__.py | 2 + fairseq/criterions/fairseq_criterion.py | 137 +++++ fairseq/criterions/moe_cross_entropy.py | 60 ++ fairseq/data/data_utils.py | 42 +- fairseq/data/dictionary.py | 11 +- fairseq/data/fairseq_dataset.py | 5 + fairseq/data/language_pair_dataset.py | 11 +- fairseq/data/monolingual_dataset.py | 35 +- fairseq/data/pad_dataset.py | 13 +- fairseq/data/plasma_utils.py | 134 ++++- fairseq/data/token_block_dataset.py | 67 ++- fairseq/dataclass/configs.py | 143 +++-- fairseq/dataclass/utils.py | 3 +- .../fully_sharded_data_parallel.py | 111 +++- .../legacy_distributed_data_parallel.py | 11 + fairseq/distributed/utils.py | 166 ++++-- fairseq/fb_pathhandlers.py | 546 ++++++++++++++++++ fairseq/file_io.py | 126 ++-- fairseq/hub_utils.py | 129 ++++- fairseq/logging/metrics.py | 35 ++ fairseq/logging/progress_bar.py | 8 + fairseq/models/fairseq_model.py | 26 +- fairseq/models/transformer.py | 308 +++++++--- fairseq/models/transformer_lm.py | 287 ++++++++- fairseq/modules/checkpoint_activations.py | 228 +------- fairseq/modules/fairseq_dropout.py | 5 +- fairseq/modules/fused_bias_gelu.py | 63 ++ fairseq/modules/moe/__init__.py | 8 + fairseq/modules/moe/moe_layer.py | 228 ++++++++ fairseq/modules/moe/top1gate.py | 149 +++++ fairseq/modules/moe/top2gate.py | 249 ++++++++ fairseq/modules/multihead_attention.py | 50 +- fairseq/modules/transformer_layer.py | 318 ++++++++-- fairseq/moe_checkpoint_utils.py | 183 ++++++ fairseq/optim/adafactor.py | 44 +- fairseq/optim/adam.py | 47 +- fairseq/optim/cpu_adam.py | 28 +- fairseq/optim/fp16_optimizer.py | 4 + fairseq/optim/fused_adam.py | 52 +- fairseq/tasks/fairseq_task.py | 29 +- fairseq/tasks/language_modeling.py | 36 +- fairseq/trainer.py | 328 ++++++++--- fairseq/utils.py | 71 ++- fairseq_cli/eval_lm.py | 239 +++++--- fairseq_cli/generate.py | 44 +- fairseq_cli/train.py | 106 +++- 48 files changed, 4335 insertions(+), 916 deletions(-) create mode 100644 README.moe.md create mode 100644 fairseq/criterions/moe_cross_entropy.py create mode 100644 fairseq/fb_pathhandlers.py create mode 100644 fairseq/modules/fused_bias_gelu.py create mode 100644 fairseq/modules/moe/__init__.py create mode 100644 fairseq/modules/moe/moe_layer.py create mode 100644 fairseq/modules/moe/top1gate.py create mode 100644 fairseq/modules/moe/top2gate.py create mode 100644 fairseq/moe_checkpoint_utils.py diff --git a/README.moe.md b/README.moe.md new file mode 100644 index 0000000000..cb3bf500ef --- /dev/null +++ b/README.moe.md @@ -0,0 +1,131 @@ +# Training MoE language models + +## Dependencies + +Follow the fairseq installation instructions: +https://github.com/pytorch/fairseq/#requirements-and-installation + +The following package versions are recommended: + +apex: +```bash +pip install -v --no-cache-dir --global-option="--cpp_ext" \ + --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" \ + --global-option="--xentropy" --global-option="--fast_multihead_attn" \ + git+git://github.com/NVIDIA/apex.git@e2083df5eb96643c61613b9df48dd4eea6b07690 +``` + +fairscale: +```bash +pip install fairscale==0.4.0 +``` + +hydra: +```bash +pip install hydra-core==1.0.7 omegaconf==2.0.6 +``` + +megatron (must be installed from source to get fused kernels): +```bash +git clone --depth=1 --branch v2.4 https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +pip install -e . +``` + +## Single-node training + +The following command will benchmark an MoE language model using synthetic data +on 8 GPUs. The model has 8 experts (one per GPU) and 4.1B parameters total. + +```bash +# set NUM_EXPERTS based on # of GPUs and desired # experts per GPU +# generally it's recommended to have a single expert per GPU +NUM_EXPERTS=8 +TOKENS_PER_SAMPLE=2048 +python fairseq_cli/train.py \ + --ddp-backend fully_sharded --memory-efficient-fp16 --checkpoint-activations \ + --task dummy_lm --tokens-per-sample $TOKENS_PER_SAMPLE \ + --arch transformer_lm_gpt --share-decoder-input-output-embed \ + --decoder-layers 24 --decoder-embed-dim 2048 --decoder-ffn-embed-dim 8192 \ + --decoder-attention-heads 32 \ + --moe-expert-count $NUM_EXPERTS --moe-freq 2 \ + --moe-gating-use-fp32 --moe-second-expert-policy all \ + --moe-normalize-expert-grad sqrt_world_size \ + --moe-eval-capacity-token-fraction -1.0 \ + --max-sentences-valid 1 --num-workers-valid 0 \ + --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ + --optimizer adam --fp16-adam-stats --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.0005 --warmup-updates 750 \ + --dropout 0.1 --attention-dropout 0.1 \ + --batch-size 4 --update-freq 1 \ + --max-update 250 --disable-validation \ + --log-format json --log-interval 10 +``` + +The total parameter count can be infered from the logs: +``` +(...) +2021-08-13 14:54:20 | INFO | fairseq_cli.train | num. non-expert model params: 908,423,168 (num. trained: 908,423,168) +2021-08-13 14:54:20 | INFO | fairseq_cli.train | num. expert model params: 402,776,064 (num. trained: 402,776,064) +(...) +``` +The expert params are distinct on each GPU, so the total parameter count is `908M + 8 * 403M = 4.1B` + +**Sample output on 8 x V100:** +``` +2021-08-13 14:58:39 | INFO | fairseq.modules.fused_bias_gelu | Done with compiling and loading fused kernels. +2021-08-13 14:58:44 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 64.0 +2021-08-13 14:58:49 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 32.0 +2021-08-13 14:58:53 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 16.0 +2021-08-13 14:59:32 | INFO | train_inner | {"epoch": 1, "update": 0.004, "loss": "20.714", "moe_gate_loss": "16.7217", "overflow_expert1": "20.84", "overflow_expert2": "53.493", "entropy_gating": "1.943", "expert1_balance_top": "66.521", "expert1_balance_bottom": "2.528", "unused_expert1_count": "0.12", "expert2_balance_top": "50.142", "expert2_balance_bottom": "5.417", "unused_expert2_count": "0.052", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "20.472", "ppl": "1.45489e+06", "wps": "16606.5", "ups": "0.25", "wpb": "65536", "bsz": "32", "num_updates": "10", "lr": "7.33333e-06", "gnorm": "30.01", "loss_scale": "16", "train_wall": "62", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "19", "cuda_gb_free": "21.4", "wall": "68"} +2021-08-13 15:00:12 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "16.194", "moe_gate_loss": "15.642", "overflow_expert1": "16.52", "overflow_expert2": "59.608", "entropy_gating": "1.983", "expert1_balance_top": "63.168", "expert1_balance_bottom": "1.88", "unused_expert1_count": "0.564", "expert2_balance_top": "49.929", "expert2_balance_bottom": "3.712", "unused_expert2_count": "0.368", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "15.969", "ppl": "64132.9", "wps": "16591.9", "ups": "0.25", "wpb": "65536", "bsz": "32", "num_updates": "20", "lr": "1.4e-05", "gnorm": "1.82", "loss_scale": "16", "train_wall": "39", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "19", "cuda_gb_free": "21.4", "wall": "107"} +2021-08-13 15:00:52 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "15.132", "moe_gate_loss": "13.3857", "overflow_expert1": "5.742", "overflow_expert2": "45.276", "entropy_gating": "2.023", "expert1_balance_top": "49.599", "expert1_balance_bottom": "5.064", "unused_expert1_count": "0.423", "expert2_balance_top": "40.013", "expert2_balance_bottom": "7.728", "unused_expert2_count": "0.32", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "14.939", "ppl": "31410.7", "wps": "16562", "ups": "0.25", "wpb": "65536", "bsz": "32", "num_updates": "30", "lr": "2.06667e-05", "gnorm": "1.397", "loss_scale": "16", "train_wall": "40", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "19", "cuda_gb_free": "21.4", "wall": "147"} +``` + +**Sample output on 8 x A100:** +``` +2021-08-13 14:58:39 | INFO | fairseq.modules.fused_bias_gelu | Done with compiling and loading fused kernels. +2021-08-13 22:10:38 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 64.0 +2021-08-13 22:10:40 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 32.0 +2021-08-13 22:10:43 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 16.0 +2021-08-13 22:11:02 | INFO | train_inner | {"epoch": 1, "update": 0.004, "loss": "20.703", "moe_gate_loss": "16.7792", "overflow_expert1": "21.27", "overflow_expert2": "52.991", "entropy_gating": "1.943", "expert1_balance_top": "66.899", "expert1_balance_bottom": "2.586", "unused_expert1_count": "0.13", "expert2_balance_top": "50.174", "expert2_balance_bottom": "5.421", "unused_expert2_count": "0.066", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "20.461", "ppl": "1.44332e+06", "wps": "34799.2", "ups": "0.53", "wpb": "65536", "bsz": "32", "num_updates": "10", "lr": "7.33333e-06", "gnorm": "29.972", "loss_scale": "16", "train_wall": "49", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "20.6", "cuda_gb_free": "29.2", "wall": "68"} +2021-08-13 22:11:21 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "16.195", "moe_gate_loss": "15.6466", "overflow_expert1": "16.589", "overflow_expert2": "59.311", "entropy_gating": "1.984", "expert1_balance_top": "63.15", "expert1_balance_bottom": "1.885", "unused_expert1_count": "0.548", "expert2_balance_top": "49.952", "expert2_balance_bottom": "3.785", "unused_expert2_count": "0.349", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "15.969", "ppl": "64151.2", "wps": "34973.5", "ups": "0.53", "wpb": "65536", "bsz": "32", "num_updates": "20", "lr": "1.4e-05", "gnorm": "1.822", "loss_scale": "16", "train_wall": "19", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "20.6", "cuda_gb_free": "29.2", "wall": "87"} +2021-08-13 22:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "15.131", "moe_gate_loss": "13.3747", "overflow_expert1": "5.894", "overflow_expert2": "44.769", "entropy_gating": "2.024", "expert1_balance_top": "49.877", "expert1_balance_bottom": "5.076", "unused_expert1_count": "0.41", "expert2_balance_top": "39.921", "expert2_balance_bottom": "7.812", "unused_expert2_count": "0.343", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "14.938", "ppl": "31389.2", "wps": "35046.4", "ups": "0.53", "wpb": "65536", "bsz": "32", "num_updates": "30", "lr": "2.06667e-05", "gnorm": "1.396", "loss_scale": "16", "train_wall": "19", "cuda_gb_allocated": "10.4", "cuda_gb_reserved": "20.6", "cuda_gb_free": "29.2", "wall": "105"} +``` + +## Larger model on multiple nodes + +The following command will train an MoE model with 142B parameters on 64 A100s. + +```bash +# salloc command might look different +salloc --gpus-per-node 8 --ntasks-per-node 8 --cpus-per-task 12 --nodes 8 --mem-per-gpu 128G + +# set NUM_EXPERTS based on # of GPUs and desired # experts per GPU +# generally it's recommended to have a single expert per GPU +NUM_EXPERTS=64 +TOKENS_PER_SAMPLE=1024 + +# launch the job (adjust port and --cpu-bind if needed) +DISTRIBUTED_PORT=12345 +srun --cpu-bind=mask_cpu:000000ffffff000000ffffff,000000ffffff000000ffffff,000000ffffff000000ffffff,000000ffffff000000ffffff,ffffff000000ffffff000000,ffffff000000ffffff000000,ffffff000000ffffff000000,ffffff000000ffffff000000 \ + python fairseq_cli/train.py \ + --distributed-port $DISTRIBUTED_PORT \ + --ddp-backend fully_sharded --memory-efficient-fp16 --checkpoint-activations \ + --task dummy_lm --tokens-per-sample $TOKENS_PER_SAMPLE \ + --arch transformer_lm_gpt --share-decoder-input-output-embed \ + --decoder-layers 32 --decoder-embed-dim 4096 --decoder-ffn-embed-dim 16384 \ + --decoder-attention-heads 32 \ + --moe-expert-count $NUM_EXPERTS --moe-freq 2 \ + --moe-gating-use-fp32 --moe-second-expert-policy all \ + --moe-normalize-expert-grad sqrt_world_size \ + --moe-eval-capacity-token-fraction -1.0 \ + --max-sentences-valid 1 --num-workers-valid 0 \ + --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ + --optimizer adam --fp16-adam-stats --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.0005 --warmup-updates 750 \ + --dropout 0.1 --attention-dropout 0.1 \ + --batch-size 12 --update-freq 1 \ + --max-update 250 --disable-validation \ + --log-format json --log-interval 10 +``` diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 5a98dad2aa..52e32cf7f8 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -6,6 +6,7 @@ import ast import collections import contextlib +import functools import logging import os import re @@ -14,20 +15,24 @@ from typing import Any, Dict, Optional, Union import torch +from fairseq.data.multilingual.multilingual_utils import get_lang_tok from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, overwrite_args_by_name, ) -from fairseq.file_io import PathManager +from fairseq.distributed import utils as dist_utils +from fairseq.file_io import PathManager, torch_load_cpu from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import Container, DictConfig, open_dict, OmegaConf - +from fairseq import moe_checkpoint_utils +from omegaconf import DictConfig, open_dict, OmegaConf logger = logging.getLogger(__name__) -def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): +def save_checkpoint( + cfg: CheckpointConfig, trainer, epoch_itr, val_loss, training_finished=False, async_callback_fn=None, +): from fairseq import meters # only one worker should attempt to create the required dir @@ -42,7 +47,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if cfg.no_save: return - trainer.consolidate_optimizer() + trainer.consolidate_optimizer() # TODO(SS): we dont need if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: return @@ -69,11 +74,19 @@ def is_better(a, b): and cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0 ) - checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( - not hasattr(save_checkpoint, "best") - or is_better(val_loss, save_checkpoint.best) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = ( + val_loss is not None + and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + and not cfg.no_best_checkpoints ) - if val_loss is not None and cfg.keep_best_checkpoints > 0: + if ( + val_loss is not None + and cfg.keep_best_checkpoints > 0 + and not cfg.no_best_checkpoints + ): checkpoint_conds[ "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) ] = not hasattr(save_checkpoint, "best") or is_better( @@ -91,19 +104,32 @@ def is_better(a, b): os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: - trainer.save_checkpoint(checkpoints[0], extra_state) + if PathManager.islink(checkpoints[0]): + PathManager.rm(checkpoints[0]) + if trainer.is_moe and trainer.is_data_parallel_master: + shared = re.sub("rank-[0-9]+", "shared", checkpoints[0]) + if PathManager.islink(shared): + PathManager.rm(shared) + + trainer.save_checkpoint( + checkpoints[0], extra_state, training_finished=training_finished, async_callback_fn=async_callback_fn + ) + + def copy_or_symlink(src, dest): + if cfg.symlink_best_and_last_checkpoints: + PathManager.symlink(src, dest) + elif cfg.write_checkpoints_asynchronously: + pass # TODO[ioPath]: Need to implement a delayed asynchronous file copying/moving feature. + else: + assert PathManager.copy(src, dest, overwrite=True), f"Failed to copy {src} to {dest}" + for cp in checkpoints[1:]: - if cfg.write_checkpoints_asynchronously: - # TODO[ioPath]: Need to implement a delayed asynchronous - # file copying/moving feature. - logger.warning( - f"ioPath is not copying {checkpoints[0]} to {cp} " - "since async write mode is on." + copy_or_symlink(src=checkpoints[0], dest=cp) + if (trainer.is_moe or trainer.is_base_moe) and not trainer.is_fsdp and trainer.is_data_parallel_master: + copy_or_symlink( + src=re.sub("rank-[0-9]+", "shared", checkpoints[0]), + dest=re.sub("rank-[0-9]+", "shared", cp), ) - else: - assert PathManager.copy( - checkpoints[0], cp, overwrite=True - ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( @@ -112,33 +138,42 @@ def is_better(a, b): ) ) + delete_old_checkpoint_files(cfg, end_of_epoch, trainer.is_moe or trainer.is_base_moe, suffix, trainer.is_data_parallel_master) + + +def delete_old_checkpoint_files(cfg: DictConfig, end_of_epoch: bool, is_moe: bool, suffix: str, is_data_parallel_master: bool): if not end_of_epoch and cfg.keep_interval_updates > 0: - # remove old checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" - ) - for old_chk in checkpoints[cfg.keep_interval_updates :]: - if os.path.lexists(old_chk): - os.remove(old_chk) + suffixes = [suffix] + if is_moe and is_data_parallel_master: + suffixes.append("-shared") + # remove old checkpoints; checkpoints are sorted in descending order + for one_suffix in suffixes: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(one_suffix) + ) + for old_chk in checkpoints[cfg.keep_interval_updates:]: + if os.path.lexists(old_chk): + os.remove(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") - for old_chk in checkpoints[cfg.keep_last_epochs :]: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) - if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, - pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - cfg.best_checkpoint_metric + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix ), ) if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] - for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + for old_chk in checkpoints[cfg.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -234,7 +269,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): return extra_state, epoch_itr -def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False, is_moe=False): """Loads a checkpoint to CPU (with upgrading for backward compatibility). If doing single-GPU training or if the checkpoint is only being loaded by at @@ -267,8 +302,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): torch.distributed.barrier() local_path = PathManager.get_local_path(path) - with open(local_path, "rb") as f: - state = torch.load(f, map_location=torch.device("cpu")) + # path to checkpoint...-shared.pt + shared_path = re.sub('rank-[0-9]+', 'shared', local_path) + if is_moe and os.path.exists(shared_path): + expert_state = moe_checkpoint_utils.load_expert_state(local_path) # Possibly merge experts + shared_state = torch_load_cpu(shared_path) + state = moe_checkpoint_utils.merge_expert_and_shared_state(expert_state, shared_state) + else: + state = torch_load_cpu(local_path) if "args" in state and state["args"] is not None and arg_overrides is not None: args = state["args"] @@ -304,6 +345,7 @@ def load_model_ensemble( suffix="", num_shards=1, state=None, + is_moe=False, ): """Loads an ensemble of models. @@ -324,10 +366,56 @@ def load_model_ensemble( suffix, num_shards, state, + is_moe=is_moe, ) return ensemble, args +def upgrade_state_for_langs_difference(state, model_config, task): + """Accounts for the difference in dictionaries due to language tokens + to allow ensembling between multilingual and bilingual models""" + + lang_count_diff = len(task.langs) - len(model_config.langs) + assert lang_count_diff >= 0, "Removing langs from ensemble components not yet supported!" + + if model_config.encoder_langtok is not None: + orig_embed_tokens = state["model"]["encoder.embed_tokens.weight"] + upgraded_embed_tokens = torch.zeros( + (orig_embed_tokens.shape[0] + lang_count_diff, orig_embed_tokens.shape[1]), + dtype=orig_embed_tokens.dtype, + device=orig_embed_tokens.device, + ) + + first_lang_tok = task.source_dictionary.index(get_lang_tok(task.langs[0], "multilingual")) + # language tokens appear at the end of the dictionary + upgraded_embed_tokens[: first_lang_tok, :] = orig_embed_tokens[: first_lang_tok, :] + for i, lang in enumerate(model_config.langs): + lang_tok = task.source_dictionary.index(get_lang_tok(lang, "multilingual")) + upgraded_embed_tokens[lang_tok, :] = orig_embed_tokens[first_lang_tok + i, :] + + state["model"]["encoder.embed_tokens.weight"] = upgraded_embed_tokens + del orig_embed_tokens + + if model_config.decoder_langtok: + for weight_name in ("decoder.embed_tokens.weight", "decoder.output_projection.weight"): + orig_weights = state["model"][weight_name] + upgraded_weights = torch.zeros( + (orig_weights.shape[0] + lang_count_diff, orig_weights.shape[1]), + dtype=orig_weights.dtype, + device=orig_weights.device, + ) + + first_lang_tok = task.target_dictionary.index(get_lang_tok(task.langs[0], "multilingual")) + # language tokens appear at the end of the dictionary + upgraded_weights[: first_lang_tok, :] = orig_weights[: first_lang_tok, :] + for i, lang in enumerate(model_config.langs): + lang_tok = task.target_dictionary.index(get_lang_tok(lang, "multilingual")) + upgraded_weights[lang_tok, :] = orig_weights[first_lang_tok + i, :] + + state["model"][weight_name] = upgraded_weights + del orig_weights + + def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, @@ -336,7 +424,9 @@ def load_model_ensemble_and_task( suffix="", num_shards=1, state=None, + is_moe=False, ): + logger.info("load_model_ensemble_and_task is_moe={}".format(is_moe)) assert state is None or len(filenames) == 1 from fairseq import tasks @@ -346,6 +436,7 @@ def load_model_ensemble_and_task( ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None + for filename in filenames: orig_filename = filename assert num_shards > 0 @@ -358,7 +449,7 @@ def load_model_ensemble_and_task( if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: - state = load_checkpoint_to_cpu(filename, arg_overrides) + state = load_checkpoint_to_cpu(filename, arg_overrides, is_moe=is_moe) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: @@ -377,6 +468,13 @@ def load_model_ensemble_and_task( # build model for ensemble model = task.build_model(cfg.model) + if ( + hasattr(cfg.model, "langs") + and hasattr(task, "langs") + and cfg.model.langs != task.langs + ): + upgrade_state_for_langs_difference(state, cfg.model, task) + model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) # reset state so it gets loaded for the next model in ensemble @@ -405,9 +503,14 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(obj, filename, async_write: bool = False): +def torch_persistent_save(obj, filename: str, async_write: bool = False, async_callback_fn=None): + assert async_callback_fn is None or async_write, 'async_callback_fn requires async_write=True (--save-async)' + if async_write and async_callback_fn is not None: + callback = functools.partial(async_callback_fn, filename) + else: + callback = None if async_write: - with PathManager.opena(filename, "wb") as f: + with PathManager.opena(filename, "wb", callback_after_file_close=callback) as f: _torch_persistent_save(obj, f) else: if PathManager.supports_rename(filename): @@ -421,16 +524,16 @@ def torch_persistent_save(obj, filename, async_write: bool = False): _torch_persistent_save(obj, f) -def _torch_persistent_save(obj, f): +def _torch_persistent_save(obj, f, num_retries=3): if isinstance(f, str): with PathManager.open(f, "wb") as h: torch_persistent_save(obj, h) return - for i in range(3): + for i in range(num_retries): try: return torch.save(obj, f) except Exception: - if i == 2: + if i == num_retries - 1: logger.error(traceback.format_exc()) @@ -691,7 +794,8 @@ def load_pretrained_component_from_model( def verify_checkpoint_directory(save_dir: str) -> None: if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) - temp_file_path = os.path.join(save_dir, "dummy") + rank = dist_utils.get_global_rank() + temp_file_path = os.path.join(save_dir, f"dummy{rank}") try: with open(temp_file_path, "w"): pass @@ -701,4 +805,7 @@ def verify_checkpoint_directory(save_dir: str) -> None: ) raise e else: - os.remove(temp_file_path) + try: + os.remove(temp_file_path) + except FileNotFoundError: + pass diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 8cc6c0f043..e2181bd539 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -11,6 +11,8 @@ from fairseq.criterions.fairseq_criterion import ( # noqa FairseqCriterion, LegacyFairseqCriterion, + MoECriterion, + MoECriterionConfig, ) from omegaconf import DictConfig diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index ff4beb0250..d6a6823aea 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -4,13 +4,19 @@ # LICENSE file in the root directory of this source tree. import inspect +import math +from dataclasses import dataclass, field from typing import Any, Dict, List +import torch +import torch.nn.functional as F from fairseq import metrics, utils from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from torch.nn.modules.loss import _Loss +from omegaconf import II +from fairseq.modules.moe import MOELayer class FairseqCriterion(_Loss): def __init__(self, task): @@ -118,3 +124,134 @@ def __init__(self, args, task): def build_criterion(cls, args, task): """Construct a criterion from command-line args.""" return cls(args, task) + + +@dataclass +class MoECriterionConfig(FairseqDataclass): + moe_gate_loss_wt: float = field( + default=1.0, + metadata={ + "help": "Weight associated with MoE gate loss" + "in the weighted sum of gate loss and cross entropy loss" + } + ) + moe_gate_loss_combine_method: str = field( + default="average", + metadata={ + "help": "Method of combining the gate loss from each MoE layers" + "('sum', 'average')" + } + ) + moe_gate_loss_transform: str = field( + default="none", + metadata={ + "help": "Transformation to apply to the gate loss ('none', 'neg_log')" + } + ) + sentence_avg: bool = II("optimization.sentence_avg") + +class MoECriterion(FairseqCriterion): + + moe_logging_keys = [ + "overflow_expert1", # average % of overflowed tokens from 1st expert + "overflow_expert2", # average % of overflowed tokens from 2nd expert + "entropy_gating", # average entropy of the gating distribution + "expert1_balance_top", # average cumulative % of tokens processed by the most used 20% 1st experts + "expert1_balance_bottom", # average cumulative % of tokens processed by the least used 20% 1st experts + "unused_expert1_count", # average number of 1st experts which process no tokens + "expert2_balance_top", # average cumulative % of tokens processed by the most used 20% 2nd experts + "expert2_balance_bottom", # average cumulative % of tokens processed by the least used 20% 2nd experts + "unused_expert2_count", # average number of 2nd experts which process no tokens + "all_to_all_cpu_time_ms", # CPU time spent in all to all calls in milliseconds + "all_to_all_cuda_time_ms", # CUDA ttime spent in all to all calls in milliseconds + ] + def __init__(self, task, moe_gate_loss_wt, moe_gate_loss_combine_method, moe_gate_loss_transform, sentence_avg): + super().__init__(task) + self.gate_loss_weight = moe_gate_loss_wt + self.gate_loss_combine_method = moe_gate_loss_combine_method + self.gate_loss_transform = moe_gate_loss_transform + self.sentence_avg = sentence_avg + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + loss, inner_loss, moe_loss, moe_metadata, sample_size, logging_output = self.compute_loss(model, sample, reduce=reduce) + + logging_output["loss"] = loss.data + logging_output["moe_loss"] = moe_loss.data + logging_output.update(moe_metadata) + + return loss, sample_size, logging_output + + def compute_loss(self, model, sample, reduce=True): + net_output, inner_loss, sample_size, logging_output = self.compute_inner_loss(model, sample) + gate_loss = 0.0 + gate_count = 0 + for l_aux in net_output[1]["l_aux"]: + if l_aux is not None: + gate_loss += l_aux + gate_count += 1 + if self.gate_loss_combine_method == "average": + gate_loss = gate_loss / gate_count + if self.gate_loss_transform == "neg_log": + gate_loss = - torch.log(gate_loss) + gate_loss = sample_size * gate_loss + loss = inner_loss + self.gate_loss_weight * gate_loss + return loss, inner_loss, gate_loss, self.get_moe_metadata(model), sample_size, logging_output + + def compute_inner_loss(self, model, sample): + """Compute the non-MoE portion of the loss. Default is cross-entropy""" + raise NotImplementedError + + def get_moe_metadata(self, model): + moe_logging_output = {} + for key in MoECriterion.moe_logging_keys: + total_val = 0 + count = 0 + for _, module in model.named_modules(): + if isinstance(module, MOELayer): + total_val += module.metadata[key] if key in module.metadata else 0 + count += 1 + moe_logging_output[key] = total_val / count + moe_logging_output["batch_count"] = 1 + return moe_logging_output + + @staticmethod + def reduce_moe_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + moe_loss_sum = sum(log.get("moe_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "moe_gate_loss", moe_loss_sum / sample_size, sample_size, round=8 + ) + batch_count = sum(log.get("batch_count", 0) for log in logging_outputs) + for key in MoECriterion.moe_logging_keys: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar( + key, val / batch_count, batch_count, round=3 + ) + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + raise NotImplementedError + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/moe_cross_entropy.py b/fairseq/criterions/moe_cross_entropy.py new file mode 100644 index 0000000000..3e6e95871b --- /dev/null +++ b/fairseq/criterions/moe_cross_entropy.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig + + +@register_criterion("moe_cross_entropy", dataclass=MoECriterionConfig) +class MoECrossEntropyCriterion(MoECriterion): + + def compute_inner_loss(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1) + nll_loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + logging_output = { + "inner_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return net_output, nll_loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs) + + loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "inner_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["inner_loss"].avg) + ) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 6f7561afbe..88b6aa53eb 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -10,7 +10,7 @@ import contextlib import itertools import logging -import os +import re import warnings from typing import Optional, Tuple @@ -18,7 +18,8 @@ import torch from fairseq.file_io import PathManager - +from fairseq import utils +import os logger = logging.getLogger(__name__) @@ -41,13 +42,16 @@ def collate_tokens( move_eos_to_beginning=False, pad_to_length=None, pad_to_multiple=1, + pad_to_bsz=None, ): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) - res = values[0].new(len(values), size).fill_(pad_idx) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() @@ -65,7 +69,6 @@ def copy_tensor(src, dst): copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) return res - def load_indexed_dataset( path, dictionary=None, dataset_impl=None, combine=False, default="cached" ): @@ -345,7 +348,6 @@ def batch_by_size( max_sentences, bsz_mult, ) - else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort( @@ -524,3 +526,33 @@ def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: return ~lengths_to_padding_mask(lens) + + +def _find_extra_valid_paths(dataset_path: str) -> set: + paths = utils.split_paths(dataset_path) + all_valid_paths = set() + for sub_dir in paths: + contents = PathManager.ls(sub_dir) + valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] + all_valid_paths |= {os.path.basename(p) for p in valid_paths} + # Remove .bin, .idx etc + roots = {os.path.splitext(p)[0] for p in all_valid_paths} + return roots + + +def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: + """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored.""" + if ( + train_cfg.dataset.ignore_unused_valid_subsets + or train_cfg.dataset.combine_valid_subsets + or train_cfg.dataset.disable_validation + or getattr(train_cfg.task, "data", None) is None + ): + return + other_paths = _find_extra_valid_paths(train_cfg.task.data) + specified_subsets = train_cfg.dataset.valid_subset.split(",") + ignored_paths = [p for p in other_paths if p not in specified_subsets] + if ignored_paths: + advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." + msg = f"Valid paths {ignored_paths} will be ignored. {advice}" + raise ValueError(msg) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 8d219e20ef..da3427b37a 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -302,20 +302,19 @@ def encode_line( words = line_tokenizer(line) if reverse_order: words = list(reversed(words)) - nwords = len(words) - ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + ids = [] - for i, word in enumerate(words): + for word in words: if add_if_not_exist: idx = self.add_symbol(word) else: idx = self.index(word) if consumer is not None: consumer(word, idx) - ids[i] = idx + ids.append(idx) if append_eos: - ids[nwords] = self.eos_index - return ids + ids.append(self.eos_index) + return torch.tensor(ids, dtype=torch.int32) @staticmethod def _add_file_to_dictionary_single_worker( diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 23e6992dba..050379028b 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -56,6 +56,11 @@ def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" raise NotImplementedError + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + raise NotImplementedError def num_tokens_vec(self, indices): """Return the number of tokens for a set of positions defined by indices. diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index ff3e14bf14..47757a68e3 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -95,7 +95,11 @@ def compute_alignment_weights(alignments): ntokens = tgt_lengths.sum().item() if samples[0].get("prev_output_tokens", None) is not None: - prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) + prev_output_tokens = merge( + "prev_output_tokens", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, + ) elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step @@ -223,6 +227,7 @@ def __init__( src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, + fixed_pad_length=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -294,6 +299,7 @@ def __init__( else: self.buckets = None self.pad_to_multiple = pad_to_multiple + self.fixed_pad_length = fixed_pad_length def get_batch_shapes(self): return self.buckets @@ -338,6 +344,7 @@ def __getitem__(self, index): def __len__(self): return len(self.src) + # Note: self.fixed_pad_length overrides pad_to_length def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. @@ -381,7 +388,7 @@ def collater(self, samples, pad_to_length=None): left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, - pad_to_length=pad_to_length, + pad_to_length=self.fixed_pad_length or pad_to_length, pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index bf7aa86f6c..c8d80461a4 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -9,7 +9,7 @@ from . import FairseqDataset, data_utils -def collate(samples, pad_idx, eos_idx): +def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): if len(samples) == 0: return {} @@ -23,6 +23,8 @@ def merge(key, is_list=False): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) ) return res @@ -32,6 +34,8 @@ def merge(key, is_list=False): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) src_tokens = merge("source") @@ -75,6 +79,10 @@ def __init__( shuffle=False, targets=None, add_bos_token=False, + fixed_pad_length=None, + pad_to_bsz=None, + src_lang_idx=None, + tgt_lang_idx=None, ): self.dataset = dataset self.sizes = np.array(sizes) @@ -83,6 +91,10 @@ def __init__( self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token + self.fixed_pad_length = fixed_pad_length + self.pad_to_bsz = pad_to_bsz + self.src_lang_idx = src_lang_idx + self.tgt_lang_idx = tgt_lang_idx assert targets is None or all( t in {"self", "future", "past"} for t in targets @@ -160,11 +172,20 @@ def _make_source_target(self, source, future_target, past_target): def _maybe_add_bos(self, source, target): if self.add_bos_token: - source = torch.cat([source.new([self.vocab.bos()]), source]) + # src_lang_idx and tgt_lang_idx are passed in for multilingual LM, with the + # first token being an lang_id token. + bos = self.src_lang_idx or self.vocab.bos() + source = torch.cat([source.new([bos]), source]) if target is not None: - target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) + tgt_bos = self.tgt_lang_idx or self.tgt_vocab.bos() + target = torch.cat([target.new([tgt_bos]), target]) return source, target + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + return self.sizes[indices] + def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): @@ -200,7 +221,13 @@ def collater(self, samples): target sentence of shape `(bsz, tgt_len)`. Padding will appear on the right. """ - return collate(samples, self.vocab.pad(), self.vocab.eos()) + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.fixed_pad_length, + self.pad_to_bsz, + ) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/data/pad_dataset.py b/fairseq/data/pad_dataset.py index 8075bba6a9..f27321c5ae 100644 --- a/fairseq/data/pad_dataset.py +++ b/fairseq/data/pad_dataset.py @@ -9,20 +9,21 @@ class PadDataset(BaseWrapperDataset): - def __init__(self, dataset, pad_idx, left_pad): + def __init__(self, dataset, pad_idx, left_pad, pad_length=None): super().__init__(dataset) self.pad_idx = pad_idx self.left_pad = left_pad + self.pad_length = pad_length def collater(self, samples): - return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad) + return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length) class LeftPadDataset(PadDataset): - def __init__(self, dataset, pad_idx): - super().__init__(dataset, pad_idx, left_pad=True) + def __init__(self, dataset, pad_idx, pad_length=None): + super().__init__(dataset, pad_idx, left_pad=True, pad_length=pad_length) class RightPadDataset(PadDataset): - def __init__(self, dataset, pad_idx): - super().__init__(dataset, pad_idx, left_pad=False) + def __init__(self, dataset, pad_idx, pad_length=None): + super().__init__(dataset, pad_idx, left_pad=False, pad_length=pad_length) diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index f4bb6472d7..b9fab3b739 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -3,11 +3,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + import subprocess +import json import tempfile +import hashlib +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False -class PlasmaArray(object): +class PlasmaArray: """ Wrapper around numpy arrays that automatically moves the data to shared memory upon serialization. This is particularly helpful when passing numpy @@ -31,12 +43,7 @@ def __init__(self, array): @property def plasma(self): if self._plasma is None and not self.disable: - try: - import pyarrow.plasma as plasma - - self._plasma = plasma - except ImportError: - self._plasma = None + self._plasma = plasma return self._plasma def start_server(self): @@ -47,13 +54,7 @@ def start_server(self): self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name self._server = subprocess.Popen( - [ - "plasma_store", - "-m", - str(int(1.05 * self.array.nbytes)), - "-s", - self.path, - ] + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] ) @property @@ -64,6 +65,7 @@ def client(self): return self._client def __getstate__(self): + """Called on pickle load""" if self.plasma is None: return self.__dict__ if self.object_id is None: @@ -78,6 +80,7 @@ def __getstate__(self): return state def __setstate__(self, state): + """Called on pickle save""" self.__dict__.update(state) if self.plasma is None: return @@ -89,3 +92,106 @@ def __del__(self): self._server = None self._server_tmp.close() self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024 ** 3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index ce0a0d1114..d2c65fd7e0 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -7,6 +7,7 @@ import torch from fairseq.data import FairseqDataset, plasma_utils from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple class TokenBlockDataset(FairseqDataset): @@ -42,7 +43,46 @@ def __init__( break_mode=None, include_targets=False, document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" try: from fairseq.data.token_block_utils_fast import ( _get_slice_indices_fast, @@ -54,15 +94,6 @@ def __init__( "or `python setup.py build_ext --inplace`" ) - super().__init__() - self.dataset = dataset - self.pad = pad - self.eos = eos - self.include_targets = include_targets - - assert len(dataset) == len(sizes) - assert len(dataset) > 0 - if isinstance(sizes, list): sizes = np.array(sizes, dtype=np.int64) else: @@ -79,7 +110,7 @@ def __init__( slice_indices = _get_slice_indices_fast( sizes, str(break_mode), block_size, document_sep_len ) - self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + _sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices if break_mode == "eos": @@ -99,15 +130,12 @@ def __init__( sizes, slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 - slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max()) - - self._slice_indices = plasma_utils.PlasmaArray( - slice_indices.astype(slice_indices_dtype) - ) - self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype)) - self._block_to_dataset_index = plasma_utils.PlasmaArray( - block_to_dataset_index.astype(slice_indices_dtype) - ) + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices @property def slice_indices(self): @@ -131,7 +159,6 @@ def __getitem__(self, index): buffer = torch.cat( [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] ) - slice_s, slice_e = self.slice_indices[index] length = slice_e - slice_s s, e = start_offset, start_offset + length diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 5d6aee157a..d555b86cf9 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -33,6 +33,9 @@ class FairseqDataclass: def name(): return None + def positional_args(self): + return ["data"] + def _get_all_attributes(self) -> List[str]: return [k for k in self.__dataclass_fields__.keys()] @@ -95,6 +98,9 @@ class CommonConfig(FairseqDataclass): log_format: Optional[LOG_FORMAT_CHOICES] = field( default=None, metadata={"help": "log format to use"} ) + log_file: Optional[str] = field( + default=None, metadata={"help": "log file to copy metrics to."} + ) tensorboard_logdir: Optional[str] = field( default=None, metadata={ @@ -104,15 +110,10 @@ class CommonConfig(FairseqDataclass): ) wandb_project: Optional[str] = field( default=None, - metadata={ - "help": "Weights and Biases project name to use for logging" - }, + metadata={"help": "Weights and Biases project name to use for logging"}, ) azureml_logging: Optional[bool] = field( - default=False, - metadata={ - "help": "Log scalars to AzureML context" - }, + default=False, metadata={"help": "Log scalars to AzureML context"}, ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} @@ -192,6 +193,18 @@ class CommonConfig(FairseqDataclass): "main method can return a value (useful for sweeps)" }, ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) + log_nvidia_smi: bool = field( + default=False, metadata={"help": "log output from nvidia-smi during training"} + ) @dataclass @@ -263,7 +276,7 @@ class DistributedTrainingConfig(FairseqDataclass): metadata={ "help": "kill the job if no progress is made in N seconds; " "set to -1 to disable" - } + }, ) broadcast_buffers: bool = field( default=False, @@ -360,16 +373,16 @@ class DistributedTrainingConfig(FairseqDataclass): tpu: bool = II("common.tpu") # configuration for --ddp-backend=fully_sharded no_reshard_after_forward: bool = field( - default=False, - metadata={"help": "don't reshard parameters after forward pass"}, + default=False, metadata={"help": "don't reshard parameters after forward pass"}, ) fp32_reduce_scatter: bool = field( - default=False, - metadata={"help": "reduce-scatter grads in FP32"}, + default=False, metadata={"help": "reduce-scatter grads in FP32"}, ) cpu_offload: bool = field( - default=False, - metadata={"help": "offload FP32 params to CPU"} + default=False, metadata={"help": "offload FP32 params to CPU"} + ) + use_sharded_state: Optional[bool] = field( + default=False, metadata={"help": "load and save local state dict"} ) @@ -378,6 +391,9 @@ class DatasetConfig(FairseqDataclass): num_workers: int = field( default=1, metadata={"help": "how many subprocesses to use for data loading"} ) + num_workers_valid: int = field( + default=0, metadata={"help": "how many subprocesses to use for data loading during validation"} + ) skip_invalid_size_inputs_valid_test: bool = field( default=False, metadata={"help": "ignore too long or too short lines in valid and test set"}, @@ -418,6 +434,19 @@ class DatasetConfig(FairseqDataclass): " (e.g. train, valid, test)" }, ) + combine_valid_subsets: Optional[bool] = field( + default=None, + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)", + "argparse_alias": "--combine-val", + }, + ) + ignore_unused_valid_subsets: Optional[bool] = field( + default=False, + metadata={"help": "do not raise error if valid subsets are ignored"}, + ) + validate_interval: int = field( default=1, metadata={"help": "validate every N epochs"} ) @@ -447,6 +476,8 @@ class DatasetConfig(FairseqDataclass): "argparse_alias": "--max-sentences-valid", }, ) + max_valid_steps: Optional[int] = field(default=None, metadata={'help': 'How many batches to evaluate', + "argparse_alias": "--nval"}) curriculum: int = field( default=0, metadata={"help": "don't shuffle batches for first N epochs"} ) @@ -580,10 +611,21 @@ class CheckpointConfig(FairseqDataclass): no_last_checkpoints: bool = field( default=False, metadata={"help": "don't store last checkpoints"} ) + no_best_checkpoints: bool = field( + default=False, metadata={"help": "don't store best checkpoints"} + ) no_save_optimizer_state: bool = field( default=False, metadata={"help": "don't save optimizer-state as part of checkpoint"}, ) + no_save_optimizer_state_on_training_finished: bool = field( + default=False, + metadata={"help": "don't save optimizer-state as part of checkpoint when training is done"}, + ) + symlink_best_and_last_checkpoints: bool = field( + default=False, + metadata={"help": "Symlink best and last checkpoints instead of copying", "argparse_alias": "--symlink"}, + ) best_checkpoint_metric: str = field( default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} ) @@ -632,6 +674,16 @@ class CheckpointConfig(FairseqDataclass): "argparse_alias": "--save-async", }, ) + s3_upload_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Upload checkpoints asynchronously in a separate " + "thread to S3. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--s3-dir", + }, + ) model_parallel_size: int = II("common.model_parallel_size") @@ -665,12 +717,10 @@ class FairseqBMUFConfig(FairseqDataclass): @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( - default=5, - metadata={"help": "beam size"}, + default=5, metadata={"help": "beam size"}, ) nbest: int = field( - default=1, - metadata={"help": "number of hypotheses to output"}, + default=1, metadata={"help": "number of hypotheses to output"}, ) max_len_a: float = field( default=0, @@ -685,24 +735,19 @@ class GenerationConfig(FairseqDataclass): }, ) min_len: int = field( - default=1, - metadata={"help": "minimum generation length"}, + default=1, metadata={"help": "minimum generation length"}, ) match_source_len: bool = field( - default=False, - metadata={"help": "generations should match the source length"}, + default=False, metadata={"help": "generations should match the source length"}, ) unnormalized: bool = field( - default=False, - metadata={"help": "compare unnormalized hypothesis scores"}, + default=False, metadata={"help": "compare unnormalized hypothesis scores"}, ) no_early_stop: bool = field( - default=False, - metadata={"help": "deprecated"}, + default=False, metadata={"help": "deprecated"}, ) no_beamable_mm: bool = field( - default=False, - metadata={"help": "don't use BeamableMM in attention layers"}, + default=False, metadata={"help": "don't use BeamableMM in attention layers"}, ) lenpen: float = field( default=1, @@ -724,12 +769,10 @@ class GenerationConfig(FairseqDataclass): }, ) sacrebleu: bool = field( - default=False, - metadata={"help": "score with sacrebleu"}, + default=False, metadata={"help": "score with sacrebleu"}, ) score_reference: bool = field( - default=False, - metadata={"help": "just score the reference translation"}, + default=False, metadata={"help": "just score the reference translation"}, ) prefix_size: int = field( default=0, @@ -763,12 +806,10 @@ class GenerationConfig(FairseqDataclass): }, ) temperature: float = field( - default=1.0, - metadata={"help": "temperature for generation"}, + default=1.0, metadata={"help": "temperature for generation"}, ) diverse_beam_groups: int = field( - default=-1, - metadata={"help": "number of groups for Diverse Beam Search"}, + default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, ) diverse_beam_strength: float = field( default=0.5, @@ -787,16 +828,13 @@ class GenerationConfig(FairseqDataclass): }, ) print_step: bool = field( - default=False, - metadata={"help": "print steps"}, + default=False, metadata={"help": "print steps"}, ) lm_path: Optional[str] = field( - default=None, - metadata={"help": "path to lm checkpoint for lm fusion"}, + default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, ) lm_weight: float = field( - default=0.0, - metadata={"help": "weight for lm probs for lm fusion"}, + default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, ) # arguments for iterative refinement generator @@ -805,8 +843,7 @@ class GenerationConfig(FairseqDataclass): metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, ) iter_decode_max_iter: int = field( - default=10, - metadata={"help": "maximum iterations for iterative refinement."}, + default=10, metadata={"help": "maximum iterations for iterative refinement."}, ) iter_decode_force_max_iter: bool = field( default=False, @@ -833,8 +870,7 @@ class GenerationConfig(FairseqDataclass): }, ) retain_dropout: bool = field( - default=False, - metadata={"help": "Use dropout at inference time"}, + default=False, metadata={"help": "Use dropout at inference time"}, ) # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # retain_dropout_modules: Optional[List[str]] = field( @@ -859,8 +895,7 @@ class GenerationConfig(FairseqDataclass): @dataclass class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, - metadata={"help": "path(s) to model file(s), colon separated"}, + default=None, metadata={"help": "path(s) to model file(s), colon separated"}, ) post_process: Optional[str] = field( default=None, @@ -883,6 +918,11 @@ class CommonEvalConfig(FairseqDataclass): results_path: Optional[str] = field( default=None, metadata={"help": "path to save eval results (optional)"} ) + # GShard or Switch model + is_moe: bool = field( + default=False, + metadata={"help": "if set, use distributed init for MoE generation or evaluation"}, + ) @dataclass @@ -911,6 +951,10 @@ class EvalLMConfig(FairseqDataclass): "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" }, ) + stats_path: Optional[str] = field(default=None, metadata={'argparse_alias': '--sp'}) + max_valid_steps: Optional[int] = field(default=None, metadata={'help': 'How many batches to evaluate', + "argparse_alias": "--nval"}) + @dataclass @@ -922,8 +966,7 @@ class InteractiveConfig(FairseqDataclass): }, ) input: str = field( - default="-", - metadata={"help": "file to read from; use - for stdin"}, + default="-", metadata={"help": "file to read from; use - for stdin"}, ) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 27c9006fdb..3acc103d43 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -58,8 +58,7 @@ def gen_parser_from_dataclass( """convert a dataclass instance to tailing parser arguments""" def argparse_name(name: str): - if name == "data": - # normally data is positional args + if name in dataclass_instance.positional_args(): return name if name == "_name": # private member, skip diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 9d74398325..d065ecabc3 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -4,13 +4,15 @@ # LICENSE file in the root directory of this source tree. import contextlib +import os +import re +from glob import glob from typing import Optional import torch - from fairseq.dataclass.configs import DistributedTrainingConfig from fairseq.distributed import utils as dist_utils - +from fairseq.file_io import load_and_pop_last_optimizer_state try: from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP @@ -26,6 +28,7 @@ class FullyShardedDataParallel(FSDP): fairseq-specific checkpoint saving/loading logic. Args: + is_moe (bool): if True, use MoE-specific checkpointing logic use_sharded_state (bool): if True, then ``state_dict`` will return ``FSDP.local_state_dict`` and ``load_state_dict`` will call ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will @@ -34,20 +37,38 @@ class FullyShardedDataParallel(FSDP): from rank 0 to other ranks. """ - def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + def __init__(self, *args, is_moe: bool = None, use_sharded_state: bool = False, **kwargs): if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) + if is_moe is None: + if torch.distributed.get_rank() == 0: + from fairseq import pdb; pdb.set_trace() + else: + import time; time.sleep(1000) + assert is_moe is not None super().__init__(*args, **kwargs) + self.is_moe = is_moe self.use_sharded_state = use_sharded_state + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + def state_dict(self, destination=None, prefix='', keep_vars=False): if self.use_sharded_state: return super().local_state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars ) + elif self.is_moe: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) else: if self.rank == 0: return super().state_dict( @@ -62,6 +83,8 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def load_state_dict(self, state_dict, strict=True, model_cfg=None): if self.use_sharded_state: return super().load_local_state_dict(state_dict, strict=strict) + elif self.is_moe: + return super().load_state_dict(state_dict, strict=strict) else: state_dict = dist_utils.broadcast_object( state_dict, src_rank=0, group=self.process_group @@ -70,7 +93,7 @@ def load_state_dict(self, state_dict, strict=True, model_cfg=None): @contextlib.contextmanager -def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): +def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False, **kwargs): try: from fairscale.nn import enable_wrap except ImportError: @@ -93,8 +116,14 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F "cpu_offload": cfg.cpu_offload, "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, + "state_dict_device": torch.device("cpu"), + **kwargs } - with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=use_sharded_state, + **fsdp_config, + ): yield @@ -109,14 +138,80 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): """ try: from fairscale.nn import wrap - cls = FullyShardedDataParallel if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) if num_params >= min_num_params: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) else: return module else: - return wrap(module, cls=cls, **kwargs) + return wrap(module, **kwargs) except ImportError: return module + + +def consolidate_fsdp_shards(pth_prefix: str) -> str: + if pth_prefix.endswith(".pt"): + pth_prefix = pth_prefix[:-3] + save_prefix = pth_prefix + "_consolidated" # .pt' + moe_paths = glob(f"{pth_prefix}*rank*shard*.pt") + all_ckpt_files = sorted(glob(f"{pth_prefix}*shard*.pt")) + assert all_ckpt_files, f"no paths matched {pth_prefix}*shard*.pt" + weights = [] + metadata = [] + expert_paths = [] + expert_dest_paths = [] + expert_ranks = [] + dense = not bool(moe_paths) + for p in all_ckpt_files: + if re.search("rank-(\d+)", os.path.basename(p)): # expert checkpoint + expert_paths.append(p) + r = re.search("rank-(\d+)", os.path.basename(p)).groups()[0] + assert r not in expert_ranks + expert_ranks.append(r) + expert_dest_paths.append(f"{save_prefix}-rank-{r}.pt") + else: + ckpt = load_and_pop_last_optimizer_state(p) + weights.append(ckpt["model"]) + metadata.append(ckpt["shard_metadata"]) + assert weights, f'all files were considered experts: {all_ckpt_files}' + consolidated_weights = FSDP.consolidate_shard_weights(shard_weights=weights, shard_metadata=metadata, strict=False) + del weights, metadata + + if dense: + ckpt_consolidated = dict( + model=consolidated_weights, + cfg=ckpt["cfg"], + extra_state=ckpt["extra_state"], + optimizer_history=ckpt["optimizer_history"], + args=ckpt["args"], + ) + save_path = f"{save_prefix}.pt" + torch.save(ckpt_consolidated, save_path) + print(f"saved to {save_path}") + return save_path + + ckpt_shared = dict( + model=consolidated_weights, + cfg=ckpt["cfg"], + extra_state=ckpt["extra_state"], + optimizer_history=ckpt["optimizer_history"], + args=ckpt["args"], + ) + torch.save(ckpt_shared, f"{save_prefix}-shared.pt") + # Process experts + for src, dst in zip(expert_paths, expert_dest_paths): + ckpt = load_and_pop_last_optimizer_state(src) + expert_wt = FSDP.consolidate_shard_weights( + shard_weights=[ckpt["model"]], shard_metadata=[ckpt["shard_metadata"]], strict=False + ) + full_ckpt = dict( + model=expert_wt, + cfg=ckpt["cfg"], + extra_state=ckpt["extra_state"], + optimizer_history=ckpt["optimizer_history"], + args=ckpt["args"], + ) + torch.save(full_ckpt, dst) + print(f"saved consolidated MoE with prefix {save_prefix}.pt") + return f"{save_prefix}.pt" diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index b586e76b7f..a447c2a09e 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -134,6 +134,17 @@ def reduction_fn(): for param in params: if not param.requires_grad: continue + + if hasattr(param, 'base_expert'): + # Skip gradient sync for unshared parameters + continue + + if hasattr(param, 'expert'): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.data.div_(self.world_size) + continue if param.grad is None: param.grad = torch.zeros_like(param) if param.grad.requires_grad: diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 710ca18628..030de5f1b3 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib import io import logging import os @@ -51,15 +52,15 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): if cfg.pipeline_model_parallel: num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) - if all( + if cfg.distributed_port > 0: + # we can determine the init method automatically for Slurm + _infer_slurm_init(cfg, num_pipelines_per_node) + elif all( key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): # support torch.distributed.launch _infer_torch_distributed_launch_init(cfg) - elif cfg.distributed_port > 0: - # we can determine the init method automatically for Slurm - _infer_slurm_init(cfg, num_pipelines_per_node) elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs _infer_single_node_init(cfg) @@ -90,51 +91,55 @@ def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): hostnames = subprocess.check_output( ["scontrol", "show", "hostnames", node_list] ) - cfg.distributed_init_method = "tcp://{host}:{port}".format( - host=hostnames.split()[0].decode("utf-8"), - port=cfg.distributed_port, - ) - nnodes = int(os.environ.get("SLURM_NNODES")) - ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") - if ntasks_per_node is not None: - ntasks_per_node = int(ntasks_per_node) - else: - ntasks = int(os.environ.get("SLURM_NTASKS")) - nnodes = int(os.environ.get("SLURM_NNODES")) - assert ntasks % nnodes == 0 - ntasks_per_node = int(ntasks / nnodes) - if ntasks_per_node == 1: - gpus_per_node = torch.cuda.device_count() - node_id = int(os.environ.get("SLURM_NODEID")) - cfg.distributed_rank = node_id * gpus_per_node - cfg.distributed_world_size = nnodes * gpus_per_node - elif cfg.pipeline_model_parallel: - assert ntasks_per_node == num_pipelines_per_node, ( - "SLURM --ntasks-per-node must match number of pipelines per " - "node (={})".format(num_pipelines_per_node) - ) - cfg.distributed_no_spawn = True - # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on - # the first node, [1, 2] on the second node, etc. This - # matches torch.distributed.launch. - node_id = int(os.environ.get("SLURM_NODEID")) - local_id = int(os.environ.get("SLURM_LOCALID")) - cfg.distributed_rank = node_id * num_pipelines_per_node + local_id - # In the above example, device_id will always be in [0, 1], - # which also matches torch.distributed.launch. - cfg.device_id = local_id - # We also want to set distributed_world_size to be the total - # number of pipelines across all nodes. - cfg.distributed_world_size = nnodes * num_pipelines_per_node - else: - assert ntasks_per_node == cfg.distributed_world_size // nnodes - cfg.distributed_no_spawn = True - cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) - cfg.device_id = int(os.environ.get("SLURM_LOCALID")) + host = hostnames.split()[0].decode("utf-8") except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed - pass + # if we're in a container, then maybe MASTER_ADDR is set + host = os.environ.get("MASTER_ADDR", None) + if host is None: + return + cfg.distributed_init_method = "tcp://{host}:{port}".format( + host=host, port=cfg.distributed_port + ) + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None: + ntasks_per_node = int(ntasks_per_node) + else: + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) + assert ntasks % nnodes == 0 + ntasks_per_node = int(ntasks / nnodes) + if ntasks_per_node == 1: + gpus_per_node = torch.cuda.device_count() + node_id = int(os.environ.get("SLURM_NODEID")) + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) + ) + cfg.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, device_id will always be in [0, 1], + # which also matches torch.distributed.launch. + cfg.device_id = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + cfg.distributed_world_size = nnodes * num_pipelines_per_node + else: + assert ntasks_per_node == torch.cuda.device_count() + cfg.distributed_world_size = ntasks_per_node * nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) def _infer_single_node_init(cfg: DistributedTrainingConfig): @@ -281,7 +286,6 @@ def distributed_init(cfg: FairseqConfig): cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers - xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) @@ -307,6 +311,11 @@ def distributed_init(cfg: FairseqConfig): model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + if ((getattr(cfg.model, "moe_freq", 0) > 0 and + getattr(cfg.model, "moe_expert_count", 0) > 0) or + getattr(cfg.model, "base_layers", 0) > 0): + cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" + return cfg.distributed_training.distributed_rank @@ -357,7 +366,10 @@ def call_main(cfg: FairseqConfig, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, cfg, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), ) else: # single GPU main @@ -423,6 +435,52 @@ def get_global_group(): else: return None +def get_moe_group(moe_expert_count): + if torch.distributed.is_initialized(): + if not hasattr(get_moe_group, "_moe_groups"): + world_size = get_global_world_size() + + # more experts than world size + if world_size <= moe_expert_count: + assert moe_expert_count % world_size == 0 + moe_groups = [[i] for i in range(world_size)] + + # larger world than num experts + else: + assert world_size % moe_expert_count == 0 + ranks_per_group = world_size // moe_expert_count + moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] + for i in range(moe_expert_count)] + + get_moe_group._moe_group_idx = moe_groups + get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] + + my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) + return get_moe_group._moe_groups[my_group_idx] + +def get_all2all_group(moe_expert_count): + if torch.distributed.is_initialized(): + if not hasattr(get_all2all_group, "_all2all_groups"): + world_size = get_global_world_size() + + # more experts than world size + if world_size <= moe_expert_count: + assert moe_expert_count % world_size == 0 + all2all_groups = [[i for i in range(world_size)]] + + # larger world than num experts + else: + assert world_size % moe_expert_count == 0 + ranks_per_group = world_size // moe_expert_count + all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)] + for i in range(ranks_per_group)] + + get_all2all_group._all2all_group_idx = all2all_groups + get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups] + + my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) + return get_all2all_group._all2all_groups[my_group_idx] + def get_global_rank(): if use_xla(): @@ -455,12 +513,20 @@ def get_data_parallel_group(): def get_data_parallel_rank(): """Return my rank for the data parallel group.""" - return get_rank(get_data_parallel_group()) + dp_group = get_data_parallel_group() + if dp_group is not None: + return get_rank(dp_group) + else: + return get_global_rank() def get_data_parallel_world_size(): """Return world size for the data parallel group.""" - return get_world_size(get_data_parallel_group()) + dp_group = get_data_parallel_group() + if dp_group is not None: + return get_world_size(dp_group) + else: + return get_global_world_size() def get_model_parallel_group(): diff --git a/fairseq/fb_pathhandlers.py b/fairseq/fb_pathhandlers.py new file mode 100644 index 0000000000..ca71c19971 --- /dev/null +++ b/fairseq/fb_pathhandlers.py @@ -0,0 +1,546 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +import io +import logging +import os +import shutil +import types +from typing import Any, Dict, IO, List, Optional, Tuple, Union + +import boto3 +from boto3.s3.transfer import TransferConfig +import botocore +import datetime as dt +from iopath.common.file_io import file_lock, get_cache_dir, PathHandler + +logger = logging.getLogger(__name__) + + +# Override for close() on files to write to Amazon S3 +def s3_close_and_upload(self, client, bucket, s3_path, transfer_config): + # Seek to start, for use by upload_fileobj. + self.seek(0) + + # Reinstall the proper close. + self.close = self._close + + # upload_fileobj needs bytes + # NOTE: This a very undesirable hack. + if isinstance(self, io.StringIO): + self = io.BytesIO(self.getvalue().encode('utf-8')) + + # Upload + try: + client.upload_fileobj( + self, + bucket, + s3_path, + Config=transfer_config, + ) + except botocore.exceptions.ClientError as e: + raise OSError( + f"Error in file upload - {e}" + f"{type(e).__name__}: {e}" + ) from e + +class S3PathHandler(PathHandler): + """ + Support for Amazon Simple Storage Service (S3) + + PathHanlder methods, at a glance: + + File --torch.load-> In --open(..., 'w')-> Amazon <- _exists,_isfile,_isdir,_ls,_rm ... + System <-torch.save-- Mem. <-open(..., 'r')-- S3 + <----------------_copy_from_local----------------- + ----------------_get_local_path -----------------> + + Mem usage, for processing N bytes: + open(..., mode) + mode=='w': 2N, due to fully buffering user input, + *and doing naive conversion from StringIO -> BytesIO*, + before writing to S3 + ^ Potential for optimization. + mode=='wb': N, due to fully buffering user input, before writing to S3. + mode=='r': N, due to fully buffering file in memory + mode=='rb': N, due to fully buffering file in memory + _copy_from_local: ≈0. boto3 streams from file system directly to s3 + _get_local_path: ≈0. boto3 streams from s3 directly from s3 to file system + """ + # Disable failures if not all args are specified. + _strict_kwargs_check = False + + S3_PREFIX = "s3://" + CACHE_SUBDIR_NAME = "s3_cache" + + def __init__( + self, + cache_dir: Optional[str] = None, + transfer_config_kwargs: Optional[Dict] = None + ): + """ + Args: + cache_dir (str): Local filesystem directory to use for caching. If None, + uses default from `file_io.get_cache_dir()`. + transfer_config_kwargs (dict): Settings for boto3.s3.transfer.TransferConfig. + Used to specify settings for multipart transfers. + See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3.html for details. + """ + self.cache_dir = cache_dir + self.transfer_config = TransferConfig( + **(transfer_config_kwargs if transfer_config_kwargs else {}) + ) + + def _get_supported_prefixes(self) -> List[str]: + """ + Returns: + List[str]: the list of URI prefixes this PathHandler can support + """ + return [self.S3_PREFIX] + + def _parse_uri(self, uri: str) -> Tuple[str, str]: + """ + Parses a "s3://bucket/path" URI into `bucket` and `path` strings. + + Args: + uri (str): A s3:// URI. + + Returns: + bucket (str): the s3 bucket. + path (str): the path on the s3 system. + """ + splits = uri.replace(self.S3_PREFIX, '').split('/') + bucket = splits[0] + path = '/'.join(splits[1:]) + return bucket, path + + def _get_client(self, bucket: str): + if not hasattr(self, "client"): + try: + session = boto3.Session() + self.client = session.client('s3') + except botocore.exceptions.NoCredentialsError as e: + logger.error( + " See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html " + " for method of using environment variable to point to aws credentials, and the " + " order in which boto will search for said credentials. " + ) + logger.error( + "Boto3 searches via the order below. If on FAIR Cluster, method 4 may be most convenient." + "" + "The order in which Boto3 searches for credentials is:" + "1) [UNUSED] Passing credentials as parameters in the boto.client() method" + "2) [UNUSED] Passing credentials as parameters when creating a Session object" + "3) Environment variables" + " AWS_ACCESS_KEY_ID - The access key for your AWS account." + " AWS_SECRET_ACCESS_KEY - The secret key for your AWS account." + " AWS_SESSION_TOKEN - The session key for your AWS account." + " This is only needed when you are using temporary credentials. " + "4) Shared credential file (~/.aws/credentials)" + " default: ~/.aws/credentials" + " changed via: AWS_SHARED_CREDENTIALS_FILE" + " *for FAIR cluster usage: `export AWS_SHARED_CREDENTIALS_FILE=~/.fairusers_aws/credentials`" + "5) AWS config file (~/.aws/config)" + " default: ~/.aws/config" + " changed via: AWS_CONFIG_FILE" + "6) Assume Role provider" + "7) Boto2 config file (/etc/boto.cfg and ~/.boto)" + "8) Instance metadata service on an Amazon EC2 instance that has an IAM role configured." + ) + raise OSError( + f"Error in making s3 client for bucket {bucket}" + f"{type(e).__name__}: {e}" + ) from e + + return self.client + + def _local_cache_path( + self, + path: str, + ): + """ + Helper that returns a local cache path for a given uri. + Args: + path (str): A URI supported by this PathHandler. + Returns: + local_cache_path (str): a file path which exists on the local file system, + in a cache directory. + """ + bucket, file_path = self._parse_uri(path) + return os.path.join( + get_cache_dir(self.cache_dir), self.CACHE_SUBDIR_NAME, file_path + ) + + def _get_local_path( + self, + path: str, + **kwargs: Any + ) -> str: + """ + Get a filepath which is compatible with native Python I/O such as `open` + and `os.path`. + If URI points to a remote resource, this function may download and cache + the resource to local disk. In this case, the cache stays on filesystem + (under `file_io.get_cache_dir()`) and will be used by a different run. + Therefore this function is meant to be used with read-only resources. + Args: + path (str): A URI supported by this PathHandler + Returns: + local_path (str): a file path which exists on the local file system + """ + self._check_kwargs(kwargs) + + # Cheap check first. + if path.endswith("/"): + raise NotImplementedError( + "S3PathHandler does not currently support downloading directories" + ) + assert self._isfile(path) + + local_path = self._local_cache_path(path) + with file_lock(local_path): + if os.path.exists(local_path): + # If local object's last modified time is *after* remote object's last modified + # time, do not use the cache. Instead, redownload. + response = self._head_object(path) + if response is not None: + remote_dt = response['LastModified'] + local_dt = dt.datetime.fromtimestamp(os.path.getmtime(local_path)).astimezone() + # NOTE: may consider still avoid cache if times are close, to avoid a race condition. + # Currently, a lengthy download of a very recent but stale file would have a late + # local last modified timestamp, and would be improperly used. + # Better fix: set last modified time via the remote object's last modified time, + # in download_file(). + if (local_dt - remote_dt) > dt.timedelta(minutes=0): + logger.info("URL {} was already cached in {}".format(path, local_path)) + return local_path + + logger.info("Caching {} ...".format(path)) + tmp = local_path + ".tmp" + # clean-up tmp if found, because if tmp exists, it must be a dirty + # result of a previously process that didn't cleanup itself. + if os.path.isfile(tmp): + os.unlink(tmp) + + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + try: + response = client.download_file( + bucket, + s3_path, + tmp, + Config=self.transfer_config + ) + + # First download to tmp, then move it, because move is + # (almost?) atomic when src and dst are in the same file + # system. This will avoid partial cache state if the + # process is killed. + shutil.move(tmp, local_path) + finally: + try: + os.unlink(tmp) + except Exception: + pass + + logger.info("URL {} cached in {}".format(path, local_path)) + return local_path + + def _copy_from_local( + self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any + ) -> bool: + """ + Copies a local file to the specified URI. + If the URI is another local path, this should be functionally identical + to copy. + Args: + local_path (str): a file path which exists on the local file system + dst_path (str): A URI supported by this PathHandler + overwrite (bool): Bool flag for forcing overwrite of existing URI + Returns: + status (bool): True on success + """ + self._check_kwargs(kwargs) + + # Just checking this to avoid expensive API calls in self._isdir(). + if local_path.endswith("/") or dst_path.endswith("/"): + raise NotImplementedError( + "S3PathHandler does not currently support uploading directories" + ) + + bucket, s3_path = self._parse_uri(dst_path) + client = self._get_client(bucket) + try: + client.upload_file( + local_path, + bucket, + s3_path, + Config=self.transfer_config + ) + return True + except botocore.exceptions.ClientError as e: + logger.error("Error in file upload - {}".format(str(e))) + return False + + def _decorate_buf_with_s3_methods( + self, buffer: Union[IO[str], IO[bytes]], client: Any, bucket: str, s3_path: str, transfer_config: Any + ): + # Save old close method. + buffer._close = buffer.close + + # Add in our new close method. + fn = partial(s3_close_and_upload, client=client, bucket=bucket, s3_path=s3_path, transfer_config=transfer_config) + buffer.close = types.MethodType(fn, buffer) + + def _open( + self, + path: str, + mode: str = "r", + buffering: int = -1, + # The following three arguments are unused, + # But are included to avoid triggering WARNING + # messages from _check_kargs. + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + **kwargs: Any + ) -> Union[IO[str], IO[bytes]]: + """ + Open a stream to a URI, similar to the built-in `open`. + Args: + path (str): A URI supported by this PathHandler + mode (str): Specifies the mode in which the file is opened. It defaults + to 'r'. + buffering (int): An optional integer used to set the buffering policy. + Pass 0 to switch buffering off and an integer >= 1 to indicate the + size in bytes of a fixed-size chunk buffer. When no buffering + argument is given, the default buffering policy depends on the + underlying I/O implementation. + Returns: + file: a file-like object. + """ + self._check_kwargs(kwargs) + + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + + # AWS methods download_fileobj() and upload_fileobj() + # both expect binary file-like objects. + if 'r' in mode: + # 1. Download into io.BytesIO. + # (binary format is required by download_fileobj.) + buffer = io.BytesIO() + try: + # NOTE: Will download entire file! Further optimization to + # only read a portion of the file could be implemented here. + # NOTE: We download into an in-memory buffer. If downloading to + # filesystem is desirable, use _get_local_path(). + client.download_fileobj( + bucket, + s3_path, + buffer, + Config=self.transfer_config + ) + except botocore.exceptions.ClientError as e: + raise OSError( + f"Error in making s3 client for bucekt {bucket}" + f"{type(e).__name__}: {e}" + ) from e + + # 2. Set file-pointer to beginning of file. + buffer.seek(0) + + # 3. Use convenient wrapper to make object look like StringIO, + # if user wants non-binary. + if 'b' not in mode: + buffer = io.TextIOWrapper(buffer, encoding='utf-8') + + return buffer + + elif 'w' in mode: + # 1. For writing, we give the user io.BytesIO or io.StringIO. + if 'b' in mode: + buffer = io.BytesIO() + else: + buffer = io.StringIO() + + # 2. Decorate buffer so that we upload when it's closed by user. + # If StringIO, decorator does a simple+expensive conversion + # to bytesIO before uploading. + # (because upload_fileobj requires binary) + self._decorate_buf_with_s3_methods(buffer, client, bucket, s3_path, self.transfer_config) + + return buffer + + else: + raise OSError(f"Unsupported open mode {mode}") + + def _copy( + self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any + ) -> bool: + """ + Copies a source path to a destination path. + Args: + src_path (str): A URI supported by this PathHandler + dst_path (str): A URI supported by this PathHandler + overwrite (bool): Bool flag for forcing overwrite of existing file + Returns: + status (bool): True on success + """ + self._check_kwargs(kwargs) + + src_bucket, src_s3_path = self._parse_uri(src_path) + dst_bucket, dst_s3_path = self._parse_uri(dst_path) + assert src_bucket == dst_bucket, \ + "For now, can only _copy() within a bucket." + client = self._get_client(src_bucket) + + try: + client.copy( + { + 'Bucket': src_bucket, + 'Key': src_s3_path, + }, + dst_bucket, + dst_s3_path, + Config=self.transfer_config, + ) + return True + except botocore.exceptions.ClientError as e: + logger.error("Error in file copy - {}".format(str(e))) + return False + + def _head_object(self, path: str) -> Optional[Dict]: + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + + try: + # Raises exception if not exists, else it exists. + response = client.head_object( + Bucket=bucket, + Key=s3_path + ) + return response + except botocore.exceptions.ClientError as e: + if e.response['Error']['Message'] == 'Bad Request': + raise OSError( + f"Error in checking s3 path {path} - " + f"{type(e).__name__}: {e}" + ) from e + return None + + def _exists(self, path: str, **kwargs: Any) -> bool: + """ + Checks if there is a resource at the given URI. + Args: + path (str): A URI supported by this PathHandler + Returns: + bool: true if the path exists + """ + self._check_kwargs(kwargs) + + return self._head_object(path) is not None + + def _isfile(self, path: str, **kwargs: Any) -> bool: + """ + Checks if the resource at the given URI is a file. + Args: + path (str): A URI supported by this PathHandler + Returns: + bool: true if the path is a file + """ + self._check_kwargs(kwargs) + + # NOTE: this incurs an API call. + return not path.endswith('/') and self._exists(path, **kwargs) + + def _isdir(self, path: str, **kwargs: Any) -> bool: + """ + Checks if the resource at the given URI is a directory. + Args: + path (str): A URI supported by this PathHandler + Returns: + bool: true if the path is a directory + """ + self._check_kwargs(kwargs) + + # NOTE: this incurs an API call. + return path.endswith('/') and self._exists(path, **kwargs) + + def _ls(self, path: str, **kwargs: Any) -> List[str]: + """ + List the contents of the directory at the provided URI. + Args: + path (str): A URI supported by this PathHandler + Returns: + List[str]: list of contents in given path + """ + self._check_kwargs(kwargs) + + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + + try: + # Pagination needed if >1000 entries. + paginator = client.get_paginator('list_objects_v2') + pages = paginator.paginate( + Bucket=bucket, + Prefix=s3_path, + ) + return [obj['Key'] for page in pages for obj in page.get('Contents', [])] + except botocore.exceptions.ClientError as e: + raise OSError( + f"Error in ls path {path} - " + f"{type(e).__name__}: {e}" + ) from e + + def _mkdirs(self, path: str, **kwargs: Any) -> None: + """ + Recursive directory creation function. Like mkdir(), but makes all + intermediate-level directories needed to contain the leaf directory. + Similar to the native `os.makedirs`. + Args: + path (str): A URI supported by this PathHandler + """ + self._check_kwargs(kwargs) + + assert path.endswith('/'), path + + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + + try: + client.put_object( + Bucket=bucket, + Key=s3_path + ) + except botocore.exceptions.ClientError as e: + raise OSError( + f"Error in mkdirs path {path} - " + f"{type(e).__name__}: {e}" + ) from e + + def _rm(self, path: str, **kwargs: Any) -> None: + """ + Remove the file (not directory) at the provided URI. + Args: + path (str): A URI supported by this PathHandler + """ + self._check_kwargs(kwargs) + + bucket, s3_path = self._parse_uri(path) + client = self._get_client(bucket) + + try: + client.delete_object( + Bucket=bucket, + Key=s3_path + ) + except botocore.exceptions.ClientError as e: + raise OSError( + f"Error in rm path {path} - " + f"{type(e).__name__}: {e}" + ) from e diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 9a78ab505d..9291f50d1e 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -7,32 +7,29 @@ import logging import os +import errno import shutil from typing import List, Optional +import json +import torch logger = logging.getLogger(__file__) +from .fb_pathhandlers import S3PathHandler try: - from fvcore.common.file_io import PathManager as FVCorePathManager - - try: - # [FB only - for now] AWS PathHandler for PathManager - from .fb_pathhandlers import S3PathHandler - - FVCorePathManager.register_handler(S3PathHandler()) - except KeyError: - logging.warning("S3PathHandler already registered.") - except ImportError: - logging.debug( - "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." - ) - + from iopath.common.file_io import PathManager + IOPathPathManager = PathManager() except ImportError: - FVCorePathManager = None + IOPathPathManager = None -IOPathPathManager = None +try: + IOPathPathManager.register_handler(S3PathHandler()) +except KeyError: + pass +except Exception: + logging.exception("Failed to register S3 Path Handler. Try pip install boto3") class PathManager: @@ -51,8 +48,8 @@ def open( errors: Optional[str] = None, newline: Optional[str] = None, ): - if FVCorePathManager: - return FVCorePathManager.open( + if IOPathPathManager: + return IOPathPathManager.open( path=path, mode=mode, buffering=buffering, @@ -71,47 +68,63 @@ def open( @staticmethod def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: - if FVCorePathManager: - return FVCorePathManager.copy( + if IOPathPathManager: + return IOPathPathManager.copy( src_path=src_path, dst_path=dst_path, overwrite=overwrite ) return shutil.copyfile(src_path, dst_path) + @staticmethod + def symlink(src_path: str, dst_path: str): + try: + os.symlink(src_path, dst_path) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(dst_path) + os.symlink(src_path, dst_path) + @staticmethod def get_local_path(path: str, **kwargs) -> str: - if FVCorePathManager: - return FVCorePathManager.get_local_path(path, **kwargs) + if IOPathPathManager: + return IOPathPathManager.get_local_path(path, **kwargs) return path @staticmethod def exists(path: str) -> bool: - if FVCorePathManager: - return FVCorePathManager.exists(path) + if IOPathPathManager: + return IOPathPathManager.exists(path) return os.path.exists(path) @staticmethod def isfile(path: str) -> bool: - if FVCorePathManager: - return FVCorePathManager.isfile(path) + if IOPathPathManager: + return IOPathPathManager.isfile(path) return os.path.isfile(path) + @staticmethod + def islink(path: str) -> Optional[bool]: + if not PathManager.path_requires_pathmanager(path): + return os.path.islink(path) + return None + @staticmethod def ls(path: str) -> List[str]: - if FVCorePathManager: - return FVCorePathManager.ls(path) + if IOPathPathManager: + return IOPathPathManager.ls(path) return os.listdir(path) @staticmethod def mkdirs(path: str) -> None: - if FVCorePathManager: - return FVCorePathManager.mkdirs(path) + if IOPathPathManager: + return IOPathPathManager.mkdirs(path) os.makedirs(path, exist_ok=True) @staticmethod def rm(path: str) -> None: - if FVCorePathManager: - return FVCorePathManager.rm(path) + if IOPathPathManager: + return IOPathPathManager.rm(path) os.remove(path) + assert not os.path.exists(path) @staticmethod def chmod(path: str, mode: int) -> None: @@ -120,15 +133,15 @@ def chmod(path: str, mode: int) -> None: @staticmethod def register_handler(handler) -> None: - if FVCorePathManager: - return FVCorePathManager.register_handler(handler=handler) + if IOPathPathManager: + return IOPathPathManager.register_handler(handler=handler) @staticmethod def copy_from_local( local_path: str, dst_path: str, overwrite: bool = False, **kwargs ) -> None: - if FVCorePathManager: - return FVCorePathManager.copy_from_local( + if IOPathPathManager: + return IOPathPathManager.copy_from_local( local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs ) return shutil.copyfile(local_path, dst_path) @@ -136,8 +149,8 @@ def copy_from_local( @staticmethod def path_requires_pathmanager(path: str) -> bool: """Do we require PathManager to access given path?""" - if FVCorePathManager: - for p in FVCorePathManager._path_handlers.keys(): + if IOPathPathManager: + for p in IOPathPathManager._path_handlers.keys(): if path.startswith(p): return True return False @@ -162,18 +175,12 @@ def opena( encoding: Optional[str] = None, errors: Optional[str] = None, newline: Optional[str] = None, + callback_after_file_close=None, ): """ Return file descriptor with asynchronous write operations. """ global IOPathPathManager - if not IOPathPathManager: - logging.info("ioPath is initializing PathManager.") - try: - from iopath.common.file_io import PathManager - IOPathPathManager = PathManager() - except Exception: - logging.exception("Failed to initialize ioPath PathManager object.") return IOPathPathManager.opena( path=path, mode=mode, @@ -181,6 +188,7 @@ def opena( encoding=encoding, errors=errors, newline=newline, + callback_after_file_close=callback_after_file_close ) @staticmethod @@ -194,3 +202,31 @@ def async_close() -> bool: if IOPathPathManager: return IOPathPathManager.async_close() return False + + +def torch_load_cpu(path): + state = torch.load(path, map_location=torch.device("cpu")) + # If model was trained with fp16, model from loaded state_dict can be moved to fp16 + if isinstance(state, dict) and 'cfg' in state: + if state['cfg']['common']['fp16'] or state['cfg']['common']['memory_efficient_fp16']: + state['model'] = {k: v.half() for k, v in state['model'].items()} + return state + + +def save_json(content, path, indent=4): + with open(path, "w") as f: + json.dump(content, f, indent=indent) + + +def load_json(p): + return json.load(open(p)) + +def load_jsonl(path): + with open(path).read() as jsonl_content: + result = [json.loads(jline) for jline in jsonl_content.splitlines()] + return result + +def load_and_pop_last_optimizer_state(pth): + st = torch_load_cpu(pth) + st.pop('last_optimizer_state', None) + return st diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 7de2e2b0d4..96e1b3422f 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -9,6 +9,7 @@ import logging import os from typing import Any, Dict, Iterator, List +from fairseq.distributed.utils import get_data_parallel_rank, get_data_parallel_world_size import torch from fairseq import utils @@ -25,7 +26,7 @@ def from_pretrained( checkpoint_file="model.pt", data_name_or_path=".", archive_map=None, - **kwargs + **kwargs, ): from fairseq import checkpoint_utils, file_utils @@ -73,6 +74,8 @@ def from_pretrained( models, args, task = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)], arg_overrides=kwargs, + suffix=kwargs.get("suffix", ""), + is_moe=kwargs.get("is_moe", False) ) return { @@ -88,17 +91,47 @@ class GeneratorHubInterface(nn.Module): translation or language model. """ - def __init__(self, cfg, task, models): + lang_tokens = {} + langs = None + add_lang_bos_token = False + + def to_lang_token(self, lang): + return f"<{lang}>" + + def __init__(self, cfg, task, models, moe_disable_padding=True, skip_prepare_for_inference=False): super().__init__() self.cfg = cfg + self.task = task self.models = nn.ModuleList(models) self.src_dict = task.source_dictionary self.tgt_dict = task.target_dictionary + if "langs" in cfg.task: + self.langs = self.cfg.task.langs + lang_tokens = [ + self.to_lang_token(x.strip()) for x in self.cfg.task.langs.split(",") + ] + + # for debug purpose + for lang_token in lang_tokens: + if lang_token not in self.src_dict: + self.src_dict.add_symbol(lang_token) + + if lang_token not in self.tgt_dict: + self.tgt_dict.add_symbol(lang_token) + + self.lang_tokens = set(lang_tokens) + + if "add_bos_token" in cfg.task: + #self.add_lang_bos_token = True + self.add_lang_bos_token = cfg.task.add_bos_token + # optimize model for generation - for model in self.models: - model.prepare_for_inference_(cfg) + if not skip_prepare_for_inference: + for model in self.models: + # For moe models and eval_lm + model.prepare_for_inference_(cfg, moe_disable_padding=moe_disable_padding) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) @@ -151,11 +184,12 @@ def generate( verbose: bool = False, skip_invalid_size_inputs=False, inference_step_args=None, - **kwargs + batch_size=None, + **kwargs, ) -> List[List[Dict[str, torch.Tensor]]]: if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: return self.generate( - tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs + tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, batch_size=batch_size, **kwargs )[0] # build generator using current args as well as any kwargs @@ -168,11 +202,27 @@ def generate( inference_step_args = inference_step_args or {} results = [] - for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + rank, world_size = get_data_parallel_rank(), get_data_parallel_world_size() + batches = self._build_batches( + tokenized_sentences, skip_invalid_size_inputs, rank=rank, + world_size=world_size, batch_size=batch_size, + ) + # To ensure even batch count across workers, some batches might be dummy batches. We shouldn't score these. + first_batch = None + for batch in batches: + is_dummy_batch = False + if not first_batch and "net_input" in batch: + first_batch = batch + if "net_input" not in batch: + if first_batch is not None: + batch = first_batch + is_dummy_batch = True + else: + continue batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) - translations = self.task.inference_step( - generator, self.models, batch, **inference_step_args - ) + translations = self.task.inference_step(generator, self.models, batch, **inference_step_args) + if is_dummy_batch: # Don't score it or add it to hypotheses + continue for id, hypos in zip(batch["id"].tolist(), translations): results.append((id, hypos)) @@ -215,15 +265,64 @@ def getarg(name, default): ) return outputs + def get_sentence_and_language(self, sentence: str): + """ + If sentence is prefixed with the language, it is striped and both are replaced. + + input: 'en-ENSome sentence here' + output: en-EN, 'Some sentence here' + """ + + lang_begin = "" + lang_end = "" + + lang = None + if sentence.startswith(lang_begin): + idx = sentence.find(lang_end) + if idx > 0: + lang = sentence[: idx + len(lang_end)] + lang = lang.replace(lang_begin, "").replace(lang_end, "") + sentence = sentence[idx + len(lang_end) :] + + return lang, sentence + + def add_language_to_sentence(self, sentence: str, lang_token): + lang_begin = "" + lang_end = "" + + lang_prefix = lang_begin + lang_token + lang_end + sentence = lang_prefix + sentence + + return sentence + def encode(self, sentence: str) -> torch.LongTensor: + lang, sentence = self.get_sentence_and_language(sentence) + sentence = self.tokenize(sentence) sentence = self.apply_bpe(sentence) + + if lang is not None: + sentence = f"{lang} {sentence}" + return self.binarize(sentence) def decode(self, tokens: torch.LongTensor) -> str: sentence = self.string(tokens) + + # Remove the lang token + sent_split = sentence.split(" ", 1) + lang_token = None + if sent_split[0] in self.lang_tokens: + lang_token = sent_split[0] + sentence = sent_split[1] + sentence = self.remove_bpe(sentence) - return self.detokenize(sentence) + sentence = self.detokenize(sentence) + + if lang_token is not None: + sentence = self.add_language_to_sentence(sentence, lang_token) + + return sentence def tokenize(self, sentence: str) -> str: if self.tokenizer is not None: @@ -252,16 +351,20 @@ def string(self, tokens: torch.LongTensor) -> str: return self.tgt_dict.string(tokens) def _build_batches( - self, tokens: List[List[int]], skip_invalid_size_inputs: bool + self, tokens: List[torch.LongTensor], skip_invalid_size_inputs: bool, world_size=None, rank=None, batch_size=None ) -> Iterator[Dict[str, Any]]: lengths = torch.LongTensor([t.numel() for t in tokens]) + if batch_size is None: + batch_size = self.cfg.dataset.batch_size batch_iterator = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference(tokens, lengths), max_tokens=self.cfg.dataset.max_tokens, - max_sentences=self.cfg.dataset.batch_size, + max_sentences=batch_size, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, disable_iterator_cache=True, + num_shards=world_size, + shard_id=rank, ).next_epoch_itr(shuffle=False) return batch_iterator diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 7b56e31592..4b24052420 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -12,6 +12,8 @@ """ import contextlib +import logging +import subprocess import time import uuid from collections import OrderedDict, defaultdict @@ -286,3 +288,36 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + +def nvidia_smi_gpu_memory_stats(): + """ + Parse the nvidia-smi output and extract the memory used stats. + """ + out_dict = {} + try: + sp = subprocess.Popen( + ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + ) + out_str = sp.communicate() + out_list = out_str[0].decode("utf-8").split("\n") + out_dict = {} + for item in out_list: + if " MiB" in item: + gpu_idx, mem_used = item.split(',') + gpu_key = f"gpu_{gpu_idx}_mem_used_gb" + out_dict[gpu_key] = int(mem_used.strip().split(" ")[0]) / 1024 + except FileNotFoundError: + logging.error( + "Failed to find the 'nvidia-smi' executable for printing GPU stats" + ) + except subprocess.CalledProcessError as e: + logging.error(f"nvidia-smi returned non zero error code: {e.returncode}") + + return out_dict + + +def get_nvidia_smi_gpu_memory_stats_str(): + return "nvidia-smi stats: {}".format(nvidia_smi_gpu_memory_stats()) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 0ae2bc006d..b6528e657d 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -29,6 +29,7 @@ def progress_bar( iterator, log_format: Optional[str] = None, log_interval: int = 100, + log_file: Optional[str] = None, epoch: Optional[int] = None, prefix: Optional[str] = None, tensorboard_logdir: Optional[str] = None, @@ -39,6 +40,10 @@ def progress_bar( ): if log_format is None: log_format = default_log_format + if log_file is not None: + handler = logging.FileHandler(filename=log_file) + logger.addHandler(handler) + if log_format == "tqdm" and not sys.stderr.isatty(): log_format = "simple" @@ -344,6 +349,9 @@ def _writer(self, key): _writers[key].add_text("sys.argv", " ".join(sys.argv)) return _writers[key] + def __len__(self): + return len(self.wrapped_bar) + def __iter__(self): return iter(self.wrapped_bar) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index d393c02ae6..932e437e64 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -20,6 +20,7 @@ gen_parser_from_dataclass, ) from fairseq.models import FairseqDecoder, FairseqEncoder +from fairseq.modules.moe import MOELayer from omegaconf import DictConfig from torch import Tensor @@ -29,9 +30,10 @@ def check_type(module, expected_type): if hasattr(module, "unwrapped_module"): - assert isinstance(module.unwrapped_module, expected_type) + assert isinstance(module.unwrapped_module, expected_type), \ + f"{type(module.unwrapped_module)} != {expected_type}" else: - assert isinstance(module, expected_type) + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" class BaseFairseqModel(nn.Module): @@ -152,14 +154,11 @@ def do_upgrade(m, prefix): def set_num_updates(self, num_updates): """State from trainer to pass along to model at every update.""" - - def _apply(m): + for m in self.modules(): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) - self.apply(_apply) - - def prepare_for_inference_(self, cfg: DictConfig): + def prepare_for_inference_(self, cfg: DictConfig, moe_disable_padding=True): """Prepare model for inference.""" kwargs = {} kwargs["beamable_mm_beam_size"] = ( @@ -172,6 +171,9 @@ def prepare_for_inference_(self, cfg: DictConfig): kwargs["retain_dropout"] = cfg.generation.retain_dropout kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules self.make_generation_fast_(**kwargs) + for n, m in self.named_modules(): + if isinstance(m, MOELayer) and moe_disable_padding: + m.prepare_for_inference_() def make_generation_fast_(self, **kwargs): """ @@ -238,6 +240,8 @@ def from_pretrained( model_name_or_path, checkpoint_file="model.pt", data_name_or_path=".", + moe_disable_padding=True, + skip_prepare_for_inference=False, **kwargs, ): """ @@ -271,7 +275,13 @@ def from_pretrained( **kwargs, ) logger.info(x["args"]) - return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) + return hub_utils.GeneratorHubInterface( + x["args"], + x["task"], + x["models"], + moe_disable_padding=moe_disable_padding, + skip_prepare_for_inference=skip_prepare_for_inference, + ) @classmethod def hub_models(cls): diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index a0a0b8dcd5..2065e326ff 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -3,13 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import math from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn from fairseq import utils -from fairseq.distributed import fsdp_wrap +from fairseq.distributed import utils as dist_utils, fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -30,12 +31,40 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor - +import logging +logger = logging.getLogger(__name__) DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +def fsdp_wrap_expert(args, layer, min_num_params=0): + # Wrap MoE layer with FSDP using a process group with all replicated ranks + process_group = layer.moe_layer.expert_group + world_size = dist_utils.get_data_parallel_group().size() + pg_size = process_group.size() + num_experts = world_size/pg_size + + for i, expert in enumerate(layer.moe_layer.experts): + layer.moe_layer.experts[i] = fsdp_wrap( + expert, process_group=process_group, min_num_params=0 + ) + if getattr(args, "moe_normalize_expert_grad", "world_size") == "sqrt_world_size": + expert_normalization_term = math.sqrt(num_experts) + else: + expert_normalization_term = num_experts + for p in layer.moe_layer.experts.parameters(): + p.expert = True + # Scale grads by world_size/pg_size so that grads match the equivalent replicated + # world size expected within Trainer + p.register_hook(functools.partial(div_by_world_size, expert_normalization_term)) + + # Everything else gets wrapped as normal. + layer = fsdp_wrap(layer, min_num_params=min_num_params) + return layer + @register_model("transformer") class TransformerModel(FairseqEncoderDecoderModel): """ @@ -191,6 +220,47 @@ def add_args(parser): help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + 'minimum number of params for a layer to be wrapped with FSDP() when ' + 'training with --ddp-backend=fully_sharded. Smaller values will ' + 'improve memory efficiency, but may make torch.distributed ' + 'communication less efficient due to smaller input sizes. This option ' + 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' + '--offload-activations are passed.' + ) + ) + # args for mixture-of-expert layers + parser.add_argument('--moe-freq', type=int, metavar='D', default=0, + help='Frequency at which we insert MoE Transformer layers') + parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0, + help='Frequency at which we insert MoE Transformer encoder layers') + parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0, + help='Frequency at which we insert MoE Transformer decoder layers') + parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0, + help='Number of experts in each MoE Layer') + parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true', + help="Use FP32 computations in MoE top2 gating function") + parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', + help="policy for second expert, options: all/sampling/random") + parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', + help="whether to normalize gate probs before or after dropping experts for capacity and randomization") + parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, + help="MoE Expert FFN dimension") + parser.add_argument('--moe-top1-expert', default=False, action='store_true', + help="Use top1 gate instead of top2") + parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25, + help="Fraction of tokens as capacity during validation" + \ + "if set to negative, use same as training. range: (0.0, 1.0].") + parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', + help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") + parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', + help="Don't route padding tokens to any expert") + # args for pseudo-MoE layers + parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0, + help="FFN embed dim of alternate pseudo-MoE blocks") # fmt: on @classmethod @@ -242,16 +312,25 @@ def build_model(cls, args, task): encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) if not args.share_all_embeddings: - encoder = fsdp_wrap(encoder, min_num_params=1e8) - decoder = fsdp_wrap(decoder, min_num_params=1e8) + min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded + encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() - - emb = Embedding(num_embeddings, embed_dim, padding_idx) + if getattr(args, 'use_stable_embedding', False): + import bitsandbytes as bnb + if not args.no_scale_embedding: + logger.warning('It is recommended to pass --no-scale-embedding with --use-stable-embedding') + emb = bnb.nn.StableEmbedding(num_embeddings, embed_dim, padding_idx) + else: + emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) @@ -317,6 +396,10 @@ def get_normalized_probs( return self.get_normalized_probs_scriptable(net_output, log_probs, sample) +def div_by_world_size(world_size, tensor): + return tensor / world_size + + class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer @@ -375,9 +458,10 @@ def __init__(self, args, dictionary, embed_tokens): self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) - self.layers.extend( - [self.build_encoder_layer(args) for i in range(args.encoder_layers)] - ) + moe_freq = max(getattr(args, 'encoder_moe_freq', 0), getattr(args, 'moe_freq', 0)) + for i in range(args.encoder_layers): + is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 + self.layers.append(self.build_encoder_layer(args, is_moe_layer=is_moe_layer)) self.num_layers = len(self.layers) if args.encoder_normalize_before: @@ -385,12 +469,22 @@ def __init__(self, args, dictionary, embed_tokens): else: self.layer_norm = None - def build_encoder_layer(self, args): - layer = TransformerEncoderLayer(args) - if getattr(args, "checkpoint_activations", False): + def build_encoder_layer(self, args, is_moe_layer=False): + layer = TransformerEncoderLayer(args, is_moe_layer=is_moe_layer) + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + if not is_moe_layer or getattr(args, "ddp_backend", None) != "fully_sharded": + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + else: + layer = fsdp_wrap_expert(args, layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( @@ -485,7 +579,7 @@ def forward_scriptable( x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation - if encoder_padding_mask is not None: + if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C @@ -497,13 +591,15 @@ def forward_scriptable( encoder_states.append(x) # encoder layers + l_aux = [] for layer in self.layers: - x = layer( + x, l_aux_i = layer( x, encoder_padding_mask=encoder_padding_mask if has_pads else None ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) + l_aux.append(l_aux_i) if self.layer_norm is not None: x = self.layer_norm(x) @@ -519,6 +615,7 @@ def forward_scriptable( "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], + "l_aux": l_aux, } @torch.jit.export @@ -678,12 +775,15 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) - self.layers.extend( - [ - self.build_decoder_layer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ] - ) + moe_freq = max(getattr(args, 'decoder_moe_freq', 0), getattr(args, 'moe_freq', 0)) + for i in range(args.decoder_layers): + is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 + self.layers.append( + self.build_decoder_layer( + args, no_encoder_attn=no_encoder_attn, is_moe_layer=is_moe_layer, + ) + ) + self.num_layers = len(self.layers) if args.decoder_normalize_before and not getattr( @@ -726,14 +826,64 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) - def build_decoder_layer(self, args, no_encoder_attn=False): - layer = TransformerDecoderLayer(args, no_encoder_attn) - if getattr(args, "checkpoint_activations", False): + + def build_decoder_layer(self, args, no_encoder_attn=False, is_moe_layer=False): + layer = TransformerDecoderLayer( + args, no_encoder_attn=no_encoder_attn, is_moe_layer=is_moe_layer + ) + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) - layer = fsdp_wrap(layer, min_num_params=1e8) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + if not is_moe_layer or getattr(args, "ddp_backend", None) != "fully_sharded": + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + else: + layer = fsdp_wrap_expert(args, layer, min_num_params=min_params_to_wrap) return layer + def forward_embedding( + self, + tokens, + token_embedding: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + # embed tokens and positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions(tokens, incremental_state=incremental_state) + + if incremental_state is not None: + tokens = tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + if token_embedding is None: + token_embedding = self.embed_tokens(tokens) + + x = embed = self.embed_scale * token_embedding + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + return x, embed + def forward( self, prev_output_tokens, @@ -745,8 +895,13 @@ def forward( alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[Tensor] = None, ): """ + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing @@ -758,6 +913,14 @@ def forward( applying output layer (default: False). full_context_alignment (bool, optional): don't apply auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + self_attn_padding_mask (torch.Tensor, optional): precomputed padding + mask for self-attention (default None will recompute mask) Returns: tuple: @@ -771,6 +934,8 @@ def forward( full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, + token_embeddings=token_embeddings, + self_attn_padding_mask=self_attn_padding_mask, ) if not features_only: x = self.output_layer(x) @@ -784,22 +949,20 @@ def extract_features( full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, + token_embeddings: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[Tensor] = None, ): return self.extract_features_scriptable( prev_output_tokens, - encoder_out, - incremental_state, - full_context_alignment, - alignment_layer, - alignment_heads, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + token_embeddings=token_embeddings, + self_attn_padding_mask=self_attn_padding_mask, ) - """ - A scriptable subclass of this class has an extract_features method and calls - super().extract_features, but super() is not supported in torchscript. A copy of - this function is made to be used in the subclass instead. - """ - def extract_features_scriptable( self, prev_output_tokens, @@ -808,75 +971,49 @@ def extract_features_scriptable( full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, + token_embeddings: Optional[Tensor] = None, + self_attn_padding_mask: Optional[Tensor] = None, ): """ - Similar to *forward* but only return features. - - Includes several features from "Jointly Learning to Align and - Translate with Transformer Models" (Garg et al., EMNLP 2019). - - Args: - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - alignment_layer (int, optional): return mean alignment over - heads at this layer (default: last layer). - alignment_heads (int, optional): only average alignment over - this many heads (default: all heads). - - Returns: - tuple: - - the decoder's features of shape `(batch, tgt_len, embed_dim)` - - a dictionary with any model-specific outputs + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy + of this function is made to be used in the subclass instead. """ if alignment_layer is None: alignment_layer = self.num_layers - 1 - # embed positions - positions = None - if self.embed_positions is not None: - positions = self.embed_positions( - prev_output_tokens, incremental_state=incremental_state + # compute self-attention padding mask (involves device-to-host transfer, + # so put it at the top of the forward) + if ( + self_attn_padding_mask is None + and ( + self.cross_self_attention + or prev_output_tokens.device.type == "xla" + or prev_output_tokens.eq(self.padding_idx).any() ) - - if incremental_state is not None: - prev_output_tokens = prev_output_tokens[:, -1:] - if positions is not None: - positions = positions[:, -1:] + ): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # embed tokens and positions - x = self.embed_scale * self.embed_tokens(prev_output_tokens) - - if self.quant_noise is not None: - x = self.quant_noise(x) - - if self.project_in_dim is not None: - x = self.project_in_dim(x) - - if positions is not None: - x += positions - - if self.layernorm_embedding is not None: - x = self.layernorm_embedding(x) - - x = self.dropout_module(x) + x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) # B x T x C -> T x B x C x = x.transpose(0, 1) - self_attn_padding_mask: Optional[Tensor] = None - if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): - self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) - # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] + if encoder_out is None: + l_aux = [] + else: + l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None - x, layer_attn, _ = layer( + x, layer_attn, _, l_aux_i = layer( x, encoder_out["encoder_out"][0] if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) @@ -893,6 +1030,7 @@ def extract_features_scriptable( need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), ) + l_aux.append(l_aux_i) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) @@ -913,7 +1051,7 @@ def extract_features_scriptable( if self.project_out_dim is not None: x = self.project_out_dim(x) - return x, {"attn": [attn], "inner_states": inner_states} + return x, {"attn": [attn], "inner_states": inner_states, "l_aux": l_aux} def output_layer(self, features): """Project features to the vocabulary size.""" @@ -1071,6 +1209,8 @@ def base_architecture(args): args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.is_moe = getattr(args, "is_moe", False) + args.selected_expert_count = getattr(args, "selected_expert_count", 2) @register_model_architecture("transformer", "transformer_iwslt_de_en") diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index f12470d033..f26f75f522 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -14,13 +14,15 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.transformer import ( + DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder, +) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from omegaconf import II - DEFAULT_MAX_TARGET_POSITIONS = 1024 - +import logging +logger = logging.getLogger(__name__) @dataclass class TransformerLanguageModelConfig(FairseqDataclass): @@ -126,15 +128,6 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) - decoder_layerdrop: float = field( - default=0.0, metadata={"help": "LayerDrop probability for decoder"} - ) - decoder_layers_to_keep: Optional[str] = field( - default=None, - metadata={ - "help": "which layers to *keep* when pruning as a comma-separated list" - }, - ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) @@ -148,6 +141,17 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) + # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + # config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -156,17 +160,134 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=8, metadata={"help": "block size of quantization noise at training time"}, ) - # TODO common var add to parent quant_noise_scalar: float = field( default=0.0, metadata={ "help": "scalar quantization noise and scalar quantization at training time" }, ) + # config for Fully Sharded Data Parallel (FSDP) training + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": ( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + ) + } + ) + # Mixture of Expert Layer arguments + alternate_decoder_ffn_embed_dim: int = field( + default=0, + metadata={ + "help": "decoder FFN embed dim of alternate decoder layers" + }, + ) + moe_freq: int = field( + default=0, + metadata={ + "help": "Frequency at which we insert MoE Transformer layers" + }, + ) + moe_expert_count: int = field( + default=0, + metadata={ + "help": "Number of experts in each MoE Layer" + } + ) + moe_gating_use_fp32: bool = field( + default=False, + metadata={ + "help": "Use FP32 computations in MoE top2 gating function" + } + ) + moe_second_expert_policy: str = field( + default='sampling', + metadata={ + "help": "policy for second expert, options: all/sampling/random" + } + ) + moe_normalize_gate_prob_before_dropping: bool = field( + default=False, + metadata={ + "help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization' + } + ) + moe_expert_ffn_dim: Optional[int] = field( + default=None, + metadata={ + "help": "MoE expert FFN dimension" + } + ) + moe_top1_expert: Optional[bool] = field( + default=False, + metadata={ + "help": "Use top1 gate instead of top2" + } + ) + moe_eval_capacity_token_fraction: Optional[float] = field( + default=0.25, + metadata={ + "help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." + } + ) + moe_normalize_expert_grad: Optional[str] = field( + default='world_size', + metadata={ + "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" + } + ) + use_moe_pad_mask: Optional[bool] = field( + default=False, + metadata={ + "help": "Don't route padding tokens to any expert", + } + ) + record_a2a_perf_stats: Optional[bool] = field( + default=False, metadata={"help": "records all to all perf stats during distributed training"} + ) + dummy_a2a: Optional[bool] = field( + default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} + ) + moe_batch_prioritized_routing: Optional[bool] = field( + default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} + ) + use_stable_embedding: Optional[bool] = field( + default=False, + metadata={"help": 'Use bitsandbytes StableEmbeddingLayer which saves embedding state in fp32', + 'argparse_alias': "--stable-emb"} + ) + + # options from other parts of the config + + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} + ) + # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") tpu: bool = II("common.tpu") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") + fp16: bool = II("common.fp16") + fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads") + ddp_backend: str = II("distributed_training.ddp_backend") + world_size: int = II("distributed_training.distributed_world_size") + distributed_rank: int = II("distributed_training.distributed_rank") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") + @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) @@ -212,9 +333,6 @@ def __init__(self, decoder): def build_model(cls, args, task): """Build a new model instance.""" - # make sure all arguments are present in older models - base_lm_architecture(args) - if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) @@ -257,15 +375,31 @@ def build_model(cls, args, task): ) assert args.decoder_input_dim == args.decoder_output_dim + if ( + getattr(args, 'moe_freq', 0) > 0 + and ( + getattr(args, 'fp16', False) + and not getattr(args, 'memory_efficient_fp16', False) + and getattr(args, 'ddp_backend', None) != "fully_sharded" + ) + ): + assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" + decoder = TransformerDecoder( - args, task.target_dictionary, embed_tokens, no_encoder_attn=True + args, task.target_dictionary, embed_tokens, no_encoder_attn=True, ) return cls(decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): - embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) - return embed_tokens + if getattr(args, 'use_stable_embedding', False): + import bitsandbytes as bnb + if not args.no_scale_embedding: + logger.warning('It is recommended to pass --no-scale-embedding with --use-stable-embedding') + return bnb.nn.StableEmbedding(len(dictionary), embed_dim, dictionary.pad()) + + else: + return Embedding(len(dictionary), embed_dim, dictionary.pad()) def base_lm_architecture(args): @@ -298,6 +432,10 @@ def base_lm_architecture(args): args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.base_layers = getattr(args, "base_layers", 0) + args.base_sublayers = getattr(args, "base_sublayers", 1) + args.base_shuffle = getattr(args, "base_shuffle", False) + args.add_bos_token = getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False @@ -422,9 +560,118 @@ def transformer_lm_gpt2_medium(args): def transformer_lm_gpt2_big(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400) - args.decoder_layers = getattr(args, "decoder_layers", 48) + args.decoder_layers = getattr(args, "decoder_layers", 40) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25) args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big_wide") +def transformer_lm_gpt2_big_wide(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_bigger") +def transformer_lm_gpt2_bigger(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_layers = getattr(args, "decoder_layers", 48) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +def base_gpt3_architecture(args): + args.decoder_input_dim = args.decoder_embed_dim + args.decoder_output_dim = args.decoder_embed_dim + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + # GPT-3 used learned positional embeddings, rather than sinusoidal + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.share_decoder_input_output_embed = True + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") +def transformer_lm_gpt3_small(args): + # 125M params + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") +def transformer_lm_gpt3_medium(args): + # 350M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") +def transformer_lm_gpt3_large(args): + # 760M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") +def transformer_lm_gpt3_xl(args): + # 1.3B params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") +def transformer_lm_gpt3_2_7(args): + # 2.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") +def transformer_lm_gpt3_6_7(args): + # 6.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") +def transformer_lm_gpt3_13(args): + # 13B params + args.decoder_layers = getattr(args, "decoder_layers", 40) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 40) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") +def transformer_lm_gpt3_175(args): + # 175B params + args.decoder_layers = getattr(args, "decoder_layers", 96) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 12288) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 96) + base_gpt3_architecture(args) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index ae07dcfaa0..2697e126be 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -3,219 +3,29 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools -from typing import Any, Dict, List, Tuple, Union -import torch -import torch.utils.checkpoint as checkpoint - -from fairseq import utils - - -def checkpoint_wrapper(m, offload_to_cpu=False): - """ - A friendlier wrapper for performing activation checkpointing. - - Compared to the PyTorch version, this version: - - wraps an nn.Module, so that all subsequent calls will use checkpointing - - handles keyword arguments in the forward - - handles non-Tensor outputs from the forward - - Usage:: - - checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) - a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) - """ - m.forward = functools.partial( - _checkpointed_forward, - m.forward, # original_forward - offload_to_cpu, - ) - return m - - -def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): - # Autograd Functions in PyTorch work best with positional args, since - # the backward must return gradients (or None) for every input argument. - # We can flatten keyword arguments to make this easier. - kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) - parent_ctx_dict = {"offload": offload_to_cpu} - output = CheckpointFunction.apply( - original_forward, parent_ctx_dict, kwarg_keys, *flat_args - ) - if isinstance(output, torch.Tensor): - return output - else: - packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] - if packed_non_tensor_outputs: - output = unpack_non_tensors(output, packed_non_tensor_outputs) - return output - - -def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: - """ - Usage:: - - kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) - args, kwargs = unpack_kwargs(kwarg_keys, flat_args) - assert args == [1, 2] - assert kwargs == {"a": 3, "b": 4} - """ - kwarg_keys = [] - flat_args = list(args) - for k, v in kwargs.items(): - kwarg_keys.append(k) - flat_args.append(v) - return kwarg_keys, flat_args - - -def unpack_kwargs( - kwarg_keys: List[str], flat_args: List[Any] -) -> Tuple[List[Any], Dict[str, Any]]: - if len(kwarg_keys) == 0: - return flat_args, {} - args = flat_args[: -len(kwarg_keys)] - kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} - return args, kwargs - - -def split_non_tensors( - mixed: Union[torch.Tensor, Tuple[Any]] -) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: - """ - Usage:: - - x = torch.Tensor([1]) - y = torch.Tensor([2]) - tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) - recon = unpack_non_tensors(tensors, packed_non_tensors) - assert recon == (x, y, None, 3) - """ - if isinstance(mixed, torch.Tensor): - return (mixed,), None - tensors = [] - packed_non_tensors = {"is_tensor": [], "objects": []} - for o in mixed: - if isinstance(o, torch.Tensor): - packed_non_tensors["is_tensor"].append(True) - tensors.append(o) - else: - packed_non_tensors["is_tensor"].append(False) - packed_non_tensors["objects"].append(o) - return tuple(tensors), packed_non_tensors - - -def unpack_non_tensors( - tensors: Tuple[torch.Tensor], - packed_non_tensors: Dict[str, List[Any]], -) -> Tuple[Any]: - if packed_non_tensors is None: - return tensors - assert isinstance(packed_non_tensors, dict) - mixed = [] - is_tensor_list = packed_non_tensors["is_tensor"] - objects = packed_non_tensors["objects"] - assert len(tensors) + len(objects) == len(is_tensor_list) - obj_i = tnsr_i = 0 - for is_tensor in is_tensor_list: - if is_tensor: - mixed.append(tensors[tnsr_i]) - tnsr_i += 1 - else: - mixed.append(objects[obj_i]) - obj_i += 1 - return tuple(mixed) - - -class CheckpointFunction(torch.autograd.Function): - """Similar to the torch version, but support non-Tensor outputs. - - The caller is expected to provide a dict (*parent_ctx_dict*) that will hold - the non-Tensor outputs. These should be combined with the Tensor *outputs* - by calling ``unpack_non_tensors``. - """ - - @staticmethod - def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): - if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation - checkpoint.check_backward_validity(args) - - ctx.run_function = run_function - ctx.kwarg_keys = kwarg_keys - ctx.fwd_rng_state = utils.get_rng_state() - - tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) - if parent_ctx_dict["offload"]: - ctx.fwd_device = tuple(x.device for x in tensor_inputs) - ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) - tensor_inputs = tuple(x.cpu() for x in tensor_inputs) - - else: - ctx.fwd_device, ctx.grad_requirements = None, None - - ctx.save_for_backward(*tensor_inputs) - ctx.packed_non_tensor_inputs = packed_non_tensor_inputs - - with torch.no_grad(): - unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) - outputs = run_function(*unpacked_args, **unpacked_kwargs) - - if isinstance(outputs, torch.Tensor): - return outputs - else: - # Autograd Functions don't like non-Tensor outputs. We can split the - # non-Tensor and Tensor outputs, returning the former by reference - # through *parent_ctx_dict* and returning the latter directly. - outputs, packed_non_tensor_outputs = split_non_tensors(outputs) - parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad(), please use .backward() if possible" +def checkpoint_wrapper(module, *args, **kwargs): + try: + from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper as _checkpoint_wrapper + except ImportError: + try: + from fairscale.nn import checkpoint_wrapper as _checkpoint_wrapper + except ImportError: + raise ImportError( + "Cannot find fairscale.nn.misc.checkpoint_activations. " + "Please install fairscale with: pip install fairscale" ) - tensor_inputs: Tuple = ctx.saved_tensors - tensor_inputs = checkpoint.detach_variable(tensor_inputs) - if ctx.fwd_device is not None: - tensor_inputs = [ - t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs) - ] - for i, need_grad in enumerate(ctx.grad_requirements): - tensor_inputs[i].requires_grad = need_grad - inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + module = _checkpoint_wrapper(module, *args, **kwargs) - # Store the current states. - bwd_rng_state = utils.get_rng_state() - - # Set the states to what it used to be before the forward pass. - utils.set_rng_state(ctx.fwd_rng_state) - - with torch.enable_grad(): - unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) - outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) - tensor_outputs, _ = split_non_tensors(outputs) - # Set the states back to what it was at the start of this function. - utils.set_rng_state(bwd_rng_state) + if hasattr(module, "extra_repr"): + orig_extra_repr = module.extra_repr + else: + orig_extra_repr = None - # Run backward() with only Tensors that require grad - outputs_with_grad = [] - args_with_grad = [] - for i in range(len(tensor_outputs)): - if tensor_outputs[i].requires_grad: - outputs_with_grad.append(tensor_outputs[i]) - args_with_grad.append(args[i]) - if len(outputs_with_grad) == 0: - raise RuntimeError( - "None of the outputs have requires_grad=True, " - "this checkpoint() is not necessary" - ) + def extra_repr(): + return f"[checkpointed] {orig_extra_repr()}" if orig_extra_repr is not None else "" - torch.autograd.backward(outputs_with_grad, args_with_grad) + module.extra_repr = extra_repr - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs - ) - return (None, None, None) + grads + return module diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py index f070a804e6..a7da9f1bc5 100644 --- a/fairseq/modules/fairseq_dropout.py +++ b/fairseq/modules/fairseq_dropout.py @@ -20,8 +20,11 @@ def __init__(self, p, module_name=None): self.module_name = module_name self.apply_during_inference = False + def extra_repr(self) -> str: + return "p={}".format(self.p) + def forward(self, x, inplace: bool = False): - if self.training or self.apply_during_inference: + if self.p > 0 and (self.training or self.apply_during_inference): return F.dropout(x, p=self.p, training=True, inplace=inplace) else: return x diff --git a/fairseq/modules/fused_bias_gelu.py b/fairseq/modules/fused_bias_gelu.py new file mode 100644 index 0000000000..750579ca9d --- /dev/null +++ b/fairseq/modules/fused_bias_gelu.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +import logging +import torch + + +logger = logging.getLogger(__name__) + + +try: + from megatron.model.fused_bias_gelu import bias_gelu_impl + has_fused_bias_gelu = True +except ImportError: + has_fused_bias_gelu = False + + +def load_megatron_fused_kernel(): + """Compile and load fused kernels from Megatron.""" + if getattr(load_megatron_fused_kernel, "has_run", False): + return + load_megatron_fused_kernel.has_run = True + + from megatron import fused_kernels + from argparse import Namespace + + if not torch.distributed.is_initialized(): + args = Namespace(rank=0, masked_softmax_fusion=True) + fused_kernels.load(args) + return + + global_rank = torch.distributed.get_rank() + args = Namespace(rank=global_rank, masked_softmax_fusion=True) + + # Always build on rank zero first. + if global_rank == 0: + logger.info("Compiling and loading fused kernels") + fused_kernels.load(args) + torch.distributed.barrier() + else: + torch.distributed.barrier() + fused_kernels.load(args) + + # Simple barrier to make sure all ranks have passed the + # compilation phase successfully before moving on to the + # rest of the program. We think this might ensure that + # the lock is released. + torch.distributed.barrier() + + logger.info("Done with compiling and loading fused kernels.") + + +def fused_bias_gelu(x, bias): + if not has_fused_bias_gelu: + raise ImportError( + "Cannot find fused Megatron kernels, please install Megatron from: " + "github.com/NVIDIA/Megatron-LM" + ) + load_megatron_fused_kernel() + return bias_gelu_impl(x, bias) diff --git a/fairseq/modules/moe/__init__.py b/fairseq/modules/moe/__init__.py new file mode 100644 index 0000000000..acab13a168 --- /dev/null +++ b/fairseq/modules/moe/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from .moe_layer import MOELayer +from .top2gate import Top2Gate +from .top1gate import Top1Gate diff --git a/fairseq/modules/moe/moe_layer.py b/fairseq/modules/moe/moe_layer.py new file mode 100644 index 0000000000..1c1668d3b9 --- /dev/null +++ b/fairseq/modules/moe/moe_layer.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# NOTE: This is a mirror of the code in +# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe + +import logging +import time +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.cuda import Event as CudaEvent +from torch.nn import Module, ModuleList +from fairseq import distributed_utils + +if TYPE_CHECKING: + Base = Module[Tensor] +else: + Base = Module + + +logger = logging.getLogger(__name__) + + +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. + +# Based on https://github.com/pytorch/pytorch/pull/40762 +class _AllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore + ctx.group = group + input = input.contiguous() + output = torch.empty_like(input) + if torch.distributed.is_initialized(): + dist.all_to_all_single(output, input, group=group) + else: + assert group is None + output = input + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, _AllToAll.apply(ctx.group, *grad_output)) + + +class MOELayer(Base): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + moe = MOELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (torch.nn.Module): + gate network + expert (torch.nn.Module): + expert network + """ + + def __init__(self, gate: Module, experts: Union[Module, ModuleList], args, group: Optional[Any] = None, all2all_group: Optional[Any] = None) -> None: + super().__init__() + self.gate = gate + if type(experts) == ModuleList: + self.experts = cast(ModuleList, experts) + else: + self.experts = ModuleList([experts]) + self.expert_group = group if group is not None else distributed_utils.get_moe_group(args.moe_expert_count) + self.all2all_group = all2all_group if all2all_group is not None else distributed_utils.get_all2all_group(args.moe_expert_count) + for p in experts.parameters(): + p.expert = True # type: ignore + self.world_size = distributed_utils.get_world_size(self.expert_group) + self.all2all_size = distributed_utils.get_world_size(self.all2all_group) + self.num_local_experts = len(self.experts) + self.args = args + self.in_generation = False + self.a2a_cuda_event_intervals = [] + self.a2a_cpu_time_ms = 0.0 + + def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor: + assert len(input) == 1, "only single input Tensor supported" + input = input[0] + assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" + if input_padding_mask is not None: + assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken" + assert input_padding_mask.shape[0] == input.shape[0] + assert input_padding_mask.shape[1] == input.shape[1] + # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" + + # Implement Algorithm 2 from GShard paper. + d_model = input.shape[2] + # Pad to expected batch size + input_shape = list(input.shape) + expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0) + # This indicates that --batch-size or --max-sentences is not specified + if expected_bsz is None: + expected_bsz = 0 + # Note: Padding is not necessary at generation time at present + # because all DDP workers process the same batch. Also, batch size at generation time + # can be different from that present in the checkpoint state + if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz: + logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})") + assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}" + padded_input = torch.zeros( + (expected_bsz, input_shape[1], input_shape[2]), + dtype=input.dtype, layout=input.layout, device=input.device) + padded_input[:input_shape[0], :, :] = input + input = padded_input + + padded_input_padding_mask = torch.ones( + (expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device + ) + if input_padding_mask is not None: + padded_input_padding_mask[:input_shape[0], :] = input_padding_mask + else: + padded_input_padding_mask[:input_shape[0], :] = False + input_padding_mask = padded_input_padding_mask + + # Reshape into S tokens by dropping sequence dimension. + reshaped_input = input.reshape(-1, d_model) + reshaped_input_shape = reshaped_input.shape + reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None + + # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences + # Pro of --max-tokens: more flexible for MT variable sequence lengths + # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM + if expected_bsz == 0: + expected_dim = int(distributed_utils.all_reduce( + reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device), + group=dist.group.WORLD, + op="max", + ).item()) + padded_input = torch.zeros( + (expected_dim, reshaped_input_shape[1]), + dtype=input.dtype, layout=input.layout, device=input.device) + padded_input[:reshaped_input_shape[0], :] = reshaped_input + reshaped_input = padded_input + + padded_input_padding_mask = torch.ones( + (expected_dim,), dtype=torch.bool, device=padded_input.device + ) + if reshaped_input_padding_mask is not None: + padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask + else: + padded_input_padding_mask[:reshaped_input_shape[0]] = False + reshaped_input_padding_mask = padded_input_padding_mask + + l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask) + + dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S + E, C, S = dispatch_mask.size() + M = reshaped_input.size(1) + assert reshaped_input.size() == (S, M) + # einsum("sec,sm->ecm") + dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M + + if self.all2all_size > 1: + dispatched_input = self.all_to_all_wrapper(dispatched_input) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model) + chunks = dispatched_input.chunk(self.num_local_experts, dim=1) + expert_outputs = [] + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + expert_output = torch.cat(expert_outputs, dim=1) + + if self.all2all_size > 1: + expert_output = self.all_to_all_wrapper(expert_output) + + # Re-shape back: gecm -> ecm + expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model) + + # einsum("sec,ecm->sm") + combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M)) + + # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences + combined_output = combined_output[:reshaped_input_shape[0], :] + combined_output = combined_output.reshape(input.shape) + combined_output = combined_output[:input_shape[0], :, :] + + self.record_all_to_all_stats() + + return combined_output, l_aux + + def prepare_for_inference_(self): + self.in_generation = True + + def all_to_all_wrapper(self, input: Tensor): + dummy_a2a = getattr(self.args, 'dummy_a2a', False) + if dummy_a2a: + input = input.contiguous() + output = input.detach().clone() + return input + # always record times, since it is not a lot of overhead + # if we do not log it we simply clear it off in record_all_to_all_stats + cuda_start = torch.cuda.Event(enable_timing=True) + cuda_end = torch.cuda.Event(enable_timing=True) + cpu_start = time.time() * 1000 + cuda_start.record() + output = _AllToAll.apply(self.all2all_group, input) + cuda_end.record() + cpu_end = time.time() * 1000 + self.a2a_cpu_time_ms += (cpu_end - cpu_start) + self.a2a_cuda_event_intervals.append((cuda_start, cuda_end)) + return output + + def record_all_to_all_stats(self): + # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize() + record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False) + if record_a2a_perf_stats: + torch.cuda.synchronize() + self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms + a2a_cuda_time_ms = 0.0 + for ev_start, ev_end in self.a2a_cuda_event_intervals: + a2a_cuda_time_ms += ev_start.elapsed_time(ev_end) + self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms + # reset stats + self.a2a_cpu_time_ms = 0.0 + self.a2a_cuda_event_intervals = [] diff --git a/fairseq/modules/moe/top1gate.py b/fairseq/modules/moe/top1gate.py new file mode 100644 index 0000000000..2a79b5fa60 --- /dev/null +++ b/fairseq/modules/moe/top1gate.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf +# Code is inspired by Top2GatingOnLogits from lingvo: +# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477 + +# NOTE: This is a mirror of the code in +# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe + +from typing import Callable, Dict, Tuple, Optional + +import math +import torch +from torch import Tensor +import torch.nn.functional as F + +from .top2gate import one_hot, entropy + + +# maximum capacity of 1 expert as a fraction of number of tokens in the batch +# Note: setting this to 1.0 causes inference to significantly slow down +EVAL_CAPACITY_TOKEN_FRACTION = 0.25 + +# logging +SAMPLE_FRACTION = 0.2 + + +def top1gating( + logits: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + use_fp32=False, + capacity_factor=1.0, + eval_mode=False, + moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, +) -> Tuple[Tensor, Tensor, Tensor, Dict]: + """Implements Top2Gating on logits.""" + metadata = {} + if use_fp32: + orig_dtype = logits.dtype + logits = logits.float() + + gates = F.softmax(logits, dim=1) + metadata["entropy_gating"] = entropy(probs=gates).mean().detach() + + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + if moe_eval_capacity_token_fraction > 0.0 and eval_mode: + capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) + else: + # capacity = capacity_factor * S/E + capacity = int(capacity_factor * math.ceil(num_tokens / num_experts)) + + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1) + mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True) + if input_mask is not None and input_mask.any(): + nonpadding = ~ input_mask + mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) + + # for logging (percent of tokens routed to each expert) + expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens + metadata["unused_expert1_count"] = (expert1_hist == 0).sum() + expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny + + sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) + metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() + metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() + + + gates1_s = (gates * mask1).sum(dim=1) + + # Compute locations in capacity buffer + locations1 = torch.cumsum(mask1, dim=0) - 1 + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.to(gates.dtype), dim=0) + l_aux = torch.mean(me * ce) + l_aux = l_aux * num_experts * num_experts + # Remove locations outside capacity from mask + mask1 = mask1 * torch.lt(locations1, capacity) + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se") + # locations1_sc = num_tokens * capacity + locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) + combine1_sec = torch.bmm( + # einsum("se,sc->sec") + gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1) + ) + dispatch_mask = combine1_sec.bool() + if use_fp32: + return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata + else: + return l_aux, combine1_sec, dispatch_mask, metadata + + +class Top1Gate(torch.nn.Module): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + wg: torch.nn.Linear + + def __init__( + self, + model_dim: int, + num_experts: int, + use_fp32=False, + input_noise_type=None, + capacity_factor=1.0, + moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, + ) -> None: + # TODO: merge this to top2gate.py + # + super().__init__() + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) + self.use_fp32 = use_fp32 + self.input_noise_type = input_noise_type + self.capacity_factor = capacity_factor + self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction + + def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None,) -> Tuple[Tensor, Tensor, Tensor, Dict]: # type: ignore + logits = self.wg(input) + return top1gating( + logits, + mask, + use_fp32=self.use_fp32, + capacity_factor=self.capacity_factor, + eval_mode=not self.training, + moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, + ) diff --git a/fairseq/modules/moe/top2gate.py b/fairseq/modules/moe/top2gate.py new file mode 100644 index 0000000000..0f88d5af3c --- /dev/null +++ b/fairseq/modules/moe/top2gate.py @@ -0,0 +1,249 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf +# Code is inspired by Top2GatingOnLogits from lingvo: +# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477 + +# NOTE: This is a mirror of the code in +# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe + +from typing import Callable, Dict, Tuple, Optional + +import math +import torch +from torch import Tensor +from torch.distributions import Categorical +import torch.nn.functional as F + + +gumbel_map: Dict[torch.device, Callable] = {} + +# logging +SAMPLE_FRACTION = 0.2 + + +def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: + gumbel = gumbel_map.get(device) + if gumbel is None: + one = torch.tensor(1.0, device=device) + zero = torch.tensor(0.0, device=device) + gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore + gumbel_map[device] = gumbel + return gumbel(shape) + + +def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor: + if unsqueeze_indices: + indices = indices.unsqueeze(-1) + assert indices.shape[-1] == 1, "last dimension of indices must be have size 1" + output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype) + output.scatter_( + len(output.shape) - 1, indices, 1 + ) + return output + + +def entropy(probs): + logits = torch.distributions.utils.probs_to_logits(probs) + p_log_p = probs * logits + return -p_log_p.sum(-1) + + +def top2gating( + logits: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + use_fp32=False, + second_expert_policy='sampling', + normalize_gate_prob_before_dropping=False, + eval_mode=False, + moe_eval_capacity_token_fraction=0.25, + batch_prioritized_routing=False, +) -> Tuple[Tensor, Tensor, Tensor]: + """Implements Top2Gating on logits.""" + metadata = {} + if use_fp32: + orig_dtype = logits.dtype + logits = logits.float() + gates = F.softmax(logits, dim=1) + metadata["entropy_gating"] = entropy(probs=gates).mean().detach() + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + if moe_eval_capacity_token_fraction > 0.0 and eval_mode: + capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) + else: + # capacity = 2S/E + capacity = 2 * math.ceil(num_tokens / num_experts) + + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1, keepdim=True) + mask1 = one_hot(indices1_s, num_experts) + if second_expert_policy == 'sampling': + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + else: + logits_w_noise = logits + # Replace top-expert with min value + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True) + mask2 = one_hot(indices2_s, num_experts) + gates1_s = (gates * mask1).sum(dim=1) + gates2_s = (gates * mask2).sum(dim=1) + + if normalize_gate_prob_before_dropping: + # Normalize gate probabilities + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) + gates1_s = gates1_s / denom_s + gates2_s = gates2_s / denom_s + + if second_expert_policy == 'random': + sampled = (2 * gates2_s) > torch.rand_like(gates2_s) + mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0) + + # Compute locations in capacity buffer + if input_mask is not None and input_mask.any(): + nonpadding = ~ input_mask + mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) + mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype) + + if batch_prioritized_routing: + # if batch_prioritized_routing: + importance_scores = -1 * gates.max(dim=1)[0] + sorted_mask1 = mask1[importance_scores.argsort(dim=0)] + sorted_cumsum1 = (torch.cumsum(sorted_mask1, dim=0) - 1) * sorted_mask1 + importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] + + sorted_mask2 = mask2[importance_scores.argsort(dim=0)] + sorted_cumsum2 = (torch.cumsum(sorted_mask2, dim=0) - 1) * sorted_mask2 + importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] + + importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True) + + locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2 + else: + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(mask1, dim=0, keepdim=True) + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.to(gates.dtype), dim=0) + l_aux = torch.mean(me * ce) + l_aux = l_aux * num_experts * num_experts + + + # for logging purposes + metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1) + metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2) + + # Remove locations outside capacity from mask + mask1 = mask1 * torch.lt(locations1, capacity) + mask2 = mask2 * torch.lt(locations2, capacity) + + # for logging (percent of tokens routed to each expert) + expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens + metadata["unused_expert1_count"] = (expert1_hist == 0).sum() + expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny + + expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens + metadata["unused_expert2_count"] = (expert2_hist == 0).sum() + expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny + + sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) + metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() + metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() + + + metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum() + metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum() + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + + if not normalize_gate_prob_before_dropping: + # Normalize gate probabilities + gates1_s = (gates * mask1).sum(dim=1) + gates2_s = (gates * mask2).sum(dim=1) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se") + gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se") + locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) + locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True) + combine1_sec = torch.bmm( + # einsum("se,sc->sec") + gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1) + ) + combine2_sec = torch.bmm( + # einsum("se,sc->sec") + gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1) + ) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.bool() + if use_fp32: + return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata + else: + return l_aux, combine_weights, dispatch_mask, metadata + + +class Top2Gate(torch.nn.Module): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + wg: torch.nn.Linear + + def __init__( + self, + model_dim: int, + num_experts: int, + use_fp32=False, + second_expert_policy='sampling', + normalize_gate_prob_before_dropping=False, + moe_eval_capacity_token_fraction=0.25, + batch_prioritized_routing=False, + ) -> None: + super().__init__() + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) + self.use_fp32 = use_fp32 + self.second_expert_policy = second_expert_policy + self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping + self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction + self.batch_prioritized_routing = batch_prioritized_routing + + def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None,) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + logits = self.wg(input) + return top2gating( + logits, + mask, + use_fp32=self.use_fp32, + second_expert_policy=self.second_expert_policy, + normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping, + eval_mode=not self.training, + moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, + batch_prioritized_routing=self.batch_prioritized_routing, + ) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 6ab86245d2..1fcba9a958 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -147,8 +147,15 @@ def forward( is_tpu = query.device.type == "xla" tgt_len, bsz, embed_dim = query.size() - assert embed_dim == self.embed_dim + src_len = tgt_len + assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] if ( not self.onnx_trace @@ -262,6 +269,7 @@ def forward( else: assert k is not None k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None @@ -290,7 +298,7 @@ def forward( assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None - src_len = k.size(1) + assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. @@ -327,6 +335,10 @@ def forward( assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: + # Replace any non-finite values with finite equivalents, since otherwise + # we may get NaN when adding attn_mask or computing softmax. + attn_weights = torch.nan_to_num(attn_weights) + attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) @@ -395,21 +407,27 @@ def _append_prev_key_padding_mask( # leaves the frame, there will be a time when prev or current # is None elif prev_key_padding_mask is not None: - filler = torch.zeros( - (batch_size, src_len - prev_key_padding_mask.size(1)), - device=prev_key_padding_mask.device, - ) - new_key_padding_mask = torch.cat( - [prev_key_padding_mask.float(), filler.float()], dim=1 - ) + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: - filler = torch.zeros( - (batch_size, src_len - key_padding_mask.size(1)), - device=key_padding_mask.device, - ) - new_key_padding_mask = torch.cat( - [filler.float(), key_padding_mask.float()], dim=1 - ) + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index f9ada37bde..2d56ea0acf 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -7,12 +7,102 @@ import torch import torch.nn as nn -from fairseq import utils -from fairseq.modules import LayerNorm, MultiheadAttention +import torch.nn.functional as F +from fairseq import distributed_utils as dist_utils, utils +from fairseq.modules import gelu, LayerNorm, MultiheadAttention from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.moe import Top1Gate, Top2Gate, MOELayer from fairseq.modules.quant_noise import quant_noise +from fairseq.modules.fused_bias_gelu import fused_bias_gelu, has_fused_bias_gelu from torch import Tensor +def _linear(x, weight, bias=None): + return F.linear(x, weight, bias) + + +def _ffn( + x, + fc1, + activation_fn, + activation_dropout_module, + fc2, + dropout_module, +): + x_shape = x.shape + x = x.reshape(-1, x.size(-1)) + if has_fused_bias_gelu and activation_fn == gelu: + x = _linear(x, fc1.weight) + x = fused_bias_gelu(x, fc1.bias) + x = activation_dropout_module(x) + x = _linear(x, fc2.weight, fc2.bias) + else: + x = _linear(x, fc1.weight, fc1.bias) + x = activation_fn(x) + x = activation_dropout_module(x) + x = _linear(x, fc2.weight, fc2.bias) + x = x.view(x_shape) + x = dropout_module(x) + return x + + +class FeedForwardNetwork(nn.Module): + """ + Feed Forward Network layer in the Transformer model + """ + def __init__(self, args, embed_dim, ffn_dim, dropout_module=None): + super().__init__() + self.embed_dim = embed_dim + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.fc1 = self.build_fc1( + self.embed_dim, + ffn_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + ffn_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) if not dropout_module else dropout_module + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def forward(self, x): + return _ffn( + x, + fc1=self.fc1, + activation_fn=self.activation_fn, + activation_dropout_module=self.activation_dropout_module, + fc2=self.fc2, + dropout_module=self.dropout_module, + ) + return x + class TransformerEncoderLayer(nn.Module): """Encoder layer block. @@ -29,7 +119,7 @@ class TransformerEncoderLayer(nn.Module): args (argparse.Namespace): parsed command-line arguments """ - def __init__(self, args): + def __init__(self, args, is_moe_layer=False): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim @@ -40,30 +130,57 @@ def __init__(self, args): self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) - self.activation_fn = utils.get_activation_fn( - activation=getattr(args, 'activation_fn', 'relu') or "relu" - ) - activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 - if activation_dropout_p == 0: - # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 - self.activation_dropout_module = FairseqDropout( - float(activation_dropout_p), module_name=self.__class__.__name__ - ) self.normalize_before = args.encoder_normalize_before - self.fc1 = self.build_fc1( - self.embed_dim, - args.encoder_ffn_embed_dim, - self.quant_noise, - self.quant_noise_block_size, - ) - self.fc2 = self.build_fc2( - args.encoder_ffn_embed_dim, - self.embed_dim, - self.quant_noise, - self.quant_noise_block_size, - ) - + self.is_moe_layer = is_moe_layer + ffn_dim = args.encoder_ffn_embed_dim + if self.is_moe_layer and getattr(args, "alternate_ffn_embed_dim", 0.0) > 0: + ffn_dim = getattr(args, "alternate_ffn_embed_dim", 0.0) + # the second condition is for a "pseudo" MoE layer + # (shared FFN with expert FFN dimension) that tries + # to replicate FLOPs used by an expert MoE layer with perfectly balanced load + if not self.is_moe_layer or getattr(args, "alternate_ffn_embed_dim", 0.0) > 0: + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') or "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.fc1 = self.build_fc1( + self.embed_dim, + ffn_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + ffn_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + else: + if args.moe_top1_expert: + gate = Top1Gate( + self.embed_dim, + args.moe_expert_count, + use_fp32=args.moe_gating_use_fp32, + moe_eval_capacity_token_fraction=getattr(args, "moe_eval_capacity_token_fraction", 0.25), + ) + else: + gate = Top2Gate( + self.embed_dim, + args.moe_expert_count, + args.moe_gating_use_fp32, + args.moe_second_expert_policy, + args.moe_normalize_gate_prob_before_dropping, + getattr(args, "moe_eval_capacity_token_fraction", 0.25), + getattr(args, "moe_batch_prioritized_routing", False), + ) + experts = make_experts(args, self.embed_dim, ffn_dim, self.dropout_module) + self.moe_layer = MOELayer(gate, experts, args) self.final_layer_norm = LayerNorm(self.embed_dim) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): @@ -146,14 +263,24 @@ def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional residual = x if self.normalize_before: x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) + if not self.is_moe_layer or getattr(self.args, "alternate_ffn_embed_dim", 0.0) > 0: + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + l_aux = None + else: + # x - seq_len, batch_size, model_dim + x = x.transpose(0, 1) # batch_size, seq_len, model_dim + if getattr(self.args, "use_moe_pad_mask", False): + x, l_aux = self.moe_layer(x, input_padding_mask=encoder_padding_mask) + else: + x, l_aux = self.moe_layer(x) + x = x.transpose(0, 1) # seq_len, batch_size, model_dim x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) - return x + return x, l_aux class TransformerDecoderLayer(nn.Module): @@ -174,7 +301,7 @@ class TransformerDecoderLayer(nn.Module): """ def __init__( - self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, is_moe_layer=False, ): super().__init__() self.embed_dim = args.decoder_embed_dim @@ -193,24 +320,13 @@ def __init__( add_zero_attn=add_zero_attn, ) - self.activation_fn = utils.get_activation_fn( - activation=str(args.activation_fn) - if getattr(args, "activation_fn", None) is not None - else "relu" - ) - activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 - if activation_dropout_p == 0: - # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 - self.activation_dropout_module = FairseqDropout( - float(activation_dropout_p), module_name=self.__class__.__name__ - ) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: @@ -220,24 +336,67 @@ def __init__( self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) - self.fc1 = self.build_fc1( - self.embed_dim, - args.decoder_ffn_embed_dim, - self.quant_noise, - self.quant_noise_block_size, - ) - self.fc2 = self.build_fc2( - args.decoder_ffn_embed_dim, - self.embed_dim, - self.quant_noise, - self.quant_noise_block_size, - ) + self.is_moe_layer = is_moe_layer + + ffn_dim = args.decoder_ffn_embed_dim + if self.is_moe_layer and getattr(args, "alternate_decoder_ffn_embed_dim", 0.0) > 0: + ffn_dim = getattr(args, "alternate_decoder_ffn_embed_dim", 0.0) + + if not self.is_moe_layer or getattr(args, "alternate_decoder_ffn_embed_dim", 0.0) > 0: + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.fc1 = self.build_fc1( + self.embed_dim, + ffn_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + ffn_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + else: + + if args.moe_top1_expert: + gate = Top1Gate( + self.embed_dim, + args.moe_expert_count, + use_fp32=args.moe_gating_use_fp32, + moe_eval_capacity_token_fraction=getattr(args, "moe_eval_capacity_token_fraction", 0.25), + ) + else: + gate = Top2Gate( + self.embed_dim, + args.moe_expert_count, + args.moe_gating_use_fp32, + args.moe_second_expert_policy, + args.moe_normalize_gate_prob_before_dropping, + getattr(args, "moe_eval_capacity_token_fraction", 0.25), + getattr(args, "moe_batch_prioritized_routing", False), + ) + experts = make_experts(args, self.embed_dim, ffn_dim, self.dropout_module) + self.moe_layer = MOELayer(gate, experts, args) + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False + self.args = args + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) @@ -390,11 +549,24 @@ def forward( residual = x if self.normalize_before: x = self.final_layer_norm(x) - - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) + if not self.is_moe_layer or getattr(self.args, "alternate_decoder_ffn_embed_dim", 0.0) > 0: + x = _ffn( + x, + fc1=self.fc1, + activation_fn=self.activation_fn, + activation_dropout_module=self.activation_dropout_module, + fc2=self.fc2, + dropout_module=self.dropout_module, + ) + l_aux = None + else: + # x - seq_len, batch_size, model_dim + x = x.transpose(0, 1) # batch_size, seq_len, model_dim + if getattr(self.args, "use_moe_pad_mask", False): + x, l_aux = self.moe_layer(x, input_padding_mask=self_attn_padding_mask) + else: + x, l_aux = self.moe_layer(x) + x = x.transpose(0, 1) # seq_len, batch_size, model_dim x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) @@ -410,7 +582,29 @@ def forward( else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state - return x, attn, None + return x, attn, None, l_aux def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn + + +def make_experts(args, embed_dim, expert_ffn_dim, dropout_module) -> nn.ModuleList: + world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() + expert_list = [] + ddp_rank = dist_utils.get_data_parallel_rank() + start_seed = torch.randint(1000000, (1,)).item() + # at least as many experts than gpus + if args.moe_expert_count >= world_size: + assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}' + local_moe_expert_count = args.moe_expert_count // world_size + for i in range(local_moe_expert_count): + with utils.set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): + expert_list.append(FeedForwardNetwork(args, embed_dim, expert_ffn_dim, dropout_module)) + # less experts than gpus + else: + assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}' + # initialize each FFN with the same seed on different GPUs + with utils.set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): + expert_list.append(FeedForwardNetwork(args, embed_dim, expert_ffn_dim, dropout_module)) + experts = nn.ModuleList(expert_list) + return experts diff --git a/fairseq/moe_checkpoint_utils.py b/fairseq/moe_checkpoint_utils.py new file mode 100644 index 0000000000..2616837711 --- /dev/null +++ b/fairseq/moe_checkpoint_utils.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +import torch +import numpy as np +from collections import defaultdict, OrderedDict +from glob import glob + +from fairseq import distributed_utils +from fairseq.file_io import torch_load_cpu +from typing import List, Dict +from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor + + +OPT_KEY = 'last_optimizer_state' +logger = logging.getLogger(__name__) + + +def merge_expert_and_shared_state(expert_state, shared_state): + state = {} + for key in ['cfg', 'args', 'extra_state', 'optimizer_history']: + state[key] = expert_state[key] + state['model'] = {**expert_state['model'], **shared_state['model']} + + if OPT_KEY in expert_state: + state[OPT_KEY] = {} + for key in ['loss_scale', 'param_groups']: + if key in expert_state[OPT_KEY]: + state[OPT_KEY][key] = expert_state[OPT_KEY][key] + + if 'param_id_map' in shared_state[OPT_KEY]: # FSDP + unflat_expert_state = _unflat_expert_tensor_state(expert_state[OPT_KEY], shared_state[OPT_KEY]) + state[OPT_KEY]['state'] = { + **shared_state[OPT_KEY]['state'], + **unflat_expert_state + } + + state[OPT_KEY].update({k: v for k, v in shared_state[OPT_KEY].items() + if k not in state[OPT_KEY]}) + else: + state[OPT_KEY]['state'] = { + **expert_state[OPT_KEY]['state'], + **shared_state[OPT_KEY]['state'], + } + return state + + +def split_shared_and_expert_states(model, optimizer): + model_state_dict = model.state_dict() + shared_model_state_dict = OrderedDict() + expert_model_state_dict = OrderedDict() + for name, value in model_state_dict.items(): + # TODO: this is a bit hacky - find a better way determine expert params + if 'expert' in name and 'expert_centroids' not in name: + expert_model_state_dict[name] = value + else: + shared_model_state_dict[name] = value + + shared_optimizer_state_dict = {} + expert_optimizer_state_dict = {} + optimizer_state_dict = optimizer.state_dict() + for key in ['param_groups', 'loss_scale']: + if key in optimizer_state_dict: + expert_optimizer_state_dict[key] = optimizer_state_dict[key] + shared_optimizer_state_dict[key] = optimizer_state_dict[key] + + param_mappings = {} + param_id_to_is_expert = {} + start_index = 0 + for group in optimizer.param_groups: + # nonlocal start_index + packed = {k: v for k, v in group.items() if k != 'params'} + for i, p in enumerate(group['params'], start_index): + if id(p) not in param_mappings: + param_mappings.update({id(p): i}) + param_id_to_is_expert[i] = hasattr(p, 'expert') or hasattr(p, 'base_expert') + packed['params'] = [param_mappings[id(p)] for p in group['params']] + start_index += len(packed['params']) + # return packed + + # param_groups = [pack_group(g) ] + expert_optimizer_state_dict['state'] = { + k: v for k, v in optimizer_state_dict['state'].items() + if param_id_to_is_expert[k] + } + shared_optimizer_state_dict['state'] = { + k: v for k, v in optimizer_state_dict['state'].items() + if not param_id_to_is_expert[k] + } + return ( + (shared_model_state_dict, shared_optimizer_state_dict), + (expert_model_state_dict, expert_optimizer_state_dict), + ) + + +def merge_multi_local_expert_states(expert_states: List[Dict]) -> Dict: + merged_expert_state = {} + for key in ['cfg', 'args', 'extra_state', 'optimizer_history']: + merged_expert_state[key] = expert_states[0][key] + + if OPT_KEY in expert_states[0]: + logger.warning( + "Not stitching last optimizer state while merging experts. " + "This is okay for inference but not for continued training. " + ) + + model_state_dict = {} + for expert_group_id, expert_state in enumerate(expert_states): + num_local_experts_in_chkpt = 1 + for key in expert_state['model']: + match = re.search(r"experts.([1-9][0-9]*)", key) + if match and int(match.groups()[0]) + 1 > num_local_experts_in_chkpt: + num_local_experts_in_chkpt = int(match.groups()[0]) + 1 + logger.info(f"found {num_local_experts_in_chkpt} local experts in expert_group_id={expert_group_id}") + for key, val in expert_state['model'].items(): + match = re.search(r"experts.([0-9][0-9]*)", key) + assert match is not None, "\"experts.([0-9][0-9]*)\" pattern expected in key {key}" + local_chkpt_expert_id = int(match.groups()[0]) + target_expert_id = expert_group_id * num_local_experts_in_chkpt + local_chkpt_expert_id + key = key.replace(f"experts.{local_chkpt_expert_id}", 'experts.{}'.format(target_expert_id)) + model_state_dict[key] = val + merged_expert_state['model'] = model_state_dict + return merged_expert_state + + +def load_expert_state(local_path): + checkpoint_files_count = len(glob(re.sub('rank-[0-9]+', 'rank-*', local_path))) + world_size = distributed_utils.get_data_parallel_world_size() + rank = distributed_utils.get_data_parallel_rank() + if world_size < checkpoint_files_count: + assert checkpoint_files_count % world_size == 0 + logger.info( + f"Found total {checkpoint_files_count} expert files and" + f" current distributed world size: {world_size}," + " Stitching experts to able to load on current world size." + ) + local_expert_count = int(checkpoint_files_count / world_size) + start_rank = local_expert_count * rank + expert_states = [] + for expert_rank in range(start_rank, start_rank + local_expert_count): + fname = re.sub( + 'rank-[0-9]+', + 'rank-{0}'.format(expert_rank), + local_path, + ) + expert_states.append(torch_load_cpu(fname)) + expert_state = merge_multi_local_expert_states(expert_states) + else: + expert_state = torch_load_cpu(local_path) + return expert_state + + +def assert_equal(a, b, msg=''): + assert a == b, f"{msg}{a} != {b}" + + +def _unflat_expert_tensor_state(expert, shared) -> Dict: + """called from merge_expert_and_shared_state, for FSDP only.""" + + local_to_globals = defaultdict(list) + for global_id, local_id in shared['param_id_map'].items(): + if local_id in shared['uncollected_local_ids']: + local_to_globals[local_id].append(global_id) + + flat_expert_state = expert['state'] + unflat_state = {} + for local_id, global_ids in local_to_globals.items(): + global_ids = sorted(global_ids) + unflat_state.update({g: {} for g in global_ids}) + already_unflat = {k: v for k, v in flat_expert_state[local_id].items() if not torch.is_tensor(v) or is_singleton_tensor(v)} + for buffer_name, flat_param in flat_expert_state[local_id].items(): + if torch.is_tensor(flat_param) and not is_singleton_tensor(flat_param): + unflat_shapes = [shared['state'][g][buffer_name].shape for g in global_ids] + numels = [np.prod(s) for s in unflat_shapes] + unflat = zip(global_ids, (t.view(s) for (t, s) in zip(flat_param.split(numels), unflat_shapes))) + for gid, t in unflat: + unflat_state[gid][buffer_name] = t + unflat_state[gid].update(already_unflat) + return unflat_state diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index c969b9fbc0..d4c5956002 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -29,6 +29,10 @@ def add_args(parser): help='decay rate of the second moment estimator') parser.add_argument('--beta1', type=float, default=None, metavar="B", help='beta for first moment estimator. Optional') + parser.add_argument('--first-moment-fp16', action='store_true', + help='store momentum in fp16') + parser.add_argument('--no-relative-lr', action='store_true', + help='skip section 8 of the paper') parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', help='weight decay') parser.add_argument('--scale-parameter', action='store_true', @@ -60,9 +64,11 @@ def optimizer_config(self): "scale_parameter": self.args.scale_parameter, # defaults to False "relative_step": self.args.relative_step, # defaults to False "warmup_init": self.args.warmup_init, + "first_moment_fp16": self.args.first_moment_fp16, + "no_relative_lr": self.args.no_relative_lr, } - +FLOAT16_MAX = 65504.0 class Adafactor(torch.optim.Optimizer): """Implements Adafactor algorithm. @@ -109,6 +115,8 @@ def __init__( scale_parameter=True, relative_step=True, warmup_init=False, + first_moment_fp16=False, + no_relative_lr=False, ): if lr is not None and relative_step: raise ValueError("Cannot combine manual lr and relative_step options") @@ -125,7 +133,10 @@ def __init__( scale_parameter=scale_parameter, relative_step=relative_step, warmup_init=warmup_init, + first_moment_fp16=first_moment_fp16, + ) + self.no_relative_lr = no_relative_lr super(Adafactor, self).__init__(params, defaults) @property @@ -138,7 +149,8 @@ def supports_flat_params(self): def _get_lr(self, param_group, param_state): rel_step_sz = param_group["lr"] - if param_group["relative_step"]: + if param_group["relative_step"]: # NOTE(SS): disable this + # disable scaling of learning rate relative to weight norms, which is the default feature in Adafactor. min_step = ( 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 ) @@ -151,7 +163,8 @@ def _get_lr(self, param_group, param_state): def _get_options(self, param_group, param_shape): factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None - return factored, use_first_moment + first_moment_fp16 = param_group['first_moment_fp16'] + return factored, use_first_moment, first_moment_fp16 def _rms(self, tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @@ -181,6 +194,9 @@ def step(self, closure=None): if p.grad is None: continue grad = p.grad.data + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: @@ -189,7 +205,7 @@ def step(self, closure=None): state = self.state[p] grad_shape = grad.shape - factored, use_first_moment = self._get_options(group, grad_shape) + factored, use_first_moment, first_moment_fp16 = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 @@ -197,11 +213,10 @@ def step(self, closure=None): if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) + state["exp_avg_scale"] = 1.0 if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros( - grad_shape[:-2] + grad_shape[-1:] - ).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) @@ -215,13 +230,11 @@ def step(self, closure=None): else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) - group["lr"] = self._get_lr(group, state) + if not self.no_relative_lr: + group["lr"] = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad ** 2) + group["eps"][0] @@ -251,7 +264,7 @@ def step(self, closure=None): update.mul_(group["lr"]) if use_first_moment: - exp_avg = state["exp_avg"] + exp_avg = state["exp_avg"].float() * state["exp_avg_scale"] exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) update = exp_avg @@ -265,4 +278,11 @@ def step(self, closure=None): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) + + # copied idea from fp16_adam_stats implem + # which copied from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + if first_moment_fp16: + state["exp_avg_scale"] = 1e-8 + torch.norm(exp_avg, float('inf')) / FLOAT16_MAX + state["exp_avg"] = (exp_avg / state["exp_avg_scale"]).half() + return loss diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index f73804718a..d5ba5785d5 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -33,9 +33,13 @@ class FairseqAdamConfig(FairseqDataclass): use_old_adam: bool = field( default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} ) + fp16_adam_stats: bool = field( + default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} + ) # TODO common vars below in parent tpu: bool = II("common.tpu") lr: List[float] = II("optimization.lr") + block_wise: bool = field(default=False, metadata={"help": "Enables block-wise optimization for 8-bit Adam"}) @register_optimizer("adam", dataclass=FairseqAdamConfig) @@ -56,13 +60,21 @@ def __init__(self, cfg: DictConfig, params): and torch.cuda.is_available() ) if getattr(cfg, "tpu", False): + if self.cfg.fp16_adam_stats: + raise NotImplementedError("--fp16-adam-stats is only supported on GPU") # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) elif use_fused_adam: logger.info("using FusedAdam") - self._optimizer = fused_adam_cls(params, **self.optimizer_config) + self._optimizer = fused_adam_cls( + params, + use_fp16_stats=self.cfg.fp16_adam_stats, + **self.optimizer_config + ) else: + if self.cfg.fp16_adam_stats: + raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") self._optimizer = Adam(params, **self.optimizer_config) @property @@ -83,7 +95,7 @@ def optimizer_config(self): } def average_params(self): - """Reduce Params is only used during BMUF distributed training.""" + """average Params is only used during BMUF distributed training.""" state_dict = self.optimizer.state_dict() total_gpus = float(dist.get_world_size()) @@ -93,6 +105,37 @@ def average_params(self): dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM) dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM) +@register_optimizer("adam8bit", dataclass=FairseqAdamConfig) +class FairseqAdam8Bit(FairseqOptimizer): + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError('adam8bit requires bits and bytes: see https://gist.github.com/TimDettmers/c4ffe346f095ee4481aa3d4b4ad2ffe0') + bnb.optim.GlobalOptimManager.get_instance().register_parameters(params) + self._optimizer = bnb.optim.Adam(params, optim_bits=8, **self.optimizer_config) # equivalent + + @property + def optimizer_config(self): + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, + "block_wise": self.cfg.block_wise, + } + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + class Adam(torch.optim.Optimizer): r"""Implements Adam algorithm. diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index 5e935df1a5..488b4fa6d9 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -12,15 +12,25 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from omegaconf import II, DictConfig +import logging try: - from deepspeed.ops.op_builder import CPUAdamBuilder - has_deepspeed_cpu_adam = True -except ImportError: - has_deepspeed_cpu_adam = False + import deepspeed + has_deepspeed = True +except ImportError as e: + has_deepspeed = False +def _get_cpu_adam(): + try: + from deepspeed.ops.op_builder import CPUAdamBuilder + return CPUAdamBuilder().load() + except ImportError: + # fbcode + from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam + return ds_opt_adam + @dataclass class FairseqCPUAdamConfig(FairseqDataclass): adam_betas: str = field( @@ -95,18 +105,22 @@ def __init__( self.use_fp16_stats = use_fp16_stats self.FLOAT16_MAX = 65504.0 - if not has_deepspeed_cpu_adam: + if not has_deepspeed: raise ImportError("Please install DeepSpeed: pip install deepspeed") self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 - self.ds_opt_adam = CPUAdamBuilder().load() + self.ds_opt_adam = _get_cpu_adam() adamw_mode = True self.ds_opt_adam.create_adam( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_memory_efficient_fp16(self): + return True + @property def supports_flat_params(self): return True @@ -118,6 +132,8 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + torch.cuda.synchronize() + for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group["params"]): if p.grad is None: diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 00ea1bbb76..1d6119b19e 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -64,6 +64,10 @@ def build_fp32_params(cls, args, params, flatten=True): fp32_params = [] for p in params: p32 = torch.nn.Parameter(p.data.float()) + if hasattr(p, 'expert'): + p32.expert = True + elif hasattr(p, 'base_expert'): + p32.base_expert = True p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): p32.param_group = p.param_group diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index e2b8e1bcd1..7a6d1f73d5 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -80,6 +80,7 @@ def __init__( weight_decay=0.0, max_grad_norm=0.0, amsgrad=False, + use_fp16_stats=False, ): global fused_adam_cuda import importlib @@ -99,6 +100,9 @@ def __init__( super().__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 + self.use_fp16_stats = use_fp16_stats + self.FLOAT16_MAX = 65504.0 + @property def supports_memory_efficient_fp16(self): return True @@ -173,29 +177,42 @@ def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): "please consider SparseAdam instead" ) - p_data_fp32 = p.data.float() + if p.device.type == "cpu": + p_data_fp32 = p.data.cuda(non_blocking=True).float() + out_p = torch.tensor([], dtype = torch.float) + else: + p_data_fp32 = p.data.float() + out_p = p.data state = self.state[p] # State initialization + dtype = torch.float16 if self.use_fp16_stats else p_data_fp32.dtype if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg"] = torch.zeros_like(p_data_fp32, dtype=dtype) # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32, dtype=dtype) + if self.use_fp16_stats: + state["exp_avg_scale"] = 1.0 + state["exp_avg_sq_scale"] = 1.0 else: - state["exp_avg"] = state["exp_avg"].to(p_data_fp32) - state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + device = p_data_fp32.device + state["exp_avg"] = state["exp_avg"].to(device, dtype) + state["exp_avg_sq"] = state["exp_avg_sq"].to(device, dtype) exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] + if self.use_fp16_stats: + assert exp_avg.dtype == torch.float16 + exp_avg = exp_avg.float() * state["exp_avg_scale"] + exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] beta1, beta2 = group["betas"] state["step"] += 1 - out_p = p.data - with torch.cuda.device(p.device): + with torch.cuda.device(p_data_fp32.device): fused_adam_cuda.adam( p_data_fp32, out_p, @@ -213,6 +230,23 @@ def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): group["weight_decay"], ) + if p.device.type == "cpu": + p.data.copy_(p_data_fp32, non_blocking=True) + + if self.use_fp16_stats: + def inf_norm(t): + return torch.norm(t, float("inf")) + + # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( + 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, + 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, + ) + state["exp_avg"], state["exp_avg_sq"] = ( + (exp_avg / state["exp_avg_scale"]).half(), + (exp_avg_sq / state["exp_avg_sq_scale"]).half(), + ) + return loss @@ -226,7 +260,9 @@ class FusedAdamV2(FusedAdam): and params to FP32 internally to support ``--memory-efficient-fp16``. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, use_fp16_stats=False, **kwargs): + if use_fp16_stats: + raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") super().__init__(*args, **kwargs) if not hasattr(self, "multi_tensor_adam"): raise Exception( diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 375b5277b9..43a1512a69 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List import torch -from fairseq import metrics, search, tokenizer, utils +from fairseq import metrics, search, tokenizer, utils, distributed_utils from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -210,7 +210,7 @@ def can_reuse_epoch_itr(self, dataset): def get_batch_iterator( self, - dataset, + dataset: FairseqDataset, max_tokens=None, max_sentences=None, max_positions=None, @@ -223,6 +223,7 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + batch_by_size=True, ): """ Get an iterator that yields batches of data from the given dataset. @@ -255,6 +256,9 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + batch_by_size (bool, optional): + batch sequences of similar length together to reduce padding. + If false, each batch will be of size max_sentences. Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split @@ -281,14 +285,18 @@ def get_batch_iterator( indices, dataset, max_positions, ignore_invalid_inputs ) - # create mini-batches with given size constraints - batch_sampler = dataset.batch_by_size( - indices, - max_tokens=max_tokens, - max_sentences=max_sentences, - required_batch_size_multiple=required_batch_size_multiple, - ) - + if batch_by_size: + # create mini-batches with given size constraints + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + assert max_sentences is not None, 'If batch_by_size=False, max_sentences must be passed. Got None' + starts = indices[::max_sentences] + batch_sampler = [indices[s: s + max_sentences] for s in starts] # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, @@ -348,6 +356,7 @@ def build_generator( return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), + compute_vocab_dist=getattr(args, "compute_vocab_dist", False) ) from fairseq.sequence_generator import ( diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 579bf69785..db71a10834 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -68,6 +68,9 @@ class LanguageModelingConfig(FairseqDataclass): add_bos_token: bool = field( default=False, metadata={"help": "prepend beginning of sentence token ()"} ) + max_source_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the source sequence"} + ) max_target_positions: Optional[int] = field( default=None, metadata={"help": "max number of tokens in the target sequence"} ) @@ -84,13 +87,24 @@ class LanguageModelingConfig(FairseqDataclass): 'e.g., "train,valid" (default: all dataset splits)' }, ) + pad_to_fixed_length: Optional[bool] = field( + default=False, metadata={"help": "pad to fixed length"}, + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, metadata={"help": "boolean to pad to fixed batch size"}, + ) + # TODO common vars below add to parent seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( "dataset.dataset_impl" ) data_buffer_size: int = II("dataset.data_buffer_size") tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") @register_task("language_modeling", dataclass=LanguageModelingConfig) @@ -179,7 +193,7 @@ def build_model(self, args): for target in self.targets: if target not in model.supported_targets: raise ValueError( - "Unsupported language modeling target: {}".format(target) + f"Unsupported language modeling target: {target} not in {model.supported_targets}" ) return model @@ -190,7 +204,7 @@ def load_dataset( """Load a given dataset split. Args: - split (str): name of the split (e.g., train, valid, test) + split (str): name of the split (e.g., train, valid, valid1, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 @@ -198,13 +212,12 @@ def load_dataset( data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) + # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: - raise FileNotFoundError( - "Dataset not found: {} ({})".format(split, split_path) - ) + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, @@ -214,7 +227,6 @@ def load_dataset( self.args.tokens_per_sample, self.args.seed, ) - dataset = TokenBlockDataset( dataset, dataset.sizes, @@ -223,12 +235,22 @@ def load_dataset( eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = self.args.batch_size_valid if 'valid' in split else self.args.batch_size self.datasets[split] = MonolingualDataset( dataset=dataset, @@ -239,6 +261,8 @@ def load_dataset( shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4d47d39897..f42d31402e 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List import torch -from fairseq import checkpoint_utils, models, optim, utils +from fairseq import checkpoint_utils, moe_checkpoint_utils, models, optim, utils from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import utils as distributed_utils @@ -26,7 +26,7 @@ from fairseq.optim import lr_scheduler from omegaconf import OmegaConf - +import re logger = logging.getLogger(__name__) @@ -63,7 +63,8 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): else: self.device = torch.device("cpu") - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.is_fsdp: + import fairscale if self.cfg.common.bf16: raise ValueError( "FullyShardedDataParallel is not compatible with --bf16 or " @@ -74,6 +75,16 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "FullyShardedDataParallel is not compatible with --zero-sharding " "option (it's already built in)" ) + if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0": + raise RuntimeError( + "Please update to fairscale 0.4.0 or newer when combining " + "--update-freq with FullyShardedDataParallel" + ) + if self.cfg.optimizer == 'adam8bit': + assert self.use_sharded_state, 'adam8bit + FSDP requires --use-sharded-state' + if self.use_sharded_state: + import fairscale + assert fairscale.__version__ >= '0.3.9', '--use-sharded-state requires newer fairscale. pip install -U fairscale' else: if self.cfg.distributed_training.cpu_offload: raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") @@ -81,7 +92,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.distributed_training.ddp_backend != "fully_sharded": + if not self.is_fsdp: if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() @@ -111,6 +122,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "detected shared parameter: {} <- {}".format(shared_param[0], path) ) _set_module_by_path(self._model, path, ref) + logger.info(metrics.get_nvidia_smi_gpu_memory_stats_str()) self._dummy_batch = None # indicates we don't have a dummy batch at first self._lr_scheduler = None @@ -188,14 +200,16 @@ def use_distributed_wrapper(self) -> bool: self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf ) or ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" + self.is_fsdp and self.cfg.distributed_training.cpu_offload ) @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + has_alt_ffn_dim = getattr(self.cfg.model, "alternate_decoder_ffn_embed_dim", 0) != 0 + if (self.is_fsdp or (self.is_moe and not has_alt_ffn_dim) or + getattr(self.cfg.model, "base_layers", 0) > 0): return True else: return self.is_data_parallel_master @@ -203,7 +217,9 @@ def should_save_checkpoint_on_current_rank(self) -> bool: @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if (self.is_moe or self.is_base_moe) and not self.use_sharded_state: + return self.cfg.checkpoint.checkpoint_suffix + elif self.is_fsdp: return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) else: return self.cfg.checkpoint.checkpoint_suffix or "" @@ -259,10 +275,7 @@ def _build_optimizer(self): ) ) - if ( - self.cfg.distributed_training.ddp_backend == "fully_sharded" - and self.cfg.common.fp16 - ): + if self.is_fsdp and self.cfg.common.fp16: # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, # mostly for the grad scaling. But if we don't have the # --memory-efficient-fp16 flag set, then we're effectively doing @@ -292,7 +305,7 @@ def _build_optimizer(self): logger.info("NOTE: your device may support faster training with --fp16") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.is_fsdp: assert not self.cfg.optimization.use_bmuf, \ "--ddp-backend=fully_sharded is not compatible with BMUF" assert self._optimizer.supports_flat_params, ( @@ -329,55 +342,141 @@ def _build_optimizer(self): ) self._lr_scheduler.step_update(0) + @property + def is_fsdp(self): + return self.cfg.distributed_training.ddp_backend == "fully_sharded" + + @property + def is_moe(self): + return getattr(self.cfg.model, "moe_freq", 0) > 0 + + @property + def is_base_moe(self) -> bool: + return getattr(self.cfg.model, "base_layers", 0) > 0 + def use_sharded_state(self): + return self.cfg.distributed_training.use_sharded_state + def consolidate_optimizer(self): """For OSS, we need to consolidate the state dict.""" + self._gathered_optim_state = None + if self.cfg.checkpoint.no_save_optimizer_state: + return if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + elif self.is_fsdp and not self.use_sharded_state: + st = self.model.gather_full_optim_state_dict(self.optimizer) # only returns on rank 0 + if st is None: + st = -1 # sentinel so that workers do not save optimizer.state_dict() + self._gathered_optim_state = st + assert self._gathered_optim_state is not None + + def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]: + if self.is_moe or self.is_base_moe: + ( + (shared_model_state_dict, shared_optimizer_state_dict), + (expert_model_state_dict, expert_optimizer_state_dict), + ) = moe_checkpoint_utils.split_shared_and_expert_states( + self.model, + self.optimizer, + ) + model_save_list = [( + filename, + expert_model_state_dict, + expert_optimizer_state_dict, + )] + if self.is_data_parallel_master: + if self.is_fsdp and not self.use_sharded_state: + assert self._gathered_optim_state is not None + if 'loss_scale' in expert_model_state_dict: + self._gathered_optim_state['loss_scale'] = expert_model_state_dict['loss_scale'] + model_save_list.append(( + filename.replace('rank-0', 'shared'), + shared_model_state_dict, + self._gathered_optim_state + )) + self._gathered_optim_state = None # let it get garbage collected + else: + model_save_list.append(( + filename.replace('rank-0', 'shared'), + shared_model_state_dict, + shared_optimizer_state_dict, # local + )) + elif self.is_fsdp and self.use_sharded_state: + model_save_list.append(( + filename.replace(f'rank-{self.data_parallel_rank}', 'shared'), + shared_model_state_dict, + shared_optimizer_state_dict, + )) - def state_dict(self): - state_dict = { - "args": None, # legacy - "cfg": ( - OmegaConf.to_container(self.cfg) - if OmegaConf.is_config(self.cfg) else self.cfg - ), - "model": self.model.state_dict(), - "criterion": ( - self.criterion.state_dict() - if utils.has_parameters(self.criterion) else None - ), - "optimizer_history": (self._optim_history or []) - + [ - { - "criterion_name": self.get_criterion().__class__.__name__, - "optimizer_name": self.optimizer.__class__.__name__, - "lr_scheduler_state": self.lr_scheduler.state_dict(), - "num_updates": self.get_num_updates(), + else: + model_state_dict = self.model.state_dict() + optim_state = None + if not self.cfg.checkpoint.no_save_optimizer_state: + optim_state = self._gathered_optim_state or self.optimizer.state_dict() + model_save_list = [( + filename, + model_state_dict, + optim_state, + )] + state_dicts = {} + for filename, model_state_dict, optimizer_state_dict in model_save_list: + state_dict = { + "args": None, # legacy + "cfg": OmegaConf.to_container(self.cfg) if OmegaConf.is_config(self.cfg) else self.cfg, + "model": model_state_dict, + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), } - ], - "task_state": self.task.state_dict() if self.task is not None else {}, - "extra_state": { - "metrics": metrics.state_dict(), - "previous_training_time": self.cumulative_training_time(), } - } - if not self.cfg.checkpoint.no_save_optimizer_state: - state_dict["last_optimizer_state"] = self.optimizer.state_dict() - return state_dict + if ( + not self.cfg.checkpoint.no_save_optimizer_state or + ( + self.cfg.checkpoint.no_save_optimizer_state_on_training_finished + and training_finished + ) + ): + state_dict["last_optimizer_state"] = optimizer_state_dict + + if self.is_fsdp and self.use_sharded_state: + state_dict['shard_metadata'] = self.model.local_metadata_dict() # save FSDP flattening and padding info + state_dicts[filename] = state_dict + return state_dicts - def save_checkpoint(self, filename, extra_state): + def save_checkpoint(self, filename, extra_state, training_finished=False, async_callback_fn=None): """Save all training state in a checkpoint file.""" - logger.info(f"Saving checkpoint to {filename}") # call state_dict on all ranks in case it needs internal communication - state_dict = utils.move_to_cpu(self.state_dict()) - state_dict["extra_state"].update(extra_state) - if self.should_save_checkpoint_on_current_rank: - checkpoint_utils.torch_persistent_save( + state_dicts = self.state_dict(filename, training_finished) + for filename, state_dict in state_dicts.items(): + logger.info(f"Saving checkpoint to {filename}") + state_dict = utils.move_to_cpu( state_dict, - filename, - async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, + # keep params in FP16 when training with --memory-efficient-fp16 + cast_to_fp32=not self.cfg.common.memory_efficient_fp16, ) - logger.info(f"Finished saving checkpoint to {filename}") + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, + filename, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, + async_callback_fn=async_callback_fn, + ) + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, @@ -394,28 +493,37 @@ def load_checkpoint( """ extra_state, self._optim_history, last_optim_state = None, [], None - logger.info(f"Preparing to load checkpoint {filename}") + is_distributed = self.data_parallel_world_size > 1 bexists = PathManager.isfile(filename) if bexists: + logger.info(f"Preparing to load checkpoint {filename}") load_on_all_ranks = ( self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu # FSDP requires loading checkpoint shards on all ranks - or self.cfg.distributed_training.ddp_backend == "fully_sharded" + or self.is_fsdp + or getattr(self.cfg.model, "base_layers", 0) > 0 ) - if load_on_all_ranks or self.data_parallel_rank == 0: + if load_on_all_ranks or self.is_data_parallel_master or self.is_moe or self.is_base_moe: state = checkpoint_utils.load_checkpoint_to_cpu( - filename, load_on_all_ranks=load_on_all_ranks + filename, + load_on_all_ranks=load_on_all_ranks, + is_moe=self.is_moe or self.is_base_moe, ) last_optim_state = state.get("last_optimizer_state", None) + if last_optim_state == -1: + master_path = re.sub("shard[0-9]+", "shard0", filename) + last_optim_state = torch.load(master_path, map_location='cpu')['last_optimizer_state'] + + logger.info(f"Loaded state for {filename}") # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank - # to avoid memory from exploding. + # to avoid memory exploding. if ( not load_on_all_ranks and self.cfg.distributed_training.zero_sharding == "os" @@ -427,7 +535,14 @@ def load_checkpoint( last_optim_state = None state = None - if is_distributed and not load_on_all_ranks: + if ( + self.data_parallel_world_size > 1 + and not load_on_all_ranks + # disable on TPUs until they support broadcast + and not self.tpu + and not self.is_moe + and not self.is_base_moe + ): state = distributed_utils.broadcast_object( state, src_rank=0, @@ -442,10 +557,14 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) + # save memory for later steps + del state["model"] if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) + del state["criterion"] + except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " @@ -474,8 +593,12 @@ def load_checkpoint( last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) - self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + elif self.is_fsdp and not self.use_sharded_state: + last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state) + logger.info(f"FSDP got shard from optim_state for {filename}") + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + logger.info(f"Loaded optim_state for {filename}") self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: @@ -500,12 +623,7 @@ def load_checkpoint( if isinstance(meter, meters.TimeMeter): meter.reset() - logger.info( - "Loaded checkpoint {} (epoch {} @ {} updates)".format( - filename, epoch, self.get_num_updates() - ) - ) - + logger.info(f"Loaded checkpoint {filename} (epoch {epoch} @ {self.get_num_updates()} updates)") else: logger.info("No existing checkpoint found {}".format(filename)) @@ -570,7 +688,7 @@ def get_valid_iterator( seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.cfg.dataset.num_workers, + num_workers=self.cfg.dataset.num_workers_valid, # always pass a fixed "epoch" to keep validation data consistent # across training epochs epoch=1, @@ -622,6 +740,23 @@ def train_step(self, samples, raise_oom=False): for i, sample in enumerate(samples): # delayed update loop sample, is_dummy_batch = self._prepare_sample(sample) + # MoE training with --batch-size or --max-sentences set + if self.is_moe and getattr(self.cfg.dataset, 'batch_size', None) is not None: + try: + fixed_src_seq_length = getattr(self.cfg.task, 'tokens_per_sample', None) or self.cfg.task.max_source_positions + assert sample['net_input']['src_tokens'].shape[1] == fixed_src_seq_length + except: + logger.warning(str(sample.keys())) + logger.warning(str(sample['net_input'].keys())) + logger.warning(is_dummy_batch) + logger.warning( + "wrong seq len {} on rank {}".format( + sample['net_input']['src_tokens'].shape[1], + torch.distributed.get_rank(), + ) + ) + raise AssertionError + def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we @@ -632,6 +767,11 @@ def maybe_no_sync(): self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 + # The no_sync context manager results in increased memory + # usage with FSDP, since full-size gradients will be + # accumulated on each GPU. It's typically a better tradeoff + # to do the extra communication with FSDP. + and not self.is_fsdp ): return self.model.no_sync() else: @@ -710,6 +850,7 @@ def maybe_no_sync(): ) overflow = False + logger.debug(f"[{self.get_num_updates()}] done with fwd, bwd") try: with torch.autograd.profiler.record_function("reduce-grads"): # reduce gradients across workers @@ -746,6 +887,8 @@ def maybe_no_sync(): if ( not self.cfg.optimization.use_bmuf and self.cfg.distributed_training.ddp_backend != "slow_mo" + and not self.is_moe + and not self.is_base_moe ): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): @@ -757,6 +900,7 @@ def maybe_no_sync(): self.task.optimizer_step( self.optimizer, model=self.model, update_num=self.get_num_updates() ) + logger.debug(f"[{self.get_num_updates()}] done with optimizer step") except FloatingPointError: # re-run the forward and backward pass with hooks attached to print @@ -833,12 +977,7 @@ def maybe_no_sync(): else: if self.cuda and self.cuda_env is not None: # log minimum free memory over the iteration - gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 - torch.cuda.reset_peak_memory_stats() - gb_free = self.cuda_env.total_memory_in_GB - gb_used - metrics.log_scalar( - "gb_free", gb_free, priority=1500, round=1, weight=0 - ) + self._log_gpu_mem_stats() # log stats logging_output = self._reduce_and_log_stats( @@ -885,6 +1024,13 @@ def valid_step(self, sample, raise_oom=False): sample, is_dummy_batch = self._prepare_sample(sample) try: + if (getattr(self.cfg.model, 'moe_freq', 0) > 0 and + getattr(self.cfg.dataset, 'batch_size', None) is not None): + fixed_src_seq_length = getattr(self.cfg.task, 'tokens_per_sample', None) or \ + self.cfg.task.max_source_positions + assert sample['net_input']['src_tokens'].shape[1] == fixed_src_seq_length, \ + f"got src_seq_length {sample['net_input']['src_tokens'].shape[1]}, " + \ + f"expected {fixed_src_seq_length}" _loss, sample_size, logging_output = self.task.valid_step( sample, self.model, self.criterion ) @@ -920,7 +1066,6 @@ def valid_step(self, sample, raise_oom=False): # log validation stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) - return logging_output def zero_grad(self): @@ -1015,21 +1160,7 @@ def set_num_updates(self, num_updates): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - - def agg_norm_fn(total_norm): - if self.cfg.distributed_training.ddp_backend == "fully_sharded": - total_norm = total_norm ** 2 - if ( - self.data_parallel_process_group is not None - or torch.distributed.is_initialized() - ): - total_norm = distributed_utils.all_reduce( - total_norm.cuda(), group=self.data_parallel_process_group - ) - total_norm = total_norm ** 0.5 - return total_norm - - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) + return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) def cumulative_training_time(self): if self._cumulative_training_time is None: @@ -1162,10 +1293,7 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is @@ -1278,6 +1406,7 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] + return logging_output def _check_xla_compilation(self): @@ -1296,6 +1425,27 @@ def _check_xla_compilation(self): ) self._num_xla_compiles = num_xla_compiles + def _log_gpu_mem_stats(self): + # log minimum free memory over the iteration + cuda_gb_allocated = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 + torch.cuda.reset_peak_memory_stats() + cuda_gb_free = self.cuda_env.total_memory_in_GB - cuda_gb_allocated + metrics.log_scalar( + "cuda_gb_allocated", cuda_gb_allocated, priority=1500, round=1, weight=0 + ) + metrics.log_scalar( + "cuda_gb_reserved", cuda_gb_reserved, priority=1500, round=1, weight=0 + ) + metrics.log_scalar( + "cuda_gb_free", cuda_gb_free, priority=1500, round=1, weight=0 + ) + # log nvidia smi stats + if self.cfg.common.log_nvidia_smi: + nvidia_smi_stats = metrics.nvidia_smi_gpu_memory_stats() + for key, val in nvidia_smi_stats.items(): + metrics.log_scalar(key, val, priority=1500, round=1, weight=0) + def _catalog_shared_params(module, memo=None, prefix=""): if memo is None: diff --git a/fairseq/utils.py b/fairseq/utils.py index d4bf73648b..a321c75129 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -15,9 +15,10 @@ from itertools import accumulate from typing import Callable, Dict, List, Optional +import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F -from fairseq.modules.multihead_attention import MultiheadAttention from torch import Tensor @@ -109,11 +110,11 @@ def _move_to_cuda(tensor): return apply_to_sample(_move_to_cuda, sample) -def move_to_cpu(sample): +def move_to_cpu(sample, cast_to_fp32=True): def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. - if tensor.dtype in {torch.bfloat16, torch.float16}: + if cast_to_fp32 and tensor.dtype in {torch.bfloat16, torch.float16}: tensor = tensor.to(dtype=torch.float32) return tensor.cpu() @@ -121,7 +122,7 @@ def _move_to_cpu(tensor): def get_incremental_state( - module: MultiheadAttention, + module, # type: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: @@ -130,7 +131,7 @@ def get_incremental_state( def set_incremental_state( - module: MultiheadAttention, + module, # type: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], @@ -248,6 +249,9 @@ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx +def assert_equal(a, b, msg=''): + assert a == b, f"{msg}{a} != {b}" + def strip_pad(tensor, pad): return tensor[tensor.ne(pad)] @@ -324,17 +328,28 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: @torch.no_grad() def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None if isinstance(params, torch.Tensor): params = [params] params = list(params) - grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] + params = list(filter(grad_exists, params)) + grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], [] + for p in params: + if hasattr(p, "expert"): + expert_grads.append(p.grad.detach()) + elif hasattr(p, "base_expert"): + base_expert_grads.append(p.grad.detach()) + elif hasattr(p, "_is_sharded"): + sharded_grads.append(p.grad.detach()) + else: + grads.append(p.grad.detach()) if len(grads) == 0: if len(params) > 0: - return params[0].new_tensor(0.0) + total_norm = params[0].new_tensor(0.0) else: - return torch.tensor(0.0) - - if len(grads) == 1: + total_norm = torch.tensor(0.0) + elif len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: if multi_tensor_l2norm_available: @@ -356,13 +371,27 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: ) ) + # calculate split_norm and all_reduce with other workers + norms = [total_norm] + for split_grads in [expert_grads, sharded_grads]: + if len(split_grads) == 0: + continue + split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])) + if dist.is_initialized(): + split_norm.pow_(2) + dist.all_reduce(split_norm) + split_norm.sqrt_() + norms.append(split_norm) + if len(norms) > 1: + total_norm = torch.norm(torch.stack(norms)) + if aggregate_norm_fn is not None: total_norm = aggregate_norm_fn(total_norm) if max_norm > 0: max_norm = float(max_norm) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) - for g in grads: + for g in grads + expert_grads + sharded_grads + base_expert_grads: g.mul_(clip_coef) return total_norm @@ -738,3 +767,23 @@ def eval_bool(x, default=False): return bool(eval(x)) except TypeError: return default + + +def print_r0(x, file=None): + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print(x, file=file, flush=True) + + +def round_safe(x): + if torch.is_tensor(x): + return float(np.round(x.cpu().numpy(), 4)) + else: + try: + return round(x, 4) + except Exception: + return x + +def print_mem(msg): + gb_denom = 1024**3 + mem_info = f'max_GB: {torch.cuda.max_memory_allocated()/gb_denom:.1f}, current_GB: {torch.cuda.memory_allocated()/gb_denom:.1f}' + print_r0(f'{msg}: {mem_info}') diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 4501cac67e..a0420a39f4 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -9,20 +9,25 @@ """ import logging +import json import math import os import sys from argparse import Namespace +from textwrap import indent from typing import Iterable, List, Optional import torch import fairseq +from fairseq.file_io import save_json +from fairseq.utils import round_safe from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter from fairseq.sequence_scorer import SequenceScorer from omegaconf import DictConfig +import time logging.basicConfig( @@ -45,6 +50,7 @@ def eval_lm( softmax_batch: int = False, remove_bos_token: bool = False, device: Optional[torch.device] = None, + max_valid_steps=None, ): """ Args: @@ -74,6 +80,7 @@ def eval_lm( device (Optional[torch.device]): device to use for evaluation (defaults to device of first model parameter) """ + start_time =time.time() if target_dictionary is None: target_dictionary = source_dictionary if device is None: @@ -103,10 +110,21 @@ def eval_lm( bpe_len = 0 word_stats = dict() - - for sample in batch_iterator: + first_batch = None + + for i, sample in enumerate(batch_iterator): + if max_valid_steps is not None and i > max_valid_steps: + break + is_dummy_batch = False + if not first_batch and "net_input" in sample: + first_batch = sample if "net_input" not in sample: - continue + if first_batch: + logger.warning("Adding a dummy batch") + sample = first_batch + is_dummy_batch = True + else: + continue sample = utils.move_to_cuda(sample, device=device) @@ -114,6 +132,10 @@ def eval_lm( hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) + # Don't calculate score for dummy batch + if is_dummy_batch: + continue + for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] @@ -182,15 +204,11 @@ def eval_lm( ) ) - avg_nll_loss = ( - -score_sum / count / math.log(2) if count > 0 else 0 - ) # convert to base 2 - logger.info( - "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( - gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 - ) - ) - + avg_nll_loss = get_aggregated_loss(score_sum, count) # convert to base 2 + tokens, gpu_seconds_taken, avg_time = get_aggregated_timer_stats(gen_timer) + end_time = time.time() + tot_time = end_time-start_time + logger.info(f"Evaluated {tokens:,} tokens in {tot_time:.1f}s ({tokens / tot_time:.2f} tokens/s)") if output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws) @@ -198,9 +216,40 @@ def eval_lm( return { "loss": avg_nll_loss, "perplexity": 2 ** avg_nll_loss, + "r0_tps_step": 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0, + "ntok_total": tokens, + "gpu_step_seconds": gpu_seconds_taken, } +def _all_reduce_float(x): + if not torch.is_tensor(x): + x = torch.tensor(x) + x_tensor = x.cuda() + torch.distributed.all_reduce(x_tensor) + return x_tensor.item() + + +def get_aggregated_loss(score_sum, count): + if torch.distributed.is_initialized(): + logger.warning("Aggregating scores across the distributed world") + count = _all_reduce_float(count) + score_sum = _all_reduce_float(score_sum) + return ( + -score_sum / count / math.log(2) if count > 0 else 0 + ) + + +def get_aggregated_timer_stats(gen_timer): + tokens, time_taken, avg_time = gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 + if torch.distributed.is_initialized(): + logger.warning("Aggregating timer stats across the distributed world") + tokens = _all_reduce_float(tokens) + time_taken = _all_reduce_float(time_taken) + avg_time = _all_reduce_float(avg_time) / torch.distributed.get_world_size() + return tokens, time_taken, avg_time + + class WordStat(object): def __init__(self, word, is_bpe): self.word = word @@ -233,30 +282,113 @@ def __str__(self): ) +def eval_dataset(cfg: DictConfig, eval_split, task, models, start_time): + dataset = task.dataset(eval_split) + logger.info(f"{cfg.task.data} {eval_split} {len(dataset):,} examples") + num_shards = max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ) + shard_id = max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ) + itr = task.eval_lm_dataloader( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens or 36000, + batch_size=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), + num_shards=num_shards, + shard_id=shard_id, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + context_window=cfg.eval_lm.context_window, + ) + itr = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + load_time = time.time() - start_time + logger.info(f'load time: {load_time:.2f} seconds') + results = eval_lm( + models=models, + source_dictionary=task.source_dictionary, + batch_iterator=itr, + post_process=cfg.common_eval.post_process, + output_word_probs=cfg.eval_lm.output_word_probs, + output_word_stats=cfg.eval_lm.output_word_stats, + target_dictionary=task.target_dictionary, + softmax_batch=cfg.eval_lm.softmax_batch, + remove_bos_token=getattr(cfg.task, "add_bos_token", False), + max_valid_steps=cfg.eval_lm.max_valid_steps, + ) + + end_time = time.time() + total_time = end_time - start_time + logger.info( + "{} Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( + eval_split, results["loss"], results["perplexity"] + ) + ) + + if isinstance(cfg.eval_lm.stats_path, str): + rr = {k: round_safe(v) for k,v in results.items()} + rr['wall_time'] = round_safe(total_time) + rr['wall_time_load'] = round_safe(load_time) + rr['wall_time_model'] = round_safe(total_time - load_time) + else: + rr = None + + return results, rr, end_time + + def main(cfg: DictConfig, **unused_kwargs): + start_time = time.time() if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) - logger.info(cfg) + logger.info('---------------------------') + logger.info('cfg:') + config_content = getattr(cfg, "_content") + for config_key in config_content: + logger.info(config_key + '\t' + str(config_content[config_key])) + logger.info('---------------------------') if cfg.eval_lm.context_window > 0: # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + model_overrides = eval(cfg.common_eval.model_overrides) + is_base_moe = model_overrides.get('is_base_moe', False) + if cfg.common_eval.is_moe or is_base_moe: + rank = distributed_utils.get_data_parallel_rank() + cfg.checkpoint.checkpoint_suffix = f"-rank-{rank}" + is_moe = True + # This is required for making all_to_all work on same sized tensors across gpus. + cfg['task']['pad_to_fixed_length'] = True + else: + is_moe = False + # Initialize the task using the current *cfg* task = tasks.setup_task(cfg.task) # Load ensemble - logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + model_overrides['batch_size_valid'] = cfg.dataset.batch_size models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [cfg.common_eval.path], - arg_overrides=eval(cfg.common_eval.model_overrides), + utils.split_paths(cfg.common_eval.path), + arg_overrides=model_overrides, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, task=task, + is_moe=is_moe or is_base_moe, ) use_fp16 = cfg.common.fp16 @@ -271,7 +403,10 @@ def main(cfg: DictConfig, **unused_kwargs): model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(cfg) + + if is_moe: + # For moe models, we want to enable padding in moe layer, so not calling this. + model.prepare_for_inference_(cfg) assert len(models) > 0 @@ -281,57 +416,23 @@ def main(cfg: DictConfig, **unused_kwargs): # Load dataset splits task.load_dataset(cfg.dataset.gen_subset) - dataset = task.dataset(cfg.dataset.gen_subset) - logger.info( - "{} {} {:,} examples".format( - cfg.task.data, cfg.dataset.gen_subset, len(dataset) - ) - ) - - itr = task.eval_lm_dataloader( - dataset=dataset, - max_tokens=cfg.dataset.max_tokens or 36000, - batch_size=cfg.dataset.batch_size, - max_positions=utils.resolve_max_positions( - *[model.max_positions() for model in models] - ), - num_shards=max( - cfg.dataset.num_shards, - cfg.distributed_training.distributed_world_size, - ), - shard_id=max( - cfg.dataset.shard_id, - cfg.distributed_training.distributed_rank, - ), - num_workers=cfg.dataset.num_workers, - data_buffer_size=cfg.dataset.data_buffer_size, - context_window=cfg.eval_lm.context_window, - ) - - itr = progress_bar.progress_bar( - itr, - log_format=cfg.common.log_format, - log_interval=cfg.common.log_interval, - default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), - ) - - results = eval_lm( - models=models, - source_dictionary=task.source_dictionary, - batch_iterator=itr, - post_process=cfg.common_eval.post_process, - output_word_probs=cfg.eval_lm.output_word_probs, - output_word_stats=cfg.eval_lm.output_word_stats, - target_dictionary=task.target_dictionary, - softmax_batch=cfg.eval_lm.softmax_batch, - remove_bos_token=getattr(cfg.task, "add_bos_token", False), - ) - - logger.info( - "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( - results["loss"], results["perplexity"] - ) - ) + eval_splits = [cfg.dataset.gen_subset] + if cfg.task._name == 'multilingual_language_modeling': + languages = cfg.task.langs.split(',') + for lang in languages: + eval_splits.append(f'{cfg.dataset.gen_subset}_{lang}') + + all_split_results = dict() + for eval_split in eval_splits: + results, rr, end_time = eval_dataset(cfg, eval_split, task, models, start_time) + start_time = end_time + all_split_results[eval_split] = rr + + if isinstance(cfg.eval_lm.stats_path, str): + save_path = f'{cfg.eval_lm.stats_path}.json' + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + save_json(all_split_results, save_path) + logger.info('Evaluation results saved to {}'.format(save_path)) return results @@ -344,4 +445,4 @@ def cli_main(): if __name__ == "__main__": - cli_main() + cli_main() \ No newline at end of file diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 7bd582b256..a55fd83aea 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -17,10 +17,11 @@ import numpy as np import torch -from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq import checkpoint_utils, options, scoring, tasks, utils, distributed_utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.utils import print_r0 from omegaconf import DictConfig @@ -93,6 +94,11 @@ def _main(cfg: DictConfig, output_file): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + if cfg.common_eval.is_moe and torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + cfg.checkpoint.checkpoint_suffix = f"-rank-{torch.distributed.get_rank()}" + moe_freq = 1 + else: + moe_freq = 0 models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, @@ -100,6 +106,7 @@ def _main(cfg: DictConfig, output_file): suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, + is_moe=moe_freq > 0, ) # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config @@ -138,6 +145,12 @@ def _main(cfg: DictConfig, output_file): align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) + num_shards = cfg.distributed_training.distributed_world_size + shard_id = cfg.distributed_training.distributed_rank + # We need all GPUs to process the same batch + if cfg.common_eval.is_moe: + num_shards = 1 + shard_id = 0 itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, @@ -148,8 +161,8 @@ def _main(cfg: DictConfig, output_file): ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, - num_shards=cfg.distributed_training.distributed_world_size, - shard_id=cfg.distributed_training.distributed_rank, + num_shards=num_shards, + shard_id=shard_id, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) @@ -254,9 +267,9 @@ def decode_fn(x): if not cfg.common_eval.quiet: if src_dict is not None: - print("S-{}\t{}".format(sample_id, src_str), file=output_file) + print_r0("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: - print("T-{}\t{}".format(sample_id, target_str), file=output_file) + print_r0("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): @@ -273,16 +286,16 @@ def decode_fn(x): if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) - print( + print_r0( "H-{}\t{}\t{}".format(sample_id, score, hypo_str), file=output_file, ) # detokenized hypothesis - print( + print_r0( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) - print( + print_r0( "P-{}\t{}".format( sample_id, " ".join( @@ -299,7 +312,7 @@ def decode_fn(x): ) if cfg.generation.print_alignment == "hard": - print( + print_r0( "A-{}\t{}".format( sample_id, " ".join( @@ -312,7 +325,7 @@ def decode_fn(x): file=output_file, ) if cfg.generation.print_alignment == "soft": - print( + print_r0( "A-{}\t{}".format( sample_id, " ".join( @@ -326,7 +339,7 @@ def decode_fn(x): ) if cfg.generation.print_step: - print( + print_r0( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) @@ -341,7 +354,7 @@ def decode_fn(x): tgt_dict=tgt_dict, remove_bpe=None, ) - print( + print_r0( "E-{}_{}\t{}".format(sample_id, step, h_str), file=output_file, ) @@ -388,7 +401,7 @@ def decode_fn(x): "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on - print( + print_r0( "Generate {} with beam={}: {}".format( cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), @@ -398,10 +411,13 @@ def decode_fn(x): return scorer + + def cli_main(): parser = options.get_generation_parser() args = options.parse_args_and_arch(parser) - main(args) + cfg = convert_namespace_to_omegaconf(args) + distributed_utils.call_main(cfg, main) if __name__ == "__main__": diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index d770e4e4ec..248f7cbed1 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -16,6 +16,8 @@ import numpy as np import torch +import torch.distributed as dist +import functools from fairseq import ( checkpoint_utils, options, @@ -23,7 +25,8 @@ tasks, utils, ) -from fairseq.data import iterators +from fairseq.data import iterators, data_utils +from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils @@ -58,11 +61,17 @@ def main(cfg: FairseqConfig) -> None: ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() + if cfg.common.log_file is not None: + handler = logging.FileHandler(filename=cfg.common.log_file) + logger.addHandler(handler) + np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(cfg.distributed_training): - checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + + # Print nvidia smi stats + logger.info(metrics.get_nvidia_smi_gpu_memory_stats_str()) # Print args logger.info(cfg) @@ -79,29 +88,54 @@ def main(cfg: FairseqConfig) -> None: # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) - # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in cfg.dataset.valid_subset.split(","): - task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" + if getattr(cfg.model, "moe_freq", 0) > 0 and getattr(cfg.model, "moe_expert_count", 0) < distributed_utils.get_global_world_size(): + assert cfg.distributed_training.ddp_backend == 'fully_sharded', 'num_experts < num_gpus only supported by FSDP' # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": - with fsdp_enable_wrap(cfg.distributed_training): + #if cfg.distributed_training.use_sharded_state: assert cfg.checkpoint.no_save_optimizer_state, f'--use-sharded-state requires --no-save-optimizer-state' + extra = { + "is_moe": getattr(cfg.model, "moe_freq", 0) > 0, + "use_sharded_state": cfg.distributed_training.use_sharded_state, + } + + with fsdp_enable_wrap(cfg.distributed_training, **extra): model = fsdp_wrap(task.build_model(cfg.model)) else: model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) + + def is_expert_param(p): + return getattr(p, "expert", False) or getattr(p, "base_expert", False) + logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( - "num. model params: {:,} (num. trained: {:,})".format( - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), + "num. non-expert model params: {:,} (num. trained: {:,})".format( + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if not is_expert_param(p)), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if not is_expert_param(p) and p.requires_grad), + ) + ) + logger.info( + "num. expert model params: {:,} (num. trained: {:,})".format( + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if is_expert_param(p)), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if is_expert_param(p) and p.requires_grad), ) ) + logger.info(metrics.get_nvidia_smi_gpu_memory_stats_str()) + + # Load valid dataset (we load training data below, based on the latest checkpoint) + # We load the valid dataset AFTER building the model + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + if cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: @@ -118,7 +152,6 @@ def main(cfg: FairseqConfig) -> None: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size @@ -130,6 +163,7 @@ def main(cfg: FairseqConfig) -> None: cfg.dataset.batch_size, ) ) + logger.info(metrics.get_nvidia_smi_gpu_memory_stats_str()) # Load the latest checkpoint if one is available and restore the # corresponding train iterator @@ -230,6 +264,7 @@ def train( progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, + log_file=cfg.common.log_file, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( @@ -255,8 +290,10 @@ def train( progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) - - valid_subsets = cfg.dataset.valid_subset.split(",") + if cfg.task._name in ["multilingual_language_modeling", "translation_multi_simple_epoch"]: + valid_subsets = task.args.valid_subset.split(",") + else: + valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") @@ -361,8 +398,6 @@ def validate_and_save( and num_updates % cfg.dataset.validate_interval_updates == 0 ) ) and not cfg.dataset.disable_validation - - # Validate valid_losses = [None] if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) @@ -372,11 +407,25 @@ def validate_and_save( # Save checkpoint if do_save or should_stop: checkpoint_utils.save_checkpoint( - cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + cfg.checkpoint, trainer, epoch_itr, valid_losses[0], training_finished=should_stop, + async_callback_fn=functools.partial(post_checkpoint_callback, cfg) if cfg.checkpoint.s3_upload_path else None, ) + trainer.reset_dummy_batch(epoch_itr.first_batch) return valid_losses, should_stop +def post_checkpoint_callback(cfg, filename): + if cfg.checkpoint.s3_upload_path is not None: + try: + # PathManager only supports writing to S3, but this function call + # can be replaced with other APIs for copying checkpoints. + PathManager.copy_from_local( + filename, + os.path.join(cfg.checkpoint.s3_upload_path, os.path.basename(filename)), + overwrite=True, + ) + except (FileNotFoundError, AssertionError) as e: + logger.info(f'could not upload {filename}: {e}') def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) @@ -399,7 +448,8 @@ def validate( trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: - logger.info('begin validation on "{}" subset'.format(subset)) + logger.info('begin validation on "{}" subset on rank {}'.format( + subset, distributed_utils.get_global_rank())) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( @@ -407,6 +457,13 @@ def validate( ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) + + logger.info('got valid iterator on "{}" subset on rank {}'.format( + subset, + distributed_utils.get_global_rank() + ) + ) + progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, @@ -429,16 +486,18 @@ def validate( ), ) + logger.info('Begin looping over validation "{}" subset with length "{}"'.format(subset, len(progress))) + # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: - for sample in progress: + for i, sample in enumerate(progress): + if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: + break trainer.valid_step(sample) - # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses @@ -465,6 +524,10 @@ def cli_main( cfg = convert_namespace_to_omegaconf(args) + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -472,6 +535,9 @@ def cli_main( else: distributed_utils.call_main(cfg, main) + # if cfg.common.use_plasma_view: + # server.server.kill() + if __name__ == "__main__": cli_main() From 1ef16122f039e45bf69098c1b2eebd27e6381db1 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 6 Sep 2021 07:09:13 -0700 Subject: [PATCH 02/63] Update README.moe.md with reference speeds --- README.md | 217 +------------------------------------------------- README.moe.md | 16 +++- 2 files changed, 16 insertions(+), 217 deletions(-) mode change 100644 => 120000 README.md diff --git a/README.md b/README.md deleted file mode 100644 index 5fedac7eec..0000000000 --- a/README.md +++ /dev/null @@ -1,216 +0,0 @@ -

- -
-
- MIT License - Latest Release - Build Status - Documentation Status -

- --------------------------------------------------------------------------------- - -Fairseq(-py) is a sequence modeling toolkit that allows researchers and -developers to train custom models for translation, summarization, language -modeling and other text generation tasks. - -We provide reference implementations of various sequence modeling papers: - -
List of implemented papers

- -* **Convolutional Neural Networks (CNN)** - + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) - + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* **LightConv and DynamicConv models** - + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* **Long Short-Term Memory (LSTM) networks** - + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) -* **Transformer (self-attention) networks** - + Attention Is All You Need (Vaswani et al., 2017) - + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) - + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) - + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) -* **Non-autoregressive Transformers** - + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* **Finetuning** - + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) - -

- -### What's New: - -* December 2020: [GottBERT model and code released](examples/gottbert/README.md) -* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework - * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) -* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) -* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) -* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) -* October 2020: [Added CRISS models and code](examples/criss/README.md) -* September 2020: [Added Linformer code](examples/linformer/README.md) -* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) -* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) -* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) -* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - -
Previous updates

- -* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) -* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) -* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -* February 2020: [mBART model and code released](examples/mbart/README.md) -* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) -* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -* November 2019: [CamemBERT model and code released](examples/camembert/README.md) -* November 2019: [BART model and code released](examples/bart/README.md) -* November 2019: [XLM-R models and code released](examples/xlmr/README.md) -* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -* August 2019: [WMT'19 models released](examples/wmt19/README.md) -* July 2019: fairseq relicensed under MIT license -* July 2019: [RoBERTa models and code released](examples/roberta/README.md) -* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) - -

- -### Features: - -* multi-GPU training on one machine or across multiple machines (data and model parallel) -* fast generation on both CPU and GPU with multiple search algorithms implemented: - + beam search - + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - + sampling (unconstrained, top-k and top-p/nucleus) - + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) -* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU -* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers -* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration - -We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) -with a convenient `torch.hub` interface: - -``` python -en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') -en2de.translate('Hello world', beam=5) -# 'Hallo Welt' -``` - -See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) -and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. - -# Requirements and Installation - -* [PyTorch](http://pytorch.org/) version >= 1.5.0 -* Python version >= 3.6 -* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **To install fairseq** and develop locally: - -``` bash -git clone https://github.com/pytorch/fairseq -cd fairseq -pip install --editable ./ - -# on MacOS: -# CFLAGS="-stdlib=libc++" pip install --editable ./ - -# to install the latest stable release (0.10.x) -# pip install fairseq -``` - -* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: - -``` bash -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ - --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ - --global-option="--fast_multihead_attn" ./ -``` - -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` - as command line options to `nvidia-docker run` . - -# Getting Started - -The [full documentation](https://fairseq.readthedocs.io/) contains instructions -for getting started, training new models and extending fairseq with new model -types and tasks. - -# Pre-trained models and examples - -We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, -as well as example training and evaluation commands. - -* [Translation](examples/translation/README.md): convolutional and transformer models are available -* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available - -We also have more detailed READMEs to reproduce results from specific papers: - -* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) -* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) -* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) -* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) -* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) -* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) -* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) -* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) -* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) -* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) -* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) -* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) -* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) -* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) -* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) - -# Join the fairseq community - -* Twitter: https://twitter.com/fairseq -* Facebook page: https://www.facebook.com/groups/fairseq.users -* Google group: https://groups.google.com/forum/#!forum/fairseq-users - -# License - -fairseq(-py) is MIT-licensed. -The license applies to the pre-trained models as well. - -# Citation - -Please cite as: - -``` bibtex -@inproceedings{ott2019fairseq, - title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, - author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, - booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, - year = {2019}, -} -``` diff --git a/README.md b/README.md new file mode 120000 index 0000000000..4d567e7967 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +README.moe.md \ No newline at end of file diff --git a/README.moe.md b/README.moe.md index cb3bf500ef..d5fd5fcdef 100644 --- a/README.moe.md +++ b/README.moe.md @@ -106,6 +106,11 @@ salloc --gpus-per-node 8 --ntasks-per-node 8 --cpus-per-task 12 --nodes 8 --mem- NUM_EXPERTS=64 TOKENS_PER_SAMPLE=1024 +# we want 12 sequences per GPU. On <= 128 GPUs we can fit 6 sequences and use +# gradient accumulation to reach this target. +BATCH_SIZE=6 +GRAD_ACC=2 + # launch the job (adjust port and --cpu-bind if needed) DISTRIBUTED_PORT=12345 srun --cpu-bind=mask_cpu:000000ffffff000000ffffff,000000ffffff000000ffffff,000000ffffff000000ffffff,000000ffffff000000ffffff,ffffff000000ffffff000000,ffffff000000ffffff000000,ffffff000000ffffff000000,ffffff000000ffffff000000 \ @@ -125,7 +130,16 @@ srun --cpu-bind=mask_cpu:000000ffffff000000ffffff,000000ffffff000000ffffff,00000 --optimizer adam --fp16-adam-stats --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr 0.0005 --warmup-updates 750 \ --dropout 0.1 --attention-dropout 0.1 \ - --batch-size 12 --update-freq 1 \ + --batch-size $BATCH_SIZE --update-freq $GRAD_ACC \ --max-update 250 --disable-validation \ --log-format json --log-interval 10 ``` + +#### Expected performance on IB interconnect + +| num GPUs | num experts | batch size (x grad acc.) | words per second (wps) | +| -- | -- | -- | -- | +| 32 | 32 | 6 (x2) | 33k | +| 64 | 64 | 6 (x2) | 82k | +| 128 | 128 | 6 (x2) | 191k | +| 512 | 512 | 12 (x1) | 638k | From 9ea39b2ccf5c2bebcdbb7c0e7d690cecde8077d0 Mon Sep 17 00:00:00 2001 From: ghostplant Date: Tue, 16 Nov 2021 14:52:32 +0000 Subject: [PATCH 03/63] Add Tutel boost for Fairseq MoE acceleration (#3873) Signed-off-by: Wei CUI --- fairseq/modules/moe/moe_layer.py | 38 ++++++++++++++++++++++++-------- fairseq/modules/moe/top1gate.py | 8 ++++++- fairseq/modules/moe/top2gate.py | 23 ++++++++++++------- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/fairseq/modules/moe/moe_layer.py b/fairseq/modules/moe/moe_layer.py index 1c1668d3b9..541848dbb4 100644 --- a/fairseq/modules/moe/moe_layer.py +++ b/fairseq/modules/moe/moe_layer.py @@ -22,6 +22,14 @@ else: Base = Module +try: + # To enable Tutel MoE optimizations: + # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x + from tutel import moe as tutel_moe + + has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one +except ModuleNotFoundError: + has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1 logger = logging.getLogger(__name__) @@ -153,14 +161,23 @@ def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Ten padded_input_padding_mask[:reshaped_input_shape[0]] = False reshaped_input_padding_mask = padded_input_padding_mask - l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask) + if has_tutel: + l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask) + S, M = reshaped_input.size(0), reshaped_input.size(1) + + if not hasattr(self, '_tutel_dispatcher'): + self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype) + self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) + dispatched_input = self._tutel_dispatcher.encode(reshaped_input) + else: + l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask) - dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S - E, C, S = dispatch_mask.size() - M = reshaped_input.size(1) - assert reshaped_input.size() == (S, M) - # einsum("sec,sm->ecm") - dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M + dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S + E, C, S = dispatch_mask.size() + M = reshaped_input.size(1) + assert reshaped_input.size() == (S, M) + # einsum("sec,sm->ecm") + dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M if self.all2all_size > 1: dispatched_input = self.all_to_all_wrapper(dispatched_input) @@ -179,8 +196,11 @@ def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Ten # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model) - # einsum("sec,ecm->sm") - combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M)) + if has_tutel: + combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M)) + else: + # einsum("sec,ecm->sm") + combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M)) # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences combined_output = combined_output[:reshaped_input_shape[0], :] diff --git a/fairseq/modules/moe/top1gate.py b/fairseq/modules/moe/top1gate.py index 2a79b5fa60..cc22f84991 100644 --- a/fairseq/modules/moe/top1gate.py +++ b/fairseq/modules/moe/top1gate.py @@ -17,6 +17,7 @@ from torch import Tensor import torch.nn.functional as F +from .moe_layer import has_tutel, fused_cumsum_sub_one from .top2gate import one_hot, entropy @@ -74,13 +75,18 @@ def top1gating( gates1_s = (gates * mask1).sum(dim=1) # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 + locations1 = fused_cumsum_sub_one(mask1) # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.to(gates.dtype), dim=0) l_aux = torch.mean(me * ce) l_aux = l_aux * num_experts * num_experts + + if has_tutel: + locations1_s = torch.sum(locations1 * mask1, dim=1) + return l_aux, metadata, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,] + # Remove locations outside capacity from mask mask1 = mask1 * torch.lt(locations1, capacity) # Store the capacity location for each token diff --git a/fairseq/modules/moe/top2gate.py b/fairseq/modules/moe/top2gate.py index 0f88d5af3c..425e42800d 100644 --- a/fairseq/modules/moe/top2gate.py +++ b/fairseq/modules/moe/top2gate.py @@ -18,6 +18,7 @@ from torch.distributions import Categorical import torch.nn.functional as F +from .moe_layer import has_tutel, fused_cumsum_sub_one gumbel_map: Dict[torch.device, Callable] = {} @@ -116,19 +117,19 @@ def top2gating( # if batch_prioritized_routing: importance_scores = -1 * gates.max(dim=1)[0] sorted_mask1 = mask1[importance_scores.argsort(dim=0)] - sorted_cumsum1 = (torch.cumsum(sorted_mask1, dim=0) - 1) * sorted_mask1 + sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1 importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] sorted_mask2 = mask2[importance_scores.argsort(dim=0)] - sorted_cumsum2 = (torch.cumsum(sorted_mask2, dim=0) - 1) * sorted_mask2 + sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2 importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True) locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2 else: - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 + locations1 = fused_cumsum_sub_one(mask1) + locations2 = fused_cumsum_sub_one(mask2) # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) @@ -144,6 +145,7 @@ def top2gating( metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2) # Remove locations outside capacity from mask + mask1_, mask2_ = mask1, mask2 mask1 = mask1 * torch.lt(locations1, capacity) mask2 = mask2 * torch.lt(locations2, capacity) @@ -164,10 +166,6 @@ def top2gating( metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum() metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum() - # Store the capacity location for each token - locations1_s = torch.sum(locations1 * mask1, dim=1) - locations2_s = torch.sum(locations2 * mask2, dim=1) - if not normalize_gate_prob_before_dropping: # Normalize gate probabilities gates1_s = (gates * mask1).sum(dim=1) @@ -178,6 +176,15 @@ def top2gating( gates1_s /= denom_s gates2_s /= denom_s + if has_tutel: + locations1_s = torch.sum(locations1 * mask1_, dim=1) + locations2_s = torch.sum(locations2 * mask2_, dim=1) + return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + # Calculate combine_weights and dispatch_mask gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se") gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se") From ebea0072062e2e8f4563644f27546df355357f5e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 18 Dec 2021 07:10:36 -0800 Subject: [PATCH 04/63] Add instructions for using eval-lm with MoE models --- README.moe.md | 47 ++++++++++++++++++++++++++++++++++- fairseq/models/transformer.py | 2 +- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/README.moe.md b/README.moe.md index d5fd5fcdef..08501457f3 100644 --- a/README.moe.md +++ b/README.moe.md @@ -27,7 +27,7 @@ pip install hydra-core==1.0.7 omegaconf==2.0.6 megatron (must be installed from source to get fused kernels): ```bash -git clone --depth=1 --branch v2.4 https://github.com/NVIDIA/Megatron-LM.git +git clone --depth=1 --branch v2.6 https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM pip install -e . ``` @@ -137,9 +137,54 @@ srun --cpu-bind=mask_cpu:000000ffffff000000ffffff,000000ffffff000000ffffff,00000 #### Expected performance on IB interconnect +NOTE: The words-per-second estimates below are before the Tutel optimizations +introduced in [#3873](https://github.com/pytorch/fairseq/pull/3873). One should +expect an additional 15-20% speedup from those optimizations. + | num GPUs | num experts | batch size (x grad acc.) | words per second (wps) | | -- | -- | -- | -- | | 32 | 32 | 6 (x2) | 33k | | 64 | 64 | 6 (x2) | 82k | | 128 | 128 | 6 (x2) | 191k | | 512 | 512 | 12 (x1) | 638k | + +# Evaluating MoE language models + +#### Using `fairseq-eval-lm` + +The `fairseq-eval-lm` script can be used to score properly +[preprocessed/binarized datasets](https://github.com/pytorch/fairseq/tree/moe/examples/language_model#1-preprocess-the-data), +for example: + +```bash +DATA_PATH=/path/to/data-bin +MODEL_PATH=/path/to/model.pt +python -m fairseq_cli.eval_lm \ + $DATA_DIR + --path $MODEL_PATH \ + --gen-subset valid \ + --sample-break-mode none \ + --tokens-per-sample 2048 \ + --batch-size 1 \ + --fp16 \ + --output-word-probs \ + --is-moe \ + --distributed-world-size 8 \ + --model-overrides "{'world_size': 8, 'moe_eval_capacity_token_fraction': 0.05}" +``` + +#### Setting `moe_eval_capacity_token_fraction` + +When evaluating MoE models you may need to adjust the `--moe-eval-capacity-token-fraction` +option to match or exceed the training capacity. The logic is somewhat unintuitive: +* During training the capacity is set to `2 * math.ceil(local_bsz_in_tokens / global_num_experts)` +* During inference the capacity is set to `math.ceil(args.moe_eval_capacity_token_fraction * local_bsz_in_tokens)` + +For example, suppose you train a model with a batch size of 12 sequences per +GPU, each of length 1024, and 512 experts. This model will have a capacity of +48 during training (i.e., `2 * 12 * 1024 / 512`). + +Now suppose you want to match this setting at inference, but you are using +fewer GPUs so have reduced the batch size to 1 sequence of length 1024. In that +case you would want to set `--moe-eval-capacity-token-fraction=0.046875` (i.e., `1024 / 48`). + diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 2065e326ff..89efcc9ff7 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -43,7 +43,7 @@ def fsdp_wrap_expert(args, layer, min_num_params=0): # Wrap MoE layer with FSDP using a process group with all replicated ranks process_group = layer.moe_layer.expert_group - world_size = dist_utils.get_data_parallel_group().size() + world_size = dist_utils.get_data_parallel_world_size() pg_size = process_group.size() num_experts = world_size/pg_size From f0e9c9b5615aa8ef4d2bfd021fdafd64d79cb2ce Mon Sep 17 00:00:00 2001 From: shumingma Date: Wed, 16 Nov 2022 08:48:54 -0800 Subject: [PATCH 05/63] support infinibatch --- fairseq/checkpoint_utils.py | 21 +- ...moothed_cross_entropy_latency_augmented.py | 108 ---- fairseq/criterions/squad.py | 122 ++++ fairseq/data/squad/__init__.py | 3 + fairseq/data/squad/basic_tokenizer.py | 186 ++++++ fairseq/data/squad/squad_extractor.py | 396 ++++++++++++ fairseq/data/squad/squad_metrics.py | 601 ++++++++++++++++++ fairseq/models/squad.py | 98 +++ fairseq/tasks/glue.py | 65 ++ fairseq/tasks/squad.py | 206 ++++++ fairseq/trainer.py | 12 +- fairseq_cli/train.py | 9 +- setup.py | 2 + 13 files changed, 1712 insertions(+), 117 deletions(-) delete mode 100644 fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py create mode 100644 fairseq/criterions/squad.py create mode 100755 fairseq/data/squad/__init__.py create mode 100755 fairseq/data/squad/basic_tokenizer.py create mode 100755 fairseq/data/squad/squad_extractor.py create mode 100755 fairseq/data/squad/squad_metrics.py create mode 100644 fairseq/models/squad.py create mode 100644 fairseq/tasks/glue.py create mode 100644 fairseq/tasks/squad.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 52e32cf7f8..9d3735d364 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -49,6 +49,19 @@ def save_checkpoint( trainer.consolidate_optimizer() # TODO(SS): we dont need if no_save_optimizer_state + extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + if getattr(epoch_itr, "sharded_checkpoint", False): + local_state_dict = extra_state["train_iterator"] + all_state_dicts = dist_utils.all_gather_list( + local_state_dict, + max_size=getattr(trainer.cfg.common, "all_gather_list_size", 16384), + group=trainer.data_parallel_process_group, + ) + extra_state["train_iterator"] = all_state_dicts + if not trainer.should_save_checkpoint_on_current_rank: return @@ -96,9 +109,9 @@ def is_better(a, b): "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints - extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} - if hasattr(save_checkpoint, "best"): - extra_state.update({"best": save_checkpoint.best}) + # extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + # if hasattr(save_checkpoint, "best"): + # extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond @@ -256,7 +269,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): # restore iterator from checkpoint itr_state = extra_state["train_iterator"] epoch_itr = trainer.get_train_iterator( - epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + epoch=itr_state.get("epoch", 1), load_dataset=True, **passthrough_args ) epoch_itr.load_state_dict(itr_state) else: diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py deleted file mode 100644 index aa3dba31e2..0000000000 --- a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from examples.simultaneous_translation.utils.latency import LatencyTraining -from fairseq.criterions import register_criterion -from fairseq.criterions.label_smoothed_cross_entropy import ( - LabelSmoothedCrossEntropyCriterion, -) - - -@register_criterion("latency_augmented_label_smoothed_cross_entropy") -class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( - LabelSmoothedCrossEntropyCriterion -): - def __init__( - self, - task, - sentence_avg, - label_smoothing, - ignore_prefix_size, - report_accuracy, - latency_weight_avg, - latency_weight_avg_type, - latency_weight_var, - latency_weight_var_type, - mass_preservation, - average_method, - ): - super().__init__( - task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy - ) - self.eps = label_smoothing - self.latency_weight_avg = latency_weight_avg - self.latency_weight_avg_type = latency_weight_avg_type - self.latency_weight_var = latency_weight_var - self.latency_weight_var_type = latency_weight_var_type - self.mass_preservation = mass_preservation - self.average_method = average_method - self.latency_train = LatencyTraining( - self.latency_weight_avg, - self.latency_weight_var, - self.latency_weight_avg_type, - self.latency_weight_var_type, - self.mass_preservation, - self.average_method, - ) - - @staticmethod - def add_args(parser): - super( - LatencyAugmentedLabelSmoothedCrossEntropyCriterion, - LatencyAugmentedLabelSmoothedCrossEntropyCriterion, - ).add_args(parser) - # fmt: off - - """Add criterion-specific arguments to the parser.""" - parser.add_argument( - "--label-smoothing", - default=0.0, - type=float, - metavar="D", - help="epsilon for label smoothing, 0 means no label smoothing", - ) - parser.add_argument( - "--ignore_prefix_size", - default=0, - type=int, - help="ignore first N tokens", - ) - parser.add_argument( - "--report-accuracy", - default=False, - type=bool, - help="report accuracy metric", - ) - parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', - help="Average loss weight") - parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', - help="Variance loss weight") - parser.add_argument("--latency-weight-avg-type", default="differentiable_average_lagging", - help="Statistics for Average loss type") - parser.add_argument("--latency-weight-var-type", default="variance_delay", - help="Statistics for variance loss type") - parser.add_argument("--average-method", default="weighted_average", - help="Average loss type") - # fmt: on - - def compute_loss(self, model, net_output, sample, reduce=True): - # Compute cross entropy loss first - loss, nll_loss = super().compute_loss(model, net_output, sample, reduce) - - # Obtain the expected alignment - attn_list = [item["alpha"] for item in net_output[-1]["attn_list"]] - - target_padding_mask = model.get_targets(sample, net_output).eq(self.padding_idx) - - source_padding_mask = net_output[-1].get("encoder_padding_mask", None) - - # Get latency loss - latency_loss = self.latency_train.loss( - attn_list, source_padding_mask, target_padding_mask - ) - - loss += latency_loss - - return loss, nll_loss diff --git a/fairseq/criterions/squad.py b/fairseq/criterions/squad.py new file mode 100644 index 0000000000..ca78eca8bd --- /dev/null +++ b/fairseq/criterions/squad.py @@ -0,0 +1,122 @@ +import math +import logging +from dataclasses import dataclass, field +from omegaconf import II + +import os +import torch +import torch.nn.functional as F + +from fairseq import metrics, utils + +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from unilm.data.squad import SquadResult, compute_predictions_logits, squad_evaluate + +logger = logging.getLogger(__name__) + +@dataclass +class SquadConfig(FairseqDataclass): + n_best_size: int = field( + default=20, metadata={"help": "The number of n-best predictions"} + ) + max_answer_length: int = field( + default=30, metadata={"help": "The maximum length of the generated answer"} + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "Squad2.0"} + ) + save_dir: str = II("common.save_dir") + +@register_criterion('squad', dataclass=SquadConfig) +class SquadCriterion(FairseqCriterion): + + def __init__(self, cfg, task): + super().__init__(task) + self.cfg = cfg + self.head_name = 'question_answering_head' + + def forward(self, model, sample, reduce=True): + features, _ = model( + **sample['net_input'], + features_only=True, + classification_head_name=None, + ) + p_mask = sample['targets']['p_mask'] + if self.training: + start_positions = sample['targets']['starts'] + end_positions = sample['targets']['ends'] + loss = model.classification_heads[self.head_name].forward(features, start_positions, end_positions, p_mask) + else: + loss = torch.zeros(1, dtype=torch.float, device=features.device, requires_grad=True) + outputs = model.classification_heads[self.head_name].forward(features, p_mask=p_mask) + + + sample_size = sample['nsentences'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['nsentences'], + 'sample_size': sample_size, + } + if not self.training: + logging_output['start_logits'] = outputs[0].detach() + logging_output['end_logits'] = outputs[1].detach() + logging_output['index'] = sample['id'] + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs): + loss = sum(log.get('loss', 0) for log in logging_outputs) + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + + metrics.log_scalar('loss', loss / sample_size / math.log(2)) + metrics.log_scalar('ntokens', ntokens) + metrics.log_scalar('nsentences', nsentences) + metrics.log_scalar('sample_size', sample_size) + + def context_metrics(self, logging_outputs): + if self.training: + return + all_results = [] + task = self.task + for log in logging_outputs: + start_logits = log['start_logits'] + end_logits = log['end_logits'] + indices = log['index'] + for i in range(start_logits.size(0)): + index = int(indices[i]) + unique_id = task.eval_features[index].unique_id + result = SquadResult(unique_id, + start_logits[i].float().cpu().tolist(), + end_logits[i].float().cpu().tolist(), + ) + all_results.append(result) + + output_prediction_file = os.path.join(self.cfg.save_dir, "predictions.json") + output_nbest_file = os.path.join(self.cfg.save_dir, "nbest_predictions.json") + output_null_log_odds_file = os.path.join(self.cfg.save_dir, "null_odds.json") + if self.cfg.version_2_with_negative: + output_null_log_odds_file = os.path.join(self.cfg.save_dir, "null_odds.json") + else: + output_null_log_odds_file = None + predictions, null_scores = compute_predictions_logits( + task.eval_examples, + task.eval_features, + all_results, + n_best_size=self.cfg.n_best_size, + max_answer_length=self.cfg.max_answer_length, + do_lower_case=False, + output_prediction_file=output_prediction_file, + output_nbest_file=output_nbest_file, + output_null_log_odds_file=output_null_log_odds_file, + verbose_logging=False, + version_2_with_negative=self.cfg.version_2_with_negative, + null_score_diff_threshold=0.0, + tokenizer=task.tokenizer, + ) + + eval_result = squad_evaluate(task.eval_examples, predictions, null_scores) + logger.info(eval_result) \ No newline at end of file diff --git a/fairseq/data/squad/__init__.py b/fairseq/data/squad/__init__.py new file mode 100755 index 0000000000..b2a6723788 --- /dev/null +++ b/fairseq/data/squad/__init__.py @@ -0,0 +1,3 @@ +from .squad_extractor import SquadExample, SquadFeature, read_squad_examples, squad_convert_examples_to_features +from .basic_tokenizer import BasicTokenizer +from .squad_metrics import SquadResult, compute_predictions_logits, squad_evaluate \ No newline at end of file diff --git a/fairseq/data/squad/basic_tokenizer.py b/fairseq/data/squad/basic_tokenizer.py new file mode 100755 index 0000000000..eb3668bac5 --- /dev/null +++ b/fairseq/data/squad/basic_tokenizer.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False \ No newline at end of file diff --git a/fairseq/data/squad/squad_extractor.py b/fairseq/data/squad/squad_extractor.py new file mode 100755 index 0000000000..6c1312a026 --- /dev/null +++ b/fairseq/data/squad/squad_extractor.py @@ -0,0 +1,396 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os, collections +import pickle +import logging + +import numpy as np +import six + + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + +class SquadExample(object): + """A single training/test example for simple sequence classification. + + For examples without an answer, the start and end position are -1. + """ + + def __init__(self, + qas_id, + question_text, + doc_tokens, + orig_answer_text=None, + start_position=None, + end_position=None, + is_impossible=False, + answers=[]): + self.qas_id = qas_id + self.question_text = question_text + self.doc_tokens = doc_tokens + self.orig_answer_text = orig_answer_text + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + self.answers = answers + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = "" + s += "qas_id: %s" % (str(self.qas_id)) + s += ", question_text: %s" % ( + str(self.question_text)) + s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) + if self.start_position: + s += ", start_position: %d" % (self.start_position) + if self.start_position: + s += ", end_position: %d" % (self.end_position) + if self.start_position: + s += ", is_impossible: %r" % (self.is_impossible) + if self.orig_answer_text: + s += ", ori_answer_text: %s" % (self.orig_answer_text) + s += ", answer_text: %s" % (' '.join(self.doc_tokens[self.start_position: self.end_position + 1])) + return s + +class SquadFeature(object): + """A single set of features of data.""" + + def __init__(self, + unique_id, + example_index, + doc_span_index, + tokens, + token_to_orig_map, + token_is_max_context, + input_ids, + p_mask, + doc_offset, + start_position=None, + end_position=None, + is_impossible=None): + self.unique_id = unique_id + self.example_index = example_index + self.doc_span_index = doc_span_index + self.tokens = tokens + self.p_mask = p_mask + self.doc_offset = doc_offset + self.token_to_orig_map = token_to_orig_map + self.token_is_max_context = token_is_max_context + self.input_ids = input_ids + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + + +def read_squad_examples(input_file, is_training, version_2_with_negative): + """Read a SQuAD json file into a list of SquadExample.""" + with open(input_file, "r") as reader: + input_data = json.load(reader)["data"] + + def is_whitespace(c): + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + return True + return False + + examples = [] + for entry in input_data: + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + doc_tokens = [] + char_to_word_offset = [] + prev_is_whitespace = True + for c in paragraph_text: + if is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + char_to_word_offset.append(len(doc_tokens) - 1) + + for qa in paragraph["qas"]: + qas_id = qa["id"] + question_text = qa["question"] + start_position = None + end_position = None + orig_answer_text = None + is_impossible = qa.get("is_impossible", False) + answers = [] + if is_training: + if (len(qa["answers"]) != 1) and (not is_impossible): + raise ValueError( + "For training, each question should have exactly 1 answer.") + if not is_impossible: + answer = qa["answers"][0] + orig_answer_text = answer["text"] + answer_offset = answer["answer_start"] + answer_length = len(orig_answer_text) + start_position = char_to_word_offset[answer_offset] + end_position = char_to_word_offset[answer_offset + answer_length - 1] + # Only add answers where the text can be exactly recovered from the + # document. If this CAN'T happen it's likely due to weird Unicode + # stuff so we will just skip the example. + # + # Note that this means for training mode, every example is NOT + # guaranteed to be preserved. + actual_text = " ".join( + doc_tokens[start_position:(end_position + 1)]) + cleaned_answer_text = " ".join( + whitespace_tokenize(orig_answer_text)) + if actual_text.find(cleaned_answer_text) == -1: + print("Could not find answer: '%s' vs. '%s'", + actual_text, cleaned_answer_text) + continue + else: + start_position = -1 + end_position = -1 + orig_answer_text = "" + elif not is_impossible: + answers = qa["answers"] + + example = SquadExample( + qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_position, + end_position=end_position, + is_impossible=is_impossible, + answers=answers) + examples.append(example) + return examples + +def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, + orig_answer_text): + """Returns tokenized answer spans that better match the annotated answer.""" + + # The SQuAD annotations are character based. We first project them to + # whitespace-tokenized words. But then after WordPiece tokenization, we can + # often find a "better match". For example: + # + # Question: What year was John Smith born? + # Context: The leader was John Smith (1895-1943). + # Answer: 1895 + # + # The original whitespace-tokenized answer will be "(1895-1943).". However + # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match + # the exact answer, 1895. + # + # However, this is not always possible. Consider the following: + # + # Question: What country is the top exporter of electornics? + # Context: The Japanese electronics industry is the lagest in the world. + # Answer: Japan + # + # In this case, the annotator chose "Japan" as a character sub-span of + # the word "Japanese". Since our WordPiece tokenizer does not split + # "Japanese", we just use "Japanese" as the annotation. This is fairly rare + # in SQuAD, but does happen. + tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) + + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) + if text_span == tok_answer_text: + return (new_start, new_end) + + return (input_start, input_end) + + +def _check_is_max_context(doc_spans, cur_span_index, position): + """Check if this is the 'max context' doc span for the token.""" + + # Because of the sliding window approach taken to scoring documents, a single + # token can appear in multiple documents. E.g. + # Doc: the man went to the store and bought a gallon of milk + # Span A: the man went to the + # Span B: to the store and bought + # Span C: and bought a gallon of + # ... + # + # Now the word 'bought' will have two scores from spans B and C. We only + # want to consider the score with "maximum context", which we define as + # the *minimum* of its left and right context (the *sum* of left and + # right context will always be the same, of course). + # + # In the example the maximum context for 'bought' would be span C since + # it has 1 left context and 3 right context, while span B has 4 left context + # and 0 right context. + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + +def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, + max_query_length, is_training, + cls_token='[CLS]', sep_token='[SEP]', additional_seq=True): + features = [] + unique_id = 1000000000 + for (example_index, example) in enumerate(examples): + query_tokens = tokenizer.tokenize(example.question_text) + if len(query_tokens) > max_query_length: + query_tokens = query_tokens[0:max_query_length] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + tok_start_position = None + tok_end_position = None + if is_training and example.is_impossible: + tok_start_position = -1 + tok_end_position = -1 + if is_training and not example.is_impossible: + tok_start_position = orig_to_tok_index[example.start_position] + if example.end_position < len(example.doc_tokens) - 1: + tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 + else: + tok_end_position = len(all_doc_tokens) - 1 + (tok_start_position, tok_end_position) = _improve_answer_span( + all_doc_tokens, tok_start_position, tok_end_position, tokenizer, + example.orig_answer_text) + + max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 + + # We can have documents that are longer than the maximum sequence length. + # To deal with this we do a sliding window approach, where we take chunks + # of the up to our max length with a stride of `doc_stride`. + _DocSpan = collections.namedtuple( # pylint: disable=invalid-name + "DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, doc_stride) + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = [] + p_mask = [] + token_to_orig_map = {} + token_is_max_context = {} + tokens.append(cls_token) + p_mask.append(0) + for token in query_tokens: + tokens.append(token) + p_mask.append(1) + tokens.append(sep_token) + p_mask.append(1) + if additional_seq: + tokens.append(sep_token) + p_mask.append(1) + doc_offset = len(tokens) + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] + + is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) + token_is_max_context[len(tokens)] = is_max_context + tokens.append(all_doc_tokens[split_token_index]) + p_mask.append(0) + tokens.append(sep_token) + p_mask.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + assert len(p_mask) == len(input_ids) + start_position = None + end_position = None + if is_training and not example.is_impossible: + # For training, if our document chunk does not contain an annotation + # we throw it out, since there is nothing to predict. + doc_start = doc_span.start + doc_end = doc_span.start + doc_span.length - 1 + out_of_span = False + if not (tok_start_position >= doc_start and tok_end_position <= doc_end): + out_of_span = True + if out_of_span: + start_position = 0 + end_position = 0 + else: + start_position = tok_start_position - doc_start + doc_offset + end_position = tok_end_position - doc_start + doc_offset + + if is_training and example.is_impossible: + start_position = 0 + end_position = 0 + + if example_index < 10: + print("*** Example ***") + print("unique_id: %s" % (unique_id)) + print("example_index: %s" % (example_index)) + print("doc_span_index: %s" % (doc_span_index)) + print("tokens: %s" % " ".join( + [x for x in tokens])) + print("token_to_orig_map: %s" % " ".join( + ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) + print("token_is_max_context: %s" % " ".join([ + "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) + ])) + print("input_ids: %s" % " ".join([str(x) for x in input_ids])) + if is_training and example.is_impossible: + print("impossible example") + if is_training and not example.is_impossible: + answer_text = " ".join(tokens[start_position:(end_position + 1)]) + print("start_position: %d" % (start_position)) + print("end_position: %d" % (end_position)) + print("answer: %s" % (answer_text)) + + feature = SquadFeature( + unique_id=unique_id, + example_index=example_index, + doc_span_index=doc_span_index, + tokens=tokens, + token_to_orig_map=token_to_orig_map, + token_is_max_context=token_is_max_context, + input_ids=input_ids, + p_mask=p_mask, + doc_offset=doc_offset, + start_position=start_position, + end_position=end_position, + is_impossible=example.is_impossible) + features.append(feature) + unique_id += 1 + return features diff --git a/fairseq/data/squad/squad_metrics.py b/fairseq/data/squad/squad_metrics.py new file mode 100755 index 0000000000..12716d3e86 --- /dev/null +++ b/fairseq/data/squad/squad_metrics.py @@ -0,0 +1,601 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to +update `find_best_threshold` scripts for SQuAD V2.0 + +In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an +additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted +probability that a question is unanswerable. +""" + + +import collections +import json +import math +import re +import string +import logging + +logger = logging.getLogger(__name__) + +from . import BasicTokenizer + +#todo, check xl-net implementation https://github.com/google-research/albert/blob/master/squad_utils.py + + +class SquadResult: + """ + Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. + Args: + unique_id: The unique identifier corresponding to that example. + start_logits: The logits corresponding to the start of the answer + end_logits: The logits corresponding to the end of the answer + """ + + def __init__(self, unique_id, start_logits, end_logits): + self.start_logits = start_logits + self.end_logits = end_logits + self.unique_id = unique_id + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +def compute_exact(a_gold, a_pred): + return int(normalize_answer(a_gold) == normalize_answer(a_pred)) + + +def compute_f1(a_gold, a_pred): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def get_raw_scores(examples, preds): + """ + Computes the exact and f1 scores from the examples and the model predictions + """ + exact_scores = {} + f1_scores = {} + + for example in examples: + qas_id = example.qas_id + gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])] + + if not gold_answers: + # For unanswerable questions, only correct answer is empty string + gold_answers = [""] + + if qas_id not in preds: + print("Missing prediction for %s" % qas_id) + continue + + prediction = preds[qas_id] + exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers) + f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers) + + return exact_scores, f1_scores + + +def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): + new_scores = {} + for qid, s in scores.items(): + pred_na = na_probs[qid] > na_prob_thresh + if pred_na: + new_scores[qid] = float(not qid_to_has_ans[qid]) + else: + new_scores[qid] = s + return new_scores + + +def make_eval_dict(exact_scores, f1_scores, qid_list=None): + if not qid_list: + total = len(exact_scores) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores.values()) / total), + ("f1", 100.0 * sum(f1_scores.values()) / total), + ("total", total), + ] + ) + else: + total = len(qid_list) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), + ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), + ("total", total), + ] + ) + + +def merge_eval(main_eval, new_eval, prefix): + for k in new_eval: + main_eval["%s_%s" % (prefix, k)] = new_eval[k] + + +def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + + has_ans_score, has_ans_cnt = 0, 0 + for qid in qid_list: + if not qid_to_has_ans[qid]: + continue + has_ans_cnt += 1 + + if qid not in scores: + continue + has_ans_score += scores[qid] + + return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt + + +def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + main_eval["has_ans_exact"] = has_ans_exact + main_eval["has_ans_f1"] = has_ans_f1 + + +def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for _, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + return 100.0 * best_score / len(scores), best_thresh + + +def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) + + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + + +def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): + qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples} + has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer] + no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer] + + if not no_answer_probs: + no_answer_probs = {k: 0.0 for k in preds} + + exact, f1 = get_raw_scores(examples, preds) + + exact_threshold = apply_no_ans_threshold( + exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold + ) + f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) + + evaluation = make_eval_dict(exact_threshold, f1_threshold) + + if has_answer_qids: + has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) + merge_eval(evaluation, has_ans_eval, "HasAns") + + if no_answer_qids: + no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) + merge_eval(evaluation, no_ans_eval, "NoAns") + + if no_answer_probs: + find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer) + + return evaluation + + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + return orig_text + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + logger.info("Couldn't map start position") + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + logger.info("Couldn't map end position") + return orig_text + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text + + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def compute_predictions_logits( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + verbose_logging, + version_2_with_negative, + null_score_diff_threshold, + tokenizer, +): + """Write final predictions to the json file and log-odds of null if needed.""" + if output_prediction_file: + logger.info(f"Writing predictions to: {output_prediction_file}") + if output_nbest_file: + logger.info(f"Writing nbest to: {output_nbest_file}") + if output_null_log_odds_file and version_2_with_negative: + logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}") + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"] + ) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() + all_null_scores = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + min_null_feature_index = 0 # the paragraph slice with min null score + null_start_logit = 0 # the start logit at the slice with min null score + null_end_logit = 0 # the end logit at the slice with min null score + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.unique_id] + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + # if we could have irrelevant answers, get the min score of irrelevant + if version_2_with_negative: + feature_null_score = result.start_logits[0] + result.end_logits[0] + if feature_null_score < score_null: + score_null = feature_null_score + min_null_feature_index = feature_index + null_start_logit = result.start_logits[0] + null_end_logit = result.end_logits[0] + doc_offset = feature.doc_offset + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens) or start_index < doc_offset: + continue + if end_index >= len(feature.tokens) or end_index < doc_offset: + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index], + ) + ) + if version_2_with_negative: + prelim_predictions.append( + _PrelimPrediction( + feature_index=min_null_feature_index, + start_index=0, + end_index=0, + start_logit=null_start_logit, + end_logit=null_end_logit, + ) + ) + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"] + ) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + if pred.start_index > 0: # this is a non-null prediction + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] + + tok_text = tokenizer.convert_tokens_to_string(tok_tokens) + + # tok_text = " ".join(tok_tokens) + # + # # De-tokenize WordPieces that have been split off. + # tok_text = tok_text.replace(" ##", "") + # tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit)) + # if we didn't include the empty option in the n-best, include it + if version_2_with_negative: + if "" not in seen_predictions: + nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit)) + + # In very rare edge cases we could only have single null prediction. + # So we just create a nonce prediction in this case to avoid failure. + if len(nbest) == 1: + nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + assert len(nbest) >= 1, "No valid predictions" + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + if entry.text: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + + assert len(nbest_json) >= 1, "No valid predictions" + + if not version_2_with_negative: + all_predictions[example.qas_id] = nbest_json[0]["text"] + else: + # predict "" iff the null score - the score of best non-null > threshold + score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit) + scores_diff_json[example.qas_id] = score_diff + all_predictions[example.qas_id] = best_non_null_entry.text + all_null_scores[example.qas_id] = score_diff + all_nbest_json[example.qas_id] = nbest_json + + if output_prediction_file: + with open(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + if output_nbest_file: + with open(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + if output_null_log_odds_file and version_2_with_negative: + with open(output_null_log_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions, all_null_scores diff --git a/fairseq/models/squad.py b/fairseq/models/squad.py new file mode 100644 index 0000000000..0d38d9d720 --- /dev/null +++ b/fairseq/models/squad.py @@ -0,0 +1,98 @@ +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.modules import ( + LayerNorm, +) + +class PoolerLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The config used by the model, will be used to grab the :obj:`hidden_size` of the model. + """ + + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, 1) + self.dense.weight.data.normal_(mean=0.0, std=0.02) + self.dense.bias.data.zero_() + + def forward( + self, hidden_states, p_mask = None + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + :obj:`torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + x.masked_fill_(p_mask, float('-inf')) + return x + + + +class SQuADHead(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.start_logits = PoolerLogits(hidden_size) + self.end_logits = PoolerLogits(hidden_size) + + def forward( + self, + hidden_states, + start_positions=None, + end_positions=None, + p_mask = None, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Positions of the first token for the labeled span. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Positions of the last token for the labeled span. + is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Whether the question has a possible answer in the paragraph or not. + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + end_logits = self.end_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + def loss_fct(logits, targets): + return F.nll_loss( + F.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ), + targets.view(-1), + reduction='sum', + ) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) * 0.5 + return total_loss + else: + return start_logits, end_logits \ No newline at end of file diff --git a/fairseq/tasks/glue.py b/fairseq/tasks/glue.py new file mode 100644 index 0000000000..202f701d2a --- /dev/null +++ b/fairseq/tasks/glue.py @@ -0,0 +1,65 @@ +import logging +import os + +from dataclasses import dataclass, field +from typing import Optional + +from fairseq.tasks import FairseqDataclass, FairseqTask, register_task +from fairseq.tasks.sentence_prediction import SentencePredictionConfig, SentencePredictionTask +from fairseq.data import Dictionary +from omegaconf import II + +logger = logging.getLogger(__name__) + +@dataclass +class GlueConfig(SentencePredictionConfig): + required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") + +@register_task("glue", dataclass=GlueConfig) +class GlueTask(SentencePredictionTask): + """ + Sentence (or sentence pair) prediction (classification or regression) task. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @classmethod + def load_dictionary(cls, filename, extra_mask_tokens=False, required_batch_size_multiple=1): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + + if extra_mask_tokens: + dictionary.add_symbol("") + for i in range(100): + dictionary.add_symbol(f"") + + dictionary.pad_to_multiple_(required_batch_size_multiple) + + return dictionary + + @classmethod + def setup_task(cls, cfg, **kwargs): + assert cfg.num_classes > 0, "Must set task.num_classes" + + # load data dictionary + data_dict = cls.load_dictionary( + os.path.join(cfg.data, "input0", "dict.txt"), + extra_mask_tokens=True, + required_batch_size_multiple=cfg.required_batch_size_multiple, + ) + logger.info("[input] dictionary: {} types".format(len(data_dict))) + + # load label dictionary + if not cfg.regression_target: + label_dict = cls.load_dictionary( + os.path.join(cfg.data, "label", "dict.txt"), + ) + logger.info("[label] dictionary: {} types".format(len(label_dict))) + else: + label_dict = data_dict + return cls(cfg, data_dict, label_dict) \ No newline at end of file diff --git a/fairseq/tasks/squad.py b/fairseq/tasks/squad.py new file mode 100644 index 0000000000..64bf3fa472 --- /dev/null +++ b/fairseq/tasks/squad.py @@ -0,0 +1,206 @@ +import os +import pickle +import torch +import numpy as np +from argparse import Namespace + +from fairseq.data import ( + data_utils, + Dictionary, + encoders, + BaseWrapperDataset, + IdDataset, + NumSamplesDataset, + NumelDataset, + NestedDictionaryDataset, + SortDataset, + NumelDataset, + RightPadDataset, + RawLabelDataset, + FairseqDataset, +) + +from fairseq.tasks import register_task, FairseqDataclass, FairseqTask +from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE +from dataclasses import dataclass, field +from omegaconf import II, MISSING + +@dataclass +class SquadConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + seed: int = II("common.seed") + spm_model: str = field( + default="", + metadata={ + "help": "sentencepice model to tokenize the data" + }, + ) + required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") + max_positions: int = field( + default=512, + metadata={"help": "max tokens per example"}, + ) + +@register_task('squad', dataclass=SquadConfig) +class SQuADTask(FairseqTask): + + def __init__(self, args, dictionary): + super().__init__(args) + self.args = args + self.dictionary = dictionary + self.seed = args.seed + self.tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=args.spm_model)) + assert self.tokenizer is not None + # self.dictionary.add_symbol('[MASK]') + + @classmethod + def load_dictionary(cls, filename, extra_mask_tokens=False, required_batch_size_multiple=1): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + + if extra_mask_tokens: + dictionary.add_symbol("") + for i in range(100): + dictionary.add_symbol(f"") + + dictionary.pad_to_multiple_(required_batch_size_multiple) + + return dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = cls.load_dictionary( + os.path.join(args.data, 'dict.txt'), + extra_mask_tokens=True, + required_batch_size_multiple=args.required_batch_size_multiple + ) + print('| Dictionary: {} types'.format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, combine=False, **kwargs): + features_file_path = os.path.join(self.args.data, "{}_features.pkl".format(split)) + examples_file_path = os.path.join(self.args.data, "{}_examples.pkl".format(split)) + + if os.path.exists(features_file_path) and os.path.exists(examples_file_path): + examples = pickle.load(open(examples_file_path, 'rb')) + features = pickle.load(open(features_file_path, 'rb')) + else: + raise FileNotFoundError("cannot find {} or {}".format(features_file_path, examples_file_path)) + + if split == 'valid': + # save for eval + self.eval_examples = examples + self.eval_features = features + + src_tokens = RawArrayDataset([torch.from_numpy(np.array(f.input_ids)) for f in features]) + p_mask = RawArrayDataset([torch.from_numpy(np.array(f.p_mask)).bool() for f in features]) + if split == 'train': + starts = RawLabelDataset([int(f.start_position) for f in features]) + ends = RawLabelDataset([int(f.end_position) for f in features]) + is_impossible = RawLabelDataset([int(f.is_impossible) for f in features]) + else: + starts = ends = is_impossible = None + #sizes = np.array([len(f.input_ids) for f in features]) + + ''' + Input format: question here ? Passage
+ ''' + dataset = NestedDictionaryDataset( + { + 'id': IdDataset(), + 'net_input': { + 'src_tokens': RightPadDataset( + src_tokens, + pad_idx=self.dictionary.pad(), + ), + 'src_lengths': NumelDataset(src_tokens, reduce=False), + }, + 'targets': { + 'starts': starts, + 'ends': ends, + 'is_impossible': is_impossible, + 'p_mask': RightPadDataset(p_mask, pad_idx=1), + }, + 'nsentences': NumSamplesDataset(), + 'ntokens': NumelDataset(src_tokens, reduce=True), + }, + sizes=[src_tokens.sizes], + ) + + if split == 'train': + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_tokens)) + dataset = SortDataset( + dataset, + sort_order=[shuffle], + ) + + print('| Loaded {} with {} samples'.format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, args): + from fairseq import models + model = models.build_model(args, self) + + model.register_question_answering_head( + 'question_answering_head', + num_classes=2, + ) + return model + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + criterion.context_metrics(logging_outputs) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + def max_positions(self): + return self.args.max_positions + + +class RawArrayDataset(FairseqDataset): + + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + if hasattr(dataset, 'sizes'): + self._sizes = dataset.sizes + else: + try: + self._sizes = np.array([len(x) for x in self.dataset]) + except: + self._sizes = np.array([1 for x in self.dataset]) + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, 'collater'): + return self.dataset.collater(samples) + else: + return default_collate(samples) + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) \ No newline at end of file diff --git a/fairseq/trainer.py b/fairseq/trainer.py index f42d31402e..6eec8766fd 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -603,7 +603,11 @@ def load_checkpoint( if extra_state is not None: itr_state = extra_state["train_iterator"] - epoch = itr_state["epoch"] + if type(itr_state) == list: + # assert len(itr_state) == self.data_parallel_world_size + itr_state = itr_state[self.data_parallel_rank] + extra_state["train_iterator"] = itr_state + epoch = itr_state.get("epoch", 1) if "previous_training_time" in extra_state: self._previous_training_time = extra_state["previous_training_time"] @@ -611,7 +615,7 @@ def load_checkpoint( self.lr_step(epoch) - if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: + if itr_state.get("version", 1) >= 2 and itr_state.get("iterations_in_epoch", 0) == 0: # reset meters at start of epoch reset_meters = True @@ -1065,8 +1069,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats - logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) - return logging_output + # logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) + return logging_outputs def zero_grad(self): self.optimizer.zero_grad() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 248f7cbed1..600473ce49 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -205,6 +205,9 @@ def is_expert_param(p): train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + if getattr(epoch_itr, "should_close_after_finished", False): + epoch_itr.close() + # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( @@ -491,10 +494,14 @@ def validate( # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: + logging_outputs = [] for i, sample in enumerate(progress): if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: break - trainer.valid_step(sample) + # trainer.valid_step(sample) + inner_logging_outputs = trainer.valid_step(sample) + logging_outputs.extend(inner_logging_outputs) + task.reduce_metrics(logging_outputs, trainer.get_criterion()) # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) diff --git a/setup.py b/setup.py index 3670ff3cfc..accb57a3e0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,9 @@ from setuptools import setup, find_packages, Extension from setuptools import Extension, find_packages, setup +import site +site.ENABLE_USER_SITE = True if sys.version_info < (3, 6): sys.exit("Sorry, Python >= 3.6 is required for fairseq.") From bf7c38d5a867f56ba9d836305b6e7405e8c30946 Mon Sep 17 00:00:00 2001 From: shumingma Date: Wed, 16 Nov 2022 08:56:51 -0800 Subject: [PATCH 06/63] fx --- fairseq/criterions/squad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/squad.py b/fairseq/criterions/squad.py index ca78eca8bd..c6c7a76724 100644 --- a/fairseq/criterions/squad.py +++ b/fairseq/criterions/squad.py @@ -11,7 +11,7 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass -from unilm.data.squad import SquadResult, compute_predictions_logits, squad_evaluate +from fairseq.data.squad import SquadResult, compute_predictions_logits, squad_evaluate logger = logging.getLogger(__name__) From 5605b44f973b573fa09ff9acad5b6826c9bb75e3 Mon Sep 17 00:00:00 2001 From: shumingma Date: Wed, 16 Nov 2022 09:10:51 -0800 Subject: [PATCH 07/63] fx --- fairseq/tasks/glue.py | 259 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 240 insertions(+), 19 deletions(-) diff --git a/fairseq/tasks/glue.py b/fairseq/tasks/glue.py index 202f701d2a..8e6b988067 100644 --- a/fairseq/tasks/glue.py +++ b/fairseq/tasks/glue.py @@ -1,22 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import logging import os -from dataclasses import dataclass, field -from typing import Optional +import numpy as np +from fairseq import utils +from fairseq.data import ( + ConcatSentencesDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + OffsetTokensDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + RollDataset, + SortDataset, + StripTokenDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task -from fairseq.tasks import FairseqDataclass, FairseqTask, register_task -from fairseq.tasks.sentence_prediction import SentencePredictionConfig, SentencePredictionTask -from fairseq.data import Dictionary -from omegaconf import II logger = logging.getLogger(__name__) -@dataclass -class GlueConfig(SentencePredictionConfig): - required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") -@register_task("glue", dataclass=GlueConfig) -class GlueTask(SentencePredictionTask): +@register_task("glue") +class GlueTask(LegacyFairseqTask): """ Sentence (or sentence pair) prediction (classification or regression) task. @@ -24,6 +40,62 @@ class GlueTask(SentencePredictionTask): dictionary (Dictionary): the dictionary for the input of the task """ + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", metavar="FILE", help="file prefix for data") + parser.add_argument( + "--num-classes", + type=int, + default=-1, + help="number of classes or regression targets", + ) + parser.add_argument( + "--init-token", + type=int, + default=None, + help="add token at the beginning of each batch item", + ) + parser.add_argument( + "--separator-token", + type=int, + default=None, + help="add separator token between inputs", + ) + parser.add_argument("--regression-target", action="store_true", default=False) + parser.add_argument("--no-shuffle", action="store_true", default=False) + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + parser.add_argument( + "--add-prev-output-tokens", + action="store_true", + default=False, + help="add prev_output_tokens to sample, used for encoder-decoder arch", + ) + + def __init__(self, args, data_dictionary, label_dictionary): + super().__init__(args) + self.dictionary = data_dictionary + self._label_dictionary = label_dictionary + if not hasattr(args, "max_positions"): + self._max_positions = ( + args.max_source_positions, + args.max_target_positions, + ) + else: + self._max_positions = args.max_positions + args.tokens_per_sample = self._max_positions + @classmethod def load_dictionary(cls, filename, extra_mask_tokens=False, required_batch_size_multiple=1): """Load the dictionary from the filename @@ -41,25 +113,174 @@ def load_dictionary(cls, filename, extra_mask_tokens=False, required_batch_size_ dictionary.pad_to_multiple_(required_batch_size_multiple) return dictionary - + @classmethod - def setup_task(cls, cfg, **kwargs): - assert cfg.num_classes > 0, "Must set task.num_classes" + def setup_task(cls, args, **kwargs): + assert args.num_classes > 0, "Must set --num-classes" # load data dictionary data_dict = cls.load_dictionary( - os.path.join(cfg.data, "input0", "dict.txt"), + args, + os.path.join(args.data, "input0", "dict.txt"), extra_mask_tokens=True, - required_batch_size_multiple=cfg.required_batch_size_multiple, + required_batch_size_multiple=args.required_batch_size_multiple, ) logger.info("[input] dictionary: {} types".format(len(data_dict))) # load label dictionary - if not cfg.regression_target: + if not args.regression_target: label_dict = cls.load_dictionary( - os.path.join(cfg.data, "label", "dict.txt"), + args, + os.path.join(args.data, "label", "dict.txt"), ) logger.info("[label] dictionary: {} types".format(len(label_dict))) else: label_dict = data_dict - return cls(cfg, data_dict, label_dict) \ No newline at end of file + return cls(args, data_dict, label_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + def get_path(key, split): + return os.path.join(self.args.data, key, split) + + def make_dataset(key, dictionary): + split_path = get_path(key, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + dictionary, + self.args.dataset_impl, + combine=combine, + ) + return dataset + + input0 = make_dataset("input0", self.source_dictionary) + assert input0 is not None, "could not find dataset: {}".format( + get_path("input0", split) + ) + input1 = make_dataset("input1", self.source_dictionary) + + if self.args.init_token is not None: + input0 = PrependTokenDataset(input0, self.args.init_token) + + if input1 is None: + src_tokens = input0 + else: + if self.args.separator_token is not None: + input1 = PrependTokenDataset(input1, self.args.separator_token) + + src_tokens = ConcatSentencesDataset(input0, input1) + + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_tokens)) + + src_tokens = maybe_shorten_dataset( + src_tokens, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.max_positions(), + self.args.seed, + ) + + dataset = { + "id": IdDataset(), + "net_input": { + "src_tokens": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset(src_tokens, reduce=False), + }, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), + } + + if self.args.add_prev_output_tokens: + prev_tokens_dataset = RightPadDataset( + RollDataset(src_tokens, 1), + pad_idx=self.dictionary.pad(), + ) + dataset["net_input"].update( + prev_output_tokens=prev_tokens_dataset, + ) + + if not self.args.regression_target: + label_dataset = make_dataset("label", self.label_dictionary) + if label_dataset is not None: + dataset.update( + target=OffsetTokensDataset( + StripTokenDataset( + label_dataset, + id_to_strip=self.label_dictionary.eos(), + ), + offset=-self.label_dictionary.nspecial, + ) + ) + else: + label_path = "{0}.label".format(get_path("label", split)) + if os.path.exists(label_path): + + def parse_regression_target(i, line): + values = line.split() + assert ( + len(values) == self.args.num_classes + ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' + return [float(x) for x in values] + + with open(label_path) as h: + dataset.update( + target=RawLabelDataset( + [ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ] + ) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[src_tokens.sizes], + ) + + if self.args.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, args): + from fairseq import models + + model = models.build_model(args, self) + + model.register_classification_head( + getattr(args, "classification_head_name", "sentence_classification_head"), + num_classes=self.args.num_classes, + ) + + return model + + def max_positions(self): + return self._max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + @property + def label_dictionary(self): + return self._label_dictionary From 978a6ce9855f01054ee54884f4a1ad7e701414a0 Mon Sep 17 00:00:00 2001 From: Shuming Ma Date: Fri, 7 Apr 2023 05:01:35 -0700 Subject: [PATCH 08/63] megatron checkpoint --- fairseq/model_parallel/megatron_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 8ab4657f73..4316ccfe1d 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -50,11 +50,11 @@ def _aggregate_model_parallel_grad_norm(total_norm): aggregate_norm_fn=_aggregate_model_parallel_grad_norm, ) - def save_checkpoint(self, filename, extra_state): + def save_checkpoint(self, filename, extra_state, **kwargs): """Save all training state in a checkpoint file.""" extra_state['rng_tracker_states'] \ = get_cuda_rng_tracker().get_states() - super().save_checkpoint(filename, extra_state) + super().save_checkpoint(filename, extra_state, **kwargs) def load_checkpoint( self, From 5b52116588f9551a751ba8be2ac5b00e17949059 Mon Sep 17 00:00:00 2001 From: Shuming Ma Date: Tue, 11 Apr 2023 06:14:37 -0700 Subject: [PATCH 09/63] bf16 --- fairseq/options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index b79443a177..5ff2b1208f 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -186,8 +186,8 @@ def parse_args_and_arch( args.bf16 = True args.tpu = getattr(args, "tpu", False) args.bf16 = getattr(args, "bf16", False) - if args.bf16: - args.tpu = True + # if args.bf16: + # args.tpu = True if args.tpu and args.fp16: raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") From a4ee20517e55717f7031bd2d681936eee7d852bf Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 15:46:56 +0000 Subject: [PATCH 10/63] test push --- CODE_OF_CONDUCT.md | 77 ---------------------------------------------- 1 file changed, 77 deletions(-) delete mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md deleted file mode 100644 index a0cbeaab76..0000000000 --- a/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,77 +0,0 @@ -# Code of Conduct - -## Our Pledge - -In the interest of fostering an open and welcoming environment, we as -contributors and maintainers pledge to make participation in our project and -our community a harassment-free experience for everyone, regardless of age, body -size, disability, ethnicity, sex characteristics, gender identity and expression, -level of experience, education, socio-economic status, nationality, personal -appearance, race, religion, or sexual identity and orientation. - -## Our Standards - -Examples of behavior that contributes to creating a positive environment -include: - -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members - -Examples of unacceptable behavior by participants include: - -* The use of sexualized language or imagery and unwelcome sexual attention or - advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic - address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Our Responsibilities - -Project maintainers are responsible for clarifying the standards of acceptable -behavior and are expected to take appropriate and fair corrective action in -response to any instances of unacceptable behavior. - -Project maintainers have the right and responsibility to remove, edit, or -reject comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct, or to ban temporarily or -permanently any contributor for other behaviors that they deem inappropriate, -threatening, offensive, or harmful. - -## Scope - -This Code of Conduct applies within all project spaces, and it also applies when -an individual is representing the project or its community in public spaces. -Examples of representing a project or community include using an official -project e-mail address, posting via an official social media account, or acting -as an appointed representative at an online or offline event. Representation of -a project may be further defined and clarified by project maintainers. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported by contacting the project team at . All -complaints will be reviewed and investigated and will result in a response that -is deemed necessary and appropriate to the circumstances. The project team is -obligated to maintain confidentiality with regard to the reporter of an incident. -Further details of specific enforcement policies may be posted separately. - -Project maintainers who do not follow or enforce the Code of Conduct in good -faith may face temporary or permanent repercussions as determined by other -members of the project's leadership. - -## Attribution - -This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, -available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see -https://www.contributor-covenant.org/faq - From f2c15d9d196f08af8ecb52cbf2e428faa9f992a7 Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 18:54:24 +0000 Subject: [PATCH 11/63] refactor by removing function inside function --- .../legacy_distributed_data_parallel.py | 177 ++++++++++-------- 1 file changed, 95 insertions(+), 82 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index a447c2a09e..a3df77e61e 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -14,6 +14,10 @@ training with `--update-freq`. """ +# function for + + + from collections import OrderedDict from contextlib import contextmanager @@ -23,6 +27,16 @@ from fairseq.distributed import utils +def start_pdb_on_rank_zero(): + rank = torch.distributed.get_rank() + if rank == 0: + import pdb + pdb.set_trace() + else: + import time + time.sleep(1e6) + + class LegacyDistributedDataParallel(nn.Module): """Implements distributed data parallelism at the module level. @@ -41,6 +55,7 @@ class LegacyDistributedDataParallel(nn.Module): def __init__(self, module, process_group, buffer_size=2 ** 28): super().__init__() + self.module = module self.process_group = process_group self.world_size = utils.get_world_size(self.process_group) @@ -73,99 +88,97 @@ def no_sync(self): def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) + def all_reduce_params(self, params): + buffer = self.buffer + nonzero_buffer = False + if len(params) > 1: + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) + nonzero_buffer = True + else: + buffer[offset : offset + sz].zero_() + offset += sz + else: + # we only have a single grad to all-reduce + p = params[0] + if p.grad is not None: + buffer = p.grad.data + nonzero_buffer = True + elif p.numel() <= self.buffer.numel(): + buffer = buffer[: p.numel()] + buffer.zero_() + else: + buffer = torch.zeros_like(p) + + if nonzero_buffer: + buffer.div_(self.world_size) + + utils.all_reduce(buffer, self.process_group) + + # copy all-reduced grads back into their original place + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) + else: + p.grad = buffer[offset : offset + sz].view_as(p).clone() + offset += sz + + def all_reduce_grads(self): """ This function must be called explicitly after backward to reduce gradients. There is no automatic hook like c10d. """ - def all_reduce_params(params): - buffer = self.buffer - nonzero_buffer = False - if len(params) > 1: - offset = 0 - for p in params: - sz = p.numel() - if p.grad is not None: - buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) - nonzero_buffer = True - else: - buffer[offset : offset + sz].zero_() - offset += sz - else: - # we only have a single grad to all-reduce - p = params[0] - if p.grad is not None: - buffer = p.grad.data - nonzero_buffer = True - elif p.numel() <= self.buffer.numel(): - buffer = buffer[: p.numel()] - buffer.zero_() - else: - buffer = torch.zeros_like(p) - - if nonzero_buffer: - buffer.div_(self.world_size) + # This function only needs to be called once + if self.accumulate_grads: + return - utils.all_reduce(buffer, self.process_group) + if self.buffer is None: + self.buffer = next(self.module.parameters()).new(self.buffer_size) - # copy all-reduced grads back into their original place + for params in self.per_device_params: + # All-reduce the gradients in buckets offset = 0 - for p in params: - sz = p.numel() - if p.grad is not None: - p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) - else: - p.grad = buffer[offset : offset + sz].view_as(p).clone() - offset += sz + buffered_params = [] + for param in params: + if not param.requires_grad: + continue - def reduction_fn(): - # This function only needs to be called once - if self.accumulate_grads: - return - - if self.buffer is None: - self.buffer = next(self.module.parameters()).new(self.buffer_size) - - for params in self.per_device_params: - # All-reduce the gradients in buckets - offset = 0 - buffered_params = [] - for param in params: - if not param.requires_grad: - continue - - if hasattr(param, 'base_expert'): - # Skip gradient sync for unshared parameters - continue - - if hasattr(param, 'expert'): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.data.div_(self.world_size) - continue + if hasattr(param, 'base_expert'): + # Skip gradient sync for unshared parameters + continue + + if hasattr(param, 'expert'): if param.grad is None: param.grad = torch.zeros_like(param) - if param.grad.requires_grad: - raise RuntimeError( - "DistributedDataParallel only works " - "with gradients that don't require " - "grad" - ) - sz = param.numel() - if sz > self.buffer.numel(): - # all-reduce big params directly - all_reduce_params([param]) else: - if offset + sz > self.buffer.numel(): - all_reduce_params(buffered_params) - offset = 0 - buffered_params.clear() - buffered_params.append(param) - offset += sz - - if len(buffered_params) > 0: - all_reduce_params(buffered_params) + param.grad.data.div_(self.world_size) + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + self.all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + self.all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz - reduction_fn() + if len(buffered_params) > 0: + self.all_reduce_params(buffered_params) From 51409372564590bb725659836bb5324b83bb9e25 Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 20:17:02 +0000 Subject: [PATCH 12/63] factor our buf and params --- .../legacy_distributed_data_parallel.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index a3df77e61e..2129a5a809 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -14,10 +14,6 @@ training with `--update-freq`. """ -# function for - - - from collections import OrderedDict from contextlib import contextmanager @@ -77,6 +73,10 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): paramlists[device] += [param] self.per_device_params = list(paramlists.values()) + #start_pdb_on_rank_zero() + + + @contextmanager def no_sync(self): """A context manager to disable gradient synchronization.""" @@ -88,8 +88,8 @@ def no_sync(self): def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) - def all_reduce_params(self, params): - buffer = self.buffer + def all_reduce_params(self, params, curr_buffer): + buffer = curr_buffer nonzero_buffer = False if len(params) > 1: offset = 0 @@ -107,7 +107,7 @@ def all_reduce_params(self, params): if p.grad is not None: buffer = p.grad.data nonzero_buffer = True - elif p.numel() <= self.buffer.numel(): + elif p.numel() <= curr_buffer.numel(): buffer = buffer[: p.numel()] buffer.zero_() else: @@ -129,6 +129,7 @@ def all_reduce_params(self, params): offset += sz + def all_reduce_grads(self): """ This function must be called explicitly after backward to reduce @@ -142,7 +143,12 @@ def all_reduce_grads(self): if self.buffer is None: self.buffer = next(self.module.parameters()).new(self.buffer_size) - for params in self.per_device_params: + self._all_reduce_grads(self.per_device_params, self.buffer) + + + def _all_reduce_grads(self, current_params, curr_buffer): + + for params in current_params: # All-reduce the gradients in buckets offset = 0 buffered_params = [] @@ -169,16 +175,16 @@ def all_reduce_grads(self): "grad" ) sz = param.numel() - if sz > self.buffer.numel(): + if sz > curr_buffer.numel(): # all-reduce big params directly - self.all_reduce_params([param]) + self.all_reduce_params([param], curr_buffer) else: - if offset + sz > self.buffer.numel(): - self.all_reduce_params(buffered_params) + if offset + sz > curr_buffer.numel(): + self.all_reduce_params(buffered_params, curr_buffer) offset = 0 buffered_params.clear() buffered_params.append(param) offset += sz if len(buffered_params) > 0: - self.all_reduce_params(buffered_params) + self.all_reduce_params(buffered_params, curr_buffer) From d31dff0de231a3da254ae113e2a07378589344cd Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 20:48:08 +0000 Subject: [PATCH 13/63] factor out pgs --- .../legacy_distributed_data_parallel.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 2129a5a809..6a8b8c0220 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -71,11 +71,14 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): if paramlists.get(device) is None: paramlists[device] = [] paramlists[device] += [param] - self.per_device_params = list(paramlists.values()) - - #start_pdb_on_rank_zero() + # split into expert and normal params + per_device_params = list(paramlists.values()) + self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] + self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] + #start_pdb_on_rank_zero() + #print('hi') @contextmanager def no_sync(self): @@ -88,7 +91,7 @@ def no_sync(self): def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) - def all_reduce_params(self, params, curr_buffer): + def all_reduce_params(self, params, curr_buffer, curr_process_group, curr_world_size): buffer = curr_buffer nonzero_buffer = False if len(params) > 1: @@ -114,9 +117,9 @@ def all_reduce_params(self, params, curr_buffer): buffer = torch.zeros_like(p) if nonzero_buffer: - buffer.div_(self.world_size) + buffer.div_(curr_world_size) - utils.all_reduce(buffer, self.process_group) + utils.all_reduce(buffer, curr_process_group) # copy all-reduced grads back into their original place offset = 0 @@ -143,10 +146,11 @@ def all_reduce_grads(self): if self.buffer is None: self.buffer = next(self.module.parameters()).new(self.buffer_size) - self._all_reduce_grads(self.per_device_params, self.buffer) + self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, self.world_size) + self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.process_group, self.world_size) - def _all_reduce_grads(self, current_params, curr_buffer): + def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): for params in current_params: # All-reduce the gradients in buckets @@ -164,7 +168,7 @@ def _all_reduce_grads(self, current_params, curr_buffer): if param.grad is None: param.grad = torch.zeros_like(param) else: - param.grad.data.div_(self.world_size) + param.grad.data.div_(curr_world_size) continue if param.grad is None: param.grad = torch.zeros_like(param) @@ -177,14 +181,14 @@ def _all_reduce_grads(self, current_params, curr_buffer): sz = param.numel() if sz > curr_buffer.numel(): # all-reduce big params directly - self.all_reduce_params([param], curr_buffer) + self.all_reduce_params([param], curr_buffer, curr_process_group, curr_world_size) else: if offset + sz > curr_buffer.numel(): - self.all_reduce_params(buffered_params, curr_buffer) + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group, curr_world_size) offset = 0 buffered_params.clear() buffered_params.append(param) offset += sz if len(buffered_params) > 0: - self.all_reduce_params(buffered_params, curr_buffer) + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group, curr_world_size) From 82b808dccc8e369305614e667d435daf2601fc2e Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 21:14:37 +0000 Subject: [PATCH 14/63] harmonize grad division --- .../legacy_distributed_data_parallel.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 6a8b8c0220..f5715adb14 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -77,6 +77,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] + assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) + #start_pdb_on_rank_zero() #print('hi') @@ -91,16 +93,14 @@ def no_sync(self): def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) - def all_reduce_params(self, params, curr_buffer, curr_process_group, curr_world_size): + def all_reduce_params(self, params, curr_buffer, curr_process_group): buffer = curr_buffer - nonzero_buffer = False if len(params) > 1: offset = 0 for p in params: sz = p.numel() if p.grad is not None: buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) - nonzero_buffer = True else: buffer[offset : offset + sz].zero_() offset += sz @@ -109,16 +109,12 @@ def all_reduce_params(self, params, curr_buffer, curr_process_group, curr_world_ p = params[0] if p.grad is not None: buffer = p.grad.data - nonzero_buffer = True elif p.numel() <= curr_buffer.numel(): buffer = buffer[: p.numel()] buffer.zero_() else: buffer = torch.zeros_like(p) - if nonzero_buffer: - buffer.div_(curr_world_size) - utils.all_reduce(buffer, curr_process_group) # copy all-reduced grads back into their original place @@ -147,11 +143,21 @@ def all_reduce_grads(self): self.buffer = next(self.module.parameters()).new(self.buffer_size) self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, self.world_size) - self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.process_group, self.world_size) + self._div_all_grads_by_worldsize(self.per_device_params_expert, self.world_size) - def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): + def _div_all_grads_by_worldsize(self, current_params, curr_world_size): + for params in current_params: + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.data.div_(curr_world_size) + continue + def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): for params in current_params: # All-reduce the gradients in buckets offset = 0 @@ -159,19 +165,13 @@ def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, cur for param in params: if not param.requires_grad: continue - - if hasattr(param, 'base_expert'): - # Skip gradient sync for unshared parameters - continue - - if hasattr(param, 'expert'): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.data.div_(curr_world_size) - continue if param.grad is None: param.grad = torch.zeros_like(param) + else: + param.grad.data.div_(curr_world_size) + + + if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " @@ -181,14 +181,14 @@ def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, cur sz = param.numel() if sz > curr_buffer.numel(): # all-reduce big params directly - self.all_reduce_params([param], curr_buffer, curr_process_group, curr_world_size) + self.all_reduce_params([param], curr_buffer, curr_process_group) else: if offset + sz > curr_buffer.numel(): - self.all_reduce_params(buffered_params, curr_buffer, curr_process_group, curr_world_size) + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group) offset = 0 buffered_params.clear() buffered_params.append(param) offset += sz if len(buffered_params) > 0: - self.all_reduce_params(buffered_params, curr_buffer, curr_process_group, curr_world_size) + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group) From 8079646ba4ebdf3dd31b2284ee96a53aa98dfe8f Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 21:36:36 +0000 Subject: [PATCH 15/63] reduce moe by size 1 pg --- .../legacy_distributed_data_parallel.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index f5715adb14..f17dbd4df1 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -18,6 +18,7 @@ from contextlib import contextmanager import torch +import torch.distributed as dist from torch import nn from fairseq.distributed import utils @@ -54,7 +55,6 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.module = module self.process_group = process_group - self.world_size = utils.get_world_size(self.process_group) # Never use a bigger buffer than the number of model params self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) @@ -79,6 +79,15 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) + # assign local pg + my_rank = dist.get_rank() + world_size = dist.get_world_size() + for rank in range(world_size): + members = [rank] + pg = dist.new_group(members) + if rank == my_rank: + self.local_pg = pg + #start_pdb_on_rank_zero() #print('hi') @@ -142,22 +151,14 @@ def all_reduce_grads(self): if self.buffer is None: self.buffer = next(self.module.parameters()).new(self.buffer_size) - self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, self.world_size) - self._div_all_grads_by_worldsize(self.per_device_params_expert, self.world_size) - + # reduce normal params + self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group) + # reduce expert params + self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg) - def _div_all_grads_by_worldsize(self, current_params, curr_world_size): - for params in current_params: - for param in params: - if not param.requires_grad: - continue - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.data.div_(curr_world_size) - continue - def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): + def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group): + curr_world_size = dist.get_world_size(curr_process_group) for params in current_params: # All-reduce the gradients in buckets offset = 0 @@ -170,8 +171,6 @@ def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, cur else: param.grad.data.div_(curr_world_size) - - if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " From 2d41d29b8b3a78b279a22ec4ae022d6bdcb45abe Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Fri, 21 Apr 2023 21:51:39 +0000 Subject: [PATCH 16/63] always div by largest world size --- fairseq/distributed/legacy_distributed_data_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index f17dbd4df1..6a538ddd8d 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -152,13 +152,13 @@ def all_reduce_grads(self): self.buffer = next(self.module.parameters()).new(self.buffer_size) # reduce normal params - self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group) + curr_world_size = dist.get_world_size(self.process_group) + self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, curr_world_size) # reduce expert params - self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg) + self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) - def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group): - curr_world_size = dist.get_world_size(curr_process_group) + def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): for params in current_params: # All-reduce the gradients in buckets offset = 0 From 5c0b09f2f94f4e187b029ab7692a8952768c0192 Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Sat, 22 Apr 2023 00:52:54 +0000 Subject: [PATCH 17/63] set moe pg from torchscale --- .../legacy_distributed_data_parallel.py | 14 +++++++++----- fairseq_cli/train.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 6a538ddd8d..5f5048250c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -23,6 +23,7 @@ from fairseq.distributed import utils +from torchscale.component.xmoe.global_groups import get_moe_group def start_pdb_on_rank_zero(): rank = torch.distributed.get_rank() @@ -34,6 +35,9 @@ def start_pdb_on_rank_zero(): time.sleep(1e6) +NUM_EXPERTS = 4 + + class LegacyDistributedDataParallel(nn.Module): """Implements distributed data parallelism at the module level. @@ -82,11 +86,11 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # assign local pg my_rank = dist.get_rank() world_size = dist.get_world_size() - for rank in range(world_size): - members = [rank] - pg = dist.new_group(members) - if rank == my_rank: - self.local_pg = pg + + + assert world_size % NUM_EXPERTS == 0 or NUM_EXPERTS % world_size == 0 + assert hasattr(get_moe_group, "_moe_groups") # need to init groups first + _, self.local_pg = get_moe_group(NUM_EXPERTS) #start_pdb_on_rank_zero() #print('hi') diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 600473ce49..050ba38d39 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -90,8 +90,8 @@ def main(cfg: FairseqConfig) -> None: task = tasks.setup_task(cfg.task) assert cfg.criterion, "Please specify criterion to train a model" - if getattr(cfg.model, "moe_freq", 0) > 0 and getattr(cfg.model, "moe_expert_count", 0) < distributed_utils.get_global_world_size(): - assert cfg.distributed_training.ddp_backend == 'fully_sharded', 'num_experts < num_gpus only supported by FSDP' + #if getattr(cfg.model, "moe_freq", 0) > 0 and getattr(cfg.model, "moe_expert_count", 0) < distributed_utils.get_global_world_size(): + # assert cfg.distributed_training.ddp_backend == 'fully_sharded', 'num_experts < num_gpus only supported by FSDP' # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": From c69ff2a4c654039b14bb58d89b0b3b95ef5f2549 Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Mon, 24 Apr 2023 17:30:47 +0000 Subject: [PATCH 18/63] num experts optional --- fairseq/distributed/legacy_distributed_data_parallel.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 5f5048250c..d74f88ae9c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -35,8 +35,6 @@ def start_pdb_on_rank_zero(): time.sleep(1e6) -NUM_EXPERTS = 4 - class LegacyDistributedDataParallel(nn.Module): """Implements distributed data parallelism at the module level. @@ -84,13 +82,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) # assign local pg - my_rank = dist.get_rank() - world_size = dist.get_world_size() - - - assert world_size % NUM_EXPERTS == 0 or NUM_EXPERTS % world_size == 0 assert hasattr(get_moe_group, "_moe_groups") # need to init groups first - _, self.local_pg = get_moe_group(NUM_EXPERTS) + _, self.local_pg = get_moe_group() #start_pdb_on_rank_zero() #print('hi') From ba856723a57179e6aa7524934ea4b097104f1dc1 Mon Sep 17 00:00:00 2001 From: Shuming Ma Date: Sun, 30 Apr 2023 06:39:56 -0700 Subject: [PATCH 19/63] fx import torchscale error --- fairseq/distributed/legacy_distributed_data_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index d74f88ae9c..906064bfdf 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -23,7 +23,10 @@ from fairseq.distributed import utils -from torchscale.component.xmoe.global_groups import get_moe_group +try: + from torchscale.component.xmoe.global_groups import get_moe_group +except ModuleNotFoundError: + get_moe_group = None def start_pdb_on_rank_zero(): rank = torch.distributed.get_rank() From 96aad9d2dcb2bd884193b862c719023830fcd0f7 Mon Sep 17 00:00:00 2001 From: Shaohan Huang Date: Thu, 4 May 2023 13:36:54 +0800 Subject: [PATCH 20/63] Update search.py --- fairseq/search.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairseq/search.py b/fairseq/search.py index d5ea68b4ce..5f716548e9 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -690,6 +690,8 @@ def step( if self.sampling_topp > 0: # only sample from the smallest set of words whose cumulative probability mass exceeds p + if lprobs.max() < -50: + lprobs += (-50 - lprobs.max()) probs, top_indices = self._sample_topp(lprobs) elif self.sampling_topk > 0: # only sample from top-k candidates From a1c2c69de21eee001fde8cb76345d030a3fb2564 Mon Sep 17 00:00:00 2001 From: Shaohan Huang Date: Tue, 27 Jun 2023 14:33:17 +0800 Subject: [PATCH 21/63] Update trainer.py, remove sample length assert --- fairseq/trainer.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 6eec8766fd..0151162ee2 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -745,21 +745,21 @@ def train_step(self, samples, raise_oom=False): sample, is_dummy_batch = self._prepare_sample(sample) # MoE training with --batch-size or --max-sentences set - if self.is_moe and getattr(self.cfg.dataset, 'batch_size', None) is not None: - try: - fixed_src_seq_length = getattr(self.cfg.task, 'tokens_per_sample', None) or self.cfg.task.max_source_positions - assert sample['net_input']['src_tokens'].shape[1] == fixed_src_seq_length - except: - logger.warning(str(sample.keys())) - logger.warning(str(sample['net_input'].keys())) - logger.warning(is_dummy_batch) - logger.warning( - "wrong seq len {} on rank {}".format( - sample['net_input']['src_tokens'].shape[1], - torch.distributed.get_rank(), - ) - ) - raise AssertionError + # if self.is_moe and getattr(self.cfg.dataset, 'batch_size', None) is not None: + # try: + # fixed_src_seq_length = getattr(self.cfg.task, 'tokens_per_sample', None) or self.cfg.task.max_source_positions + # assert sample['net_input']['src_tokens'].shape[1] == fixed_src_seq_length + # except: + # logger.warning(str(sample.keys())) + # logger.warning(str(sample['net_input'].keys())) + # logger.warning(is_dummy_batch) + # logger.warning( + # "wrong seq len {} on rank {}".format( + # sample['net_input']['src_tokens'].shape[1], + # torch.distributed.get_rank(), + # ) + # ) + # raise AssertionError def maybe_no_sync(): """ From 291a26a83a40c62ea50987ddfa1242fcb946b484 Mon Sep 17 00:00:00 2001 From: Shaohan Huang Date: Thu, 24 Aug 2023 00:08:47 +0800 Subject: [PATCH 22/63] Update utils.py --- fairseq/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/utils.py b/fairseq/utils.py index a321c75129..ab0f7908e2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -555,6 +555,7 @@ def get_available_activation_fns() -> List: "gelu_accurate", "tanh", "linear", + "silu", ] From b169faa51a3198783f0fd7b499d70c689f03c4d7 Mon Sep 17 00:00:00 2001 From: Shaohan Huang Date: Mon, 23 Oct 2023 12:16:21 +0800 Subject: [PATCH 23/63] support torchrun --- fairseq/distributed/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 030de5f1b3..275d392df4 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -78,6 +78,7 @@ def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): cfg.distributed_init_method = "env://" cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) cfg.distributed_rank = int(os.environ["RANK"]) + cfg.device_id = int(os.environ['LOCAL_RANK']) # support torchrun # processes are created by torch.distributed.launch cfg.distributed_no_spawn = True From c3e77b6ef01520424309ddc0c1a284dd83c780f1 Mon Sep 17 00:00:00 2001 From: Li Dong Date: Sat, 2 Dec 2023 16:00:52 +0800 Subject: [PATCH 24/63] dataloader: resume job with more #GPUs --- fairseq/trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0151162ee2..11a9e6036d 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -485,6 +485,7 @@ def load_checkpoint( reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, + reset_dataloader=False, ): """ Load all training state from a checkpoint file. @@ -605,7 +606,13 @@ def load_checkpoint( itr_state = extra_state["train_iterator"] if type(itr_state) == list: # assert len(itr_state) == self.data_parallel_world_size - itr_state = itr_state[self.data_parallel_rank] + if len(itr_state) != self.data_parallel_world_size: + logger.info(f"Reload dataloader: #itr_state {len(itr_state)} != #dp_world_size {self.data_parallel_world_size}") + assert reset_dataloader, "please reset the dataloader (--reset-dataloader)" + # li: no worry, because --reset_dataloader, extra_state will be reset + itr_state = itr_state[self.data_parallel_rank % len(itr_state)] + else: + itr_state = itr_state[self.data_parallel_rank] extra_state["train_iterator"] = itr_state epoch = itr_state.get("epoch", 1) From 2905f8c7d5609b80b45d9756a25c66122391f084 Mon Sep 17 00:00:00 2001 From: Li Dong Date: Sat, 2 Dec 2023 16:02:00 +0800 Subject: [PATCH 25/63] Update checkpoint_utils.py --- fairseq/checkpoint_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 9d3735d364..6d4460c126 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -255,6 +255,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): reset_lr_scheduler, optimizer_overrides, reset_meters=reset_meters, + reset_dataloader=reset_dataloader, ) if ( From 960a0bfe6b811e2c3afd786b54abc04614f7f2b9 Mon Sep 17 00:00:00 2001 From: Shaohan Huang Date: Tue, 5 Dec 2023 10:19:20 +0800 Subject: [PATCH 26/63] fx megatron trainer load ckpt --- fairseq/model_parallel/megatron_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 4316ccfe1d..75581429cd 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -63,8 +63,9 @@ def load_checkpoint( reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, + reset_dataloader=False, ): - extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) + extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters, reset_dataloader=reset_dataloader) if extra_state is not None and 'rng_tracker_states' in extra_state: get_cuda_rng_tracker().set_states( extra_state['rng_tracker_states']) From 037244967ae294ec27281de92ce1570cacd682fb Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Wed, 29 May 2024 16:21:34 +0800 Subject: [PATCH 27/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index d6a6823aea..43936124a7 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -152,19 +152,7 @@ class MoECriterionConfig(FairseqDataclass): class MoECriterion(FairseqCriterion): - moe_logging_keys = [ - "overflow_expert1", # average % of overflowed tokens from 1st expert - "overflow_expert2", # average % of overflowed tokens from 2nd expert - "entropy_gating", # average entropy of the gating distribution - "expert1_balance_top", # average cumulative % of tokens processed by the most used 20% 1st experts - "expert1_balance_bottom", # average cumulative % of tokens processed by the least used 20% 1st experts - "unused_expert1_count", # average number of 1st experts which process no tokens - "expert2_balance_top", # average cumulative % of tokens processed by the most used 20% 2nd experts - "expert2_balance_bottom", # average cumulative % of tokens processed by the least used 20% 2nd experts - "unused_expert2_count", # average number of 2nd experts which process no tokens - "all_to_all_cpu_time_ms", # CPU time spent in all to all calls in milliseconds - "all_to_all_cuda_time_ms", # CUDA ttime spent in all to all calls in milliseconds - ] + moe_logging_keys = [] def __init__(self, task, moe_gate_loss_wt, moe_gate_loss_combine_method, moe_gate_loss_transform, sentence_avg): super().__init__(task) self.gate_loss_weight = moe_gate_loss_wt From c0173dc1ddad4e22d33094bb9f53d95fceb4de90 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Thu, 30 May 2024 21:35:48 +0800 Subject: [PATCH 28/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 906064bfdf..5f9dccc5c2 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -85,8 +85,9 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) # assign local pg - assert hasattr(get_moe_group, "_moe_groups") # need to init groups first - _, self.local_pg = get_moe_group() + ####### # by mega # ################ + # assert hasattr(get_moe_group, "_moe_groups") # need to init groups first + # _, self.local_pg = get_moe_group() #start_pdb_on_rank_zero() #print('hi') From 67b73ea10b2911bb47735f17679be0c42b3575bd Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Thu, 30 May 2024 22:06:31 +0800 Subject: [PATCH 29/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 5f9dccc5c2..d0daecfa77 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -86,8 +86,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # assign local pg ####### # by mega # ################ - # assert hasattr(get_moe_group, "_moe_groups") # need to init groups first - # _, self.local_pg = get_moe_group() + assert hasattr(get_moe_group, "_moe_groups") # need to init groups first + _, self.local_pg = get_moe_group() #start_pdb_on_rank_zero() #print('hi') From 6dc3b97e472e5bb23f95efb244e87aa08893379c Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Fri, 31 May 2024 00:43:27 +0800 Subject: [PATCH 30/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index d0daecfa77..5562af1bd5 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -157,6 +157,7 @@ def all_reduce_grads(self): self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, curr_world_size) # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) + print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From 8f63e4952d36b55c200645661c1756ed76dd0b81 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Fri, 31 May 2024 02:48:30 +0800 Subject: [PATCH 31/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 5562af1bd5..29b4785302 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -157,7 +157,7 @@ def all_reduce_grads(self): self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, curr_world_size) # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) - print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") + # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From 3ba1de29db1bce6096b100df6ee24885174148ab Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Fri, 31 May 2024 21:45:36 +0800 Subject: [PATCH 32/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 29b4785302..082bafa504 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -158,6 +158,7 @@ def all_reduce_grads(self): # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") + torch.cuda.empty_cache() def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From e9ada9c872855633bec689418e9a7f3e051ad9b7 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 00:36:38 +0800 Subject: [PATCH 33/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 082bafa504..311af5ffc1 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -20,7 +20,7 @@ import torch import torch.distributed as dist from torch import nn - +import gc from fairseq.distributed import utils try: @@ -159,6 +159,7 @@ def all_reduce_grads(self): self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") torch.cuda.empty_cache() + gc.collect() def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From f788465bf2fc2f16af2c384cb4ddc088953a3f49 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 00:55:37 +0800 Subject: [PATCH 34/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 311af5ffc1..59b53c3f1c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -158,8 +158,6 @@ def all_reduce_grads(self): # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") - torch.cuda.empty_cache() - gc.collect() def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From ab5a1cdd04692f3447a347bceda6af47ae702ed4 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 01:54:16 +0800 Subject: [PATCH 35/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 43936124a7..bb32d18aca 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -184,6 +184,7 @@ def compute_loss(self, model, sample, reduce=True): if l_aux is not None: gate_loss += l_aux gate_count += 1 + assert 0 if self.gate_loss_combine_method == "average": gate_loss = gate_loss / gate_count if self.gate_loss_transform == "neg_log": From a7e65fa27c8583a6af79e4e27b008c1de6e00e95 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 02:25:23 +0800 Subject: [PATCH 36/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 59 +++++++++++++++++++------ 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index bb32d18aca..bab9a634a4 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -18,6 +18,10 @@ from omegaconf import II from fairseq.modules.moe import MOELayer +from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss +from megablocks.layers.arguments import Arguments as dMoEArgs +from megablocks.layers.arguments import from_megatron + class FairseqCriterion(_Loss): def __init__(self, task): super().__init__() @@ -160,6 +164,30 @@ def __init__(self, task, moe_gate_loss_wt, moe_gate_loss_combine_method, moe_gat self.gate_loss_transform = moe_gate_loss_transform self.sentence_avg = sentence_avg + #### dmoe args + if args.moe_top1_expert: + moe_top_k = 1 + else: + moe_top_k = 2 + init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + self.dmoe_args = dMoEArgs( + hidden_size=args.decoder_embed_dim, + ffn_hidden_size=args.ffn_dim, + moe_num_experts=args.decoder_ffn_embed_dim, + moe_capacity_factor=1.25, + moe_top_k=moe_top_k, + init_method=init_method, + moe_expert_model_parallelism=True, + memory_optimized_mlp=True, + expert_parallel_group=None, + device=torch.cuda.current_device(), + mlp_type='mlp', + mlp_impl='sparse', + fp16=True, + bf16=False, + moe_loss_weight=self.gate_loss_weight, + ) + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -178,19 +206,24 @@ def forward(self, model, sample, reduce=True): def compute_loss(self, model, sample, reduce=True): net_output, inner_loss, sample_size, logging_output = self.compute_inner_loss(model, sample) - gate_loss = 0.0 - gate_count = 0 - for l_aux in net_output[1]["l_aux"]: - if l_aux is not None: - gate_loss += l_aux - gate_count += 1 - assert 0 - if self.gate_loss_combine_method == "average": - gate_loss = gate_loss / gate_count - if self.gate_loss_transform == "neg_log": - gate_loss = - torch.log(gate_loss) - gate_loss = sample_size * gate_loss - loss = inner_loss + self.gate_loss_weight * gate_loss + + # gate_loss = 0.0 + # gate_count = 0 + # for l_aux in net_output[1]["l_aux"]: + # if l_aux is not None: + # gate_loss += l_aux + # gate_count += 1 + + # if self.gate_loss_combine_method == "average": + # gate_loss = gate_loss / gate_count + # if self.gate_loss_transform == "neg_log": + # gate_loss = - torch.log(gate_loss) + + total_load_balancing_loss = batched_load_balancing_loss(self.dmoe_args) + clear_load_balancing_loss() + + gate_loss = sample_size * total_load_balancing_loss + loss = inner_loss + gate_loss return loss, inner_loss, gate_loss, self.get_moe_metadata(model), sample_size, logging_output def compute_inner_loss(self, model, sample): From 7666f2ba7c78f7d9e0d656baa5d0438b49114316 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 02:52:13 +0800 Subject: [PATCH 37/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 47 +++++++++++++------------ 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index bab9a634a4..f1601db154 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -165,28 +165,28 @@ def __init__(self, task, moe_gate_loss_wt, moe_gate_loss_combine_method, moe_gat self.sentence_avg = sentence_avg #### dmoe args - if args.moe_top1_expert: - moe_top_k = 1 - else: - moe_top_k = 2 - init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - self.dmoe_args = dMoEArgs( - hidden_size=args.decoder_embed_dim, - ffn_hidden_size=args.ffn_dim, - moe_num_experts=args.decoder_ffn_embed_dim, - moe_capacity_factor=1.25, - moe_top_k=moe_top_k, - init_method=init_method, - moe_expert_model_parallelism=True, - memory_optimized_mlp=True, - expert_parallel_group=None, - device=torch.cuda.current_device(), - mlp_type='mlp', - mlp_impl='sparse', - fp16=True, - bf16=False, - moe_loss_weight=self.gate_loss_weight, - ) + # if args.moe_top1_expert: + # moe_top_k = 1 + # else: + # moe_top_k = 2 + # init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + # self.dmoe_args = dMoEArgs( + # hidden_size=args.decoder_embed_dim, + # ffn_hidden_size=args.ffn_dim, + # moe_num_experts=args.decoder_ffn_embed_dim, + # moe_capacity_factor=1.25, + # moe_top_k=moe_top_k, + # init_method=init_method, + # moe_expert_model_parallelism=True, + # memory_optimized_mlp=True, + # expert_parallel_group=None, + # device=torch.cuda.current_device(), + # mlp_type='mlp', + # mlp_impl='sparse', + # fp16=True, + # bf16=False, + # moe_loss_weight=self.gate_loss_weight, + # ) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -219,7 +219,8 @@ def compute_loss(self, model, sample, reduce=True): # if self.gate_loss_transform == "neg_log": # gate_loss = - torch.log(gate_loss) - total_load_balancing_loss = batched_load_balancing_loss(self.dmoe_args) + dmoe_args = net_output[1]["dmoe_args_list"][0] + total_load_balancing_loss = batched_load_balancing_loss(dmoe_args) clear_load_balancing_loss() gate_loss = sample_size * total_load_balancing_loss From d056ae578f0da478504e4e15f43ef8dc7661d2ee Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 03:22:57 +0800 Subject: [PATCH 38/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index f1601db154..15e39acd49 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -18,7 +18,7 @@ from omegaconf import II from fairseq.modules.moe import MOELayer -from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss +from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss, average_losses_across_data_parallel_group from megablocks.layers.arguments import Arguments as dMoEArgs from megablocks.layers.arguments import from_megatron @@ -222,6 +222,8 @@ def compute_loss(self, model, sample, reduce=True): dmoe_args = net_output[1]["dmoe_args_list"][0] total_load_balancing_loss = batched_load_balancing_loss(dmoe_args) clear_load_balancing_loss() + + total_load_balancing_loss = average_losses_across_data_parallel_group([total_load_balancing_loss]) gate_loss = sample_size * total_load_balancing_loss loss = inner_loss + gate_loss From 793a5c4ee3e95752082627e89acdaf9af28a326f Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sat, 1 Jun 2024 03:30:32 +0800 Subject: [PATCH 39/63] Update fairseq_criterion.py --- fairseq/criterions/fairseq_criterion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 15e39acd49..4dcf930394 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -223,7 +223,7 @@ def compute_loss(self, model, sample, reduce=True): total_load_balancing_loss = batched_load_balancing_loss(dmoe_args) clear_load_balancing_loss() - total_load_balancing_loss = average_losses_across_data_parallel_group([total_load_balancing_loss]) + # total_load_balancing_loss = average_losses_across_data_parallel_group([total_load_balancing_loss]) gate_loss = sample_size * total_load_balancing_loss loss = inner_loss + gate_loss From a25d55eb764f45a500fbc73583ff813e4426e6f0 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 00:50:38 +0800 Subject: [PATCH 40/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 59b53c3f1c..24cdb109f3 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -158,6 +158,7 @@ def all_reduce_grads(self): # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") + print("self.per_device_params_expert", self.per_device_params_expert) def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From 758d10151d10a6d7630d0e5f481a75e567cf829e Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 00:59:02 +0800 Subject: [PATCH 41/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 24cdb109f3..3156afa26c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -79,8 +79,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # split into expert and normal params per_device_params = list(paramlists.values()) - self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] - self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] + self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] + self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) From 9a74df76e04353a8284e5be50dbdcff7612e6144 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:03:28 +0800 Subject: [PATCH 42/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 3156afa26c..eca0bc2a80 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -72,6 +72,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # make per-device lists of parameters paramlists = OrderedDict() for param in self.module.parameters(): + print(param) device = param.device if paramlists.get(device) is None: paramlists[device] = [] @@ -81,7 +82,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): per_device_params = list(paramlists.values()) self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] - + print("self.per_device_params_expert", self.per_device_params_expert) + assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) # assign local pg @@ -158,7 +160,6 @@ def all_reduce_grads(self): # reduce expert params self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") - print("self.per_device_params_expert", self.per_device_params_expert) def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From c9446936146e1c8c90156334e46b9ded381ee171 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:09:42 +0800 Subject: [PATCH 43/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index eca0bc2a80..ad95f00b5d 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -71,8 +71,9 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # make per-device lists of parameters paramlists = OrderedDict() + print("module", module) for param in self.module.parameters(): - print(param) + # print(param) device = param.device if paramlists.get(device) is None: paramlists[device] = [] From bbe6b56b0d0f614dfc08a833a55c6051c29f4f5a Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:38:25 +0800 Subject: [PATCH 44/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index ad95f00b5d..fa03acdcd3 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -84,6 +84,9 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] print("self.per_device_params_expert", self.per_device_params_expert) + for t in per_device_params: + for k in t: + print(k) assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) From 3d9507e62584faf4481b8eaac327f0df8a7f437a Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:49:08 +0800 Subject: [PATCH 45/63] Update legacy_distributed_data_parallel.py --- .../legacy_distributed_data_parallel.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index fa03acdcd3..e906f9886c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -71,22 +71,37 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # make per-device lists of parameters paramlists = OrderedDict() - print("module", module) - for param in self.module.parameters(): - # print(param) + + + + # for param in self.module.parameters(): + # device = param.device + # if paramlists.get(device) is None: + # paramlists[device] = [] + # paramlists[device] += [param] + + # # split into expert and normal params + # per_device_params = list(paramlists.values()) + # self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] + # self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] + + + # 按设备分组参数,并保留参数名称 + for name, param in self.module.named_parameters(): device = param.device if paramlists.get(device) is None: paramlists[device] = [] - paramlists[device] += [param] - - # split into expert and normal params + paramlists[device].append((name, param)) + + # 获取按设备分组的参数列表 per_device_params = list(paramlists.values()) - self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] - self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] + + # 分割专家参数和普通参数 + self.per_device_params_normal = [[param for name, param in t if 'experts' not in name] for t in per_device_params] + self.per_device_params_expert = [[param for name, param in t if 'experts' in name] for t in per_device_params] + + # 打印专家参数(可选) print("self.per_device_params_expert", self.per_device_params_expert) - for t in per_device_params: - for k in t: - print(k) assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) From 14533babb986c32fafe490df0a65df80daf762c4 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 05:27:45 +0800 Subject: [PATCH 46/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index e906f9886c..bd656d51e3 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -101,8 +101,11 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_expert = [[param for name, param in t if 'experts' in name] for t in per_device_params] # 打印专家参数(可选) - print("self.per_device_params_expert", self.per_device_params_expert) - + print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) + for t in self.per_device_params_expert: + for name, param in t: + print(name) + assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) # assign local pg From f3ee1f5c2476e37c9fdd0444eb67334ea9f9e040 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 06:13:28 +0800 Subject: [PATCH 47/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index bd656d51e3..127b021eab 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -95,6 +95,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # 获取按设备分组的参数列表 per_device_params = list(paramlists.values()) + print("len(per_device_params)", len(per_device_params)) # 分割专家参数和普通参数 self.per_device_params_normal = [[param for name, param in t if 'experts' not in name] for t in per_device_params] From cabf010ba8496b5ab8f4fc32fbd27068f52ce49b Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 06:19:25 +0800 Subject: [PATCH 48/63] Update legacy_distributed_data_parallel.py --- .../distributed/legacy_distributed_data_parallel.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 127b021eab..681ca500b0 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -74,12 +74,13 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): - # for param in self.module.parameters(): - # device = param.device - # if paramlists.get(device) is None: - # paramlists[device] = [] - # paramlists[device] += [param] - + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + print("len(per_device_params)", len(per_device_params)) + # # split into expert and normal params # per_device_params = list(paramlists.values()) # self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] From 7063aaab49d47e6a634dc17d82f8a55c81fcd210 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 06:21:51 +0800 Subject: [PATCH 49/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 681ca500b0..f41b95ea78 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -79,10 +79,12 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): if paramlists.get(device) is None: paramlists[device] = [] paramlists[device] += [param] + + # split into expert and normal params + per_device_params = list(paramlists.values()) + print("len(per_device_params)", len(per_device_params)) - # # split into expert and normal params - # per_device_params = list(paramlists.values()) # self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] # self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] From 12bdff7a6b84f58cc648d184d45831839555aa8e Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 06:43:36 +0800 Subject: [PATCH 50/63] Update legacy_distributed_data_parallel.py --- .../legacy_distributed_data_parallel.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index f41b95ea78..9d59b8a234 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -74,16 +74,16 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): - for param in self.module.parameters(): - device = param.device - if paramlists.get(device) is None: - paramlists[device] = [] - paramlists[device] += [param] - - # split into expert and normal params - per_device_params = list(paramlists.values()) + # for param in self.module.parameters(): + # device = param.device + # if paramlists.get(device) is None: + # paramlists[device] = [] + # paramlists[device] += [param] + + # # split into expert and normal params + # per_device_params = list(paramlists.values()) - print("len(per_device_params)", len(per_device_params)) + # print("len(per_device_params)", len(per_device_params)) # self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] # self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] From b44bd89d9db99cecb0d25875d2c75ff254391db7 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 06:44:36 +0800 Subject: [PATCH 51/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 9d59b8a234..8dbbbe4dc8 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -107,8 +107,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # 打印专家参数(可选) print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) for t in self.per_device_params_expert: - for name, param in t: - print(name) + print(t) assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) From 1ac4823d0b86392dbb6bbfd111e860c436100482 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 07:08:30 +0800 Subject: [PATCH 52/63] Update legacy_distributed_data_parallel.py --- .../legacy_distributed_data_parallel.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 8dbbbe4dc8..4dbe2be2ee 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -74,40 +74,40 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): - # for param in self.module.parameters(): - # device = param.device - # if paramlists.get(device) is None: - # paramlists[device] = [] - # paramlists[device] += [param] + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] - # # split into expert and normal params - # per_device_params = list(paramlists.values()) + # split into expert and normal params + per_device_params = list(paramlists.values()) - # print("len(per_device_params)", len(per_device_params)) + print("len(per_device_params)", len(per_device_params)) - # self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] - # self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] + self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] + self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] - # 按设备分组参数,并保留参数名称 - for name, param in self.module.named_parameters(): - device = param.device - if paramlists.get(device) is None: - paramlists[device] = [] - paramlists[device].append((name, param)) + # # 按设备分组参数,并保留参数名称 + # for name, param in self.module.named_parameters(): + # device = param.device + # if paramlists.get(device) is None: + # paramlists[device] = [] + # paramlists[device].append((name, param)) - # 获取按设备分组的参数列表 - per_device_params = list(paramlists.values()) - print("len(per_device_params)", len(per_device_params)) + # # 获取按设备分组的参数列表 + # per_device_params = list(paramlists.values()) + # print("len(per_device_params)", len(per_device_params)) - # 分割专家参数和普通参数 - self.per_device_params_normal = [[param for name, param in t if 'experts' not in name] for t in per_device_params] - self.per_device_params_expert = [[param for name, param in t if 'experts' in name] for t in per_device_params] + # # 分割专家参数和普通参数 + # self.per_device_params_normal = [[param for name, param in t if 'experts' not in name] for t in per_device_params] + # self.per_device_params_expert = [[param for name, param in t if 'experts' in name] for t in per_device_params] - # 打印专家参数(可选) - print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) - for t in self.per_device_params_expert: - print(t) + # # 打印专家参数(可选) + # print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) + # for t in self.per_device_params_expert: + # print(t) assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) From 6359b7dc13db61c8126f0764be2df55938e4f0b0 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 07:12:20 +0800 Subject: [PATCH 53/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 4dbe2be2ee..338a78428a 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -87,7 +87,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] - + print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) # # 按设备分组参数,并保留参数名称 # for name, param in self.module.named_parameters(): From 58912a22c1add4061cdd78c0a1dbac5c89c5d8a2 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 21:39:16 +0800 Subject: [PATCH 54/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 338a78428a..242c519bfa 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -87,7 +87,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] - print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) + print("self.per_device_params_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), "self.per_device_params_normal", len(self.per_device_params_normal), len(self.per_device_params_normal[0])) # # 按设备分组参数,并保留参数名称 # for name, param in self.module.named_parameters(): From ee270edefa48ffa20fe6ee3de1a34acb39c5ddfc Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 21:48:08 +0800 Subject: [PATCH 55/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 242c519bfa..a9e0ca3575 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -85,8 +85,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): print("len(per_device_params)", len(per_device_params)) - self.per_device_params_normal = [[k for k in t if not hasattr(k, 'experts')] for t in per_device_params] - self.per_device_params_expert = [[k for k in t if hasattr(k, 'experts')] for t in per_device_params] + self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] + self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] print("self.per_device_params_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), "self.per_device_params_normal", len(self.per_device_params_normal), len(self.per_device_params_normal[0])) # # 按设备分组参数,并保留参数名称 From f921d445db2fb9b8ff3d4b22f1dad084b70eaded Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 21:48:32 +0800 Subject: [PATCH 56/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index a9e0ca3575..f510f6a8eb 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -83,7 +83,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # split into expert and normal params per_device_params = list(paramlists.values()) - print("len(per_device_params)", len(per_device_params)) + print("len(per_device_params)", len(per_device_params), len(per_device_params[0])) self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] From 9a78abe4bc2ba3103aaff3a0e4fc569106adb97f Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 22:42:04 +0800 Subject: [PATCH 57/63] Update legacy_distributed_data_parallel.py --- .../legacy_distributed_data_parallel.py | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index f510f6a8eb..e7efc73ec3 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -88,26 +88,6 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): self.per_device_params_normal = [[k for k in t if not hasattr(k, 'expert')] for t in per_device_params] self.per_device_params_expert = [[k for k in t if hasattr(k, 'expert')] for t in per_device_params] print("self.per_device_params_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), "self.per_device_params_normal", len(self.per_device_params_normal), len(self.per_device_params_normal[0])) - - # # 按设备分组参数,并保留参数名称 - # for name, param in self.module.named_parameters(): - # device = param.device - # if paramlists.get(device) is None: - # paramlists[device] = [] - # paramlists[device].append((name, param)) - - # # 获取按设备分组的参数列表 - # per_device_params = list(paramlists.values()) - # print("len(per_device_params)", len(per_device_params)) - - # # 分割专家参数和普通参数 - # self.per_device_params_normal = [[param for name, param in t if 'experts' not in name] for t in per_device_params] - # self.per_device_params_expert = [[param for name, param in t if 'experts' in name] for t in per_device_params] - - # # 打印专家参数(可选) - # print("self.per_device_params_expert", len(self.per_device_params_expert), "self.per_device_params_normal", len(self.per_device_params_normal)) - # for t in self.per_device_params_expert: - # print(t) assert all([len([k for k in t if hasattr(k, 'base_expert')]) == 0 for t in per_device_params]) @@ -183,8 +163,9 @@ def all_reduce_grads(self): curr_world_size = dist.get_world_size(self.process_group) self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, curr_world_size) # reduce expert params - self._all_reduce_grads(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) - # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") + print("begin _all_reduce_grads_for_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), self.buffer, self.local_pg, curr_world_size) + self._all_reduce_grads_for_expert(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) + print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): @@ -220,3 +201,38 @@ def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, cur if len(buffered_params) > 0: self.all_reduce_params(buffered_params, curr_buffer, curr_process_group) + + + def _all_reduce_grads_for_expert(self, current_params, curr_buffer, curr_process_group, curr_world_size): + for params in current_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.data.div_(curr_world_size) + + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > curr_buffer.numel(): + # all-reduce big params directly + self.all_reduce_params([param], curr_buffer, curr_process_group) + else: + if offset + sz > curr_buffer.numel(): + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz + + if len(buffered_params) > 0: + self.all_reduce_params(buffered_params, curr_buffer, curr_process_group) From 802d6e9e459b7963c34ab26ee1439a6e932a1821 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Sun, 16 Jun 2024 23:19:55 +0800 Subject: [PATCH 58/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index e7efc73ec3..9524a47879 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -163,9 +163,9 @@ def all_reduce_grads(self): curr_world_size = dist.get_world_size(self.process_group) self._all_reduce_grads(self.per_device_params_normal, self.buffer, self.process_group, curr_world_size) # reduce expert params - print("begin _all_reduce_grads_for_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), self.buffer, self.local_pg, curr_world_size) + # print("begin _all_reduce_grads_for_expert", len(self.per_device_params_expert), len(self.per_device_params_expert[0]), self.buffer, self.local_pg, curr_world_size) self._all_reduce_grads_for_expert(self.per_device_params_expert, self.buffer, self.local_pg, curr_world_size) - print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") + # print("curr_world_size", curr_world_size, "_all_reduce_grads for experts finished.") def _all_reduce_grads(self, current_params, curr_buffer, curr_process_group, curr_world_size): From ddfc38b293843bf22b6f177cff3b2b9f4c4f2049 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Mon, 17 Jun 2024 23:23:27 +0800 Subject: [PATCH 59/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 9524a47879..07b11f90a5 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -210,12 +210,14 @@ def _all_reduce_grads_for_expert(self, current_params, curr_buffer, curr_process buffered_params = [] for param in params: if not param.requires_grad: + assert 0 continue if param.grad is None: + assert 0 param.grad = torch.zeros_like(param) else: param.grad.data.div_(curr_world_size) - + print(param.grad) if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " From 89e778f4b05355d7a95f753cb842e6a18c43cfc1 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Mon, 17 Jun 2024 23:33:01 +0800 Subject: [PATCH 60/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 07b11f90a5..3ed245ce79 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -213,6 +213,7 @@ def _all_reduce_grads_for_expert(self, current_params, curr_buffer, curr_process assert 0 continue if param.grad is None: + assert param.requires_grad assert 0 param.grad = torch.zeros_like(param) else: From 495cd079dcf401b6b807e2c385f0efd340fe30f2 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Tue, 18 Jun 2024 08:58:39 +0800 Subject: [PATCH 61/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 3ed245ce79..c9472b385b 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -218,7 +218,7 @@ def _all_reduce_grads_for_expert(self, current_params, curr_buffer, curr_process param.grad = torch.zeros_like(param) else: param.grad.data.div_(curr_world_size) - print(param.grad) + print(param.grad.sum()) if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " From f513b25ff89550105ee612382998314ce68f6f11 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Wed, 19 Jun 2024 05:15:06 +0800 Subject: [PATCH 62/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index c9472b385b..431f58bf44 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -93,8 +93,8 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): # assign local pg ####### # by mega # ################ - assert hasattr(get_moe_group, "_moe_groups") # need to init groups first - _, self.local_pg = get_moe_group() + # assert hasattr(get_moe_group, "_moe_groups") # need to init groups first + # _, self.local_pg = get_moe_group() #start_pdb_on_rank_zero() #print('hi') From ba4d46904606dc169ece2ab61a6a0f52168f40b7 Mon Sep 17 00:00:00 2001 From: Xun Wu <138114252+yushuiwx@users.noreply.github.com> Date: Wed, 19 Jun 2024 05:22:33 +0800 Subject: [PATCH 63/63] Update legacy_distributed_data_parallel.py --- fairseq/distributed/legacy_distributed_data_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index 431f58bf44..cc1b59d271 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -95,6 +95,7 @@ def __init__(self, module, process_group, buffer_size=2 ** 28): ####### # by mega # ################ # assert hasattr(get_moe_group, "_moe_groups") # need to init groups first # _, self.local_pg = get_moe_group() + self.local_pg = None #start_pdb_on_rank_zero() #print('hi')