From 4c7cb114d17127140bc17e7cbe6a91905e879b6a Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 4 Jun 2024 11:22:29 -0400 Subject: [PATCH 01/28] inital tp commits --- megatron/model/gpt2_model.py | 1 + megatron/model/rwkv/v6/rwkv.py | 60 +++++++++++++++++++++------- megatron/neox_arguments/arguments.py | 8 +--- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..e866ffb40 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -258,6 +258,7 @@ def init_specs(self): LayerSpec( RWKVResidualLayerPipe, neox_args=self.neox_args, + init_method=self.init_method, layer_number=i, ) ) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 5d4e0d144..eaeec4ad7 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.nn import functional as F from torch.utils.cpp_extension import load - +from megatron import mpu class WKV(torch.autograd.Function): """ @@ -179,7 +179,7 @@ def __init__(self, neox_args, layer_number): self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) - self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -228,15 +228,19 @@ def forward(self, x): return self.jit_func_2(x, g) -class RWKV_ChannelMix(nn.Module): +class ParallelRWKV_ChannelMix(nn.Module): """ Channel Mix layer. The ffn in RWKV """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix @@ -247,21 +251,46 @@ def __init__(self, neox_args, layer_number): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) - self.receptance = nn.Linear( - neox_args.hidden_size, neox_args.hidden_size, bias=False - ) - self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) - + #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.key = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_ffn, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.receptance = nn.Linear( + # neox_args.hidden_size, neox_args.hidden_size, bias=False + #) + self.receptance = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + gather_output=True, + init_method=init_method, + bias=False + ) + #self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + self.value = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.dim_ffn, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=init_method, + parallel_output=False, + bias=False + ) def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k xr = x + xx * self.time_maa_r - k = self.key(xk) + k, _ = self.key(xk) k = torch.relu(k) ** 2 - kv = self.value(k) - return torch.sigmoid(self.receptance(xr)) * kv + kv, _ = self.value(k) + receptance, _ = self.receptance(xr) + return torch.sigmoid(receptance) * kv class RWKVResidualLayer(nn.Module): @@ -269,7 +298,7 @@ class RWKVResidualLayer(nn.Module): RWKV layer definition """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args self.layer_number = layer_number @@ -288,6 +317,7 @@ def __init__(self, neox_args, layer_number): self.num_attention_heads = neox_args.num_attention_heads assert neox_args.dim_att % self.num_attention_heads == 0 + self.init_method = init_method if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) @@ -296,7 +326,7 @@ def __init__(self, neox_args, layer_number): self.att = RWKV_TimeMix(neox_args, layer_number) - self.ffn = RWKV_ChannelMix(neox_args, layer_number) + self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) if neox_args.attention_dropout > 0: self.drop0 = nn.Dropout(p=neox_args.attention_dropout) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index ff4f4bc21..3dda6489d 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1066,17 +1066,11 @@ def calculate_derived(self): if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba" assert ( - self.hidden_dropout == 0.0, + self.hidden_dropout != 0.0, ), "Mamba does not yet have dropout implemented" if "rwkv" in self.attention_config: - assert ( - not self.is_pipe_parallel and self.model_parallel_size == 1 - ), "RWKV not currently compatible with parallelism" if isinstance(self.zero_stage, int): assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" - assert ( - self.hidden_dropout == 0.0, - ), "RWKV does not yet have dropout implemented" # Sparsity config if self.sparsity_config is None: From 46904d5a2cdf3579ae09d499a2324db316388148 Mon Sep 17 00:00:00 2001 From: jahatef Date: Wed, 19 Jun 2024 21:15:05 +0000 Subject: [PATCH 02/28] setup --- configs/local_setup.yml | 4 ++++ megatron/model/rwkv/v6/rwkv.py | 1 + 2 files changed, 5 insertions(+) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index d031a2ad8..3bf17ca3d 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -22,6 +22,10 @@ "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, + + # "launcher": "openmpi", + #"deepspeed_mpi": true, + "tensorboard_dir": "tensorboard", "log_dir": "logs", "use_wandb": True, diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index eaeec4ad7..b2a261842 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -223,6 +223,7 @@ def forward(self, x): H = self.neox_args.num_attention_heads r, k, v, g, w = self.jit_func(x) + print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) return self.jit_func_2(x, g) From e2933ef060fcf1d041137018e6f12b7b2b140a4b Mon Sep 17 00:00:00 2001 From: jahatef Date: Wed, 25 Sep 2024 17:23:24 +0000 Subject: [PATCH 03/28] configs --- configs/rwkv/1.5B.yml | 103 ++++++++++++++++++++++++++++++++++++++++++ configs/rwkv/7B.yml | 102 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 configs/rwkv/1.5B.yml create mode 100644 configs/rwkv/7B.yml diff --git a/configs/rwkv/1.5B.yml b/configs/rwkv/1.5B.yml new file mode 100644 index 000000000..0d97a7861 --- /dev/null +++ b/configs/rwkv/1.5B.yml @@ -0,0 +1,103 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 2, + + "num_layers": 24, + "hidden_size": 2048, + "num_attention_heads": 32, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 1, + + "attention_config": [[["rwkv"], 24]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 1, + "lr_decay_iters": 1, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + "seed": 1234, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/configs/rwkv/7B.yml b/configs/rwkv/7B.yml new file mode 100644 index 000000000..7e999d250 --- /dev/null +++ b/configs/rwkv/7B.yml @@ -0,0 +1,102 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 64, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 8, + + "attention_config": [[["rwkv"], 32]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 500, + "lr_decay_iters": 500, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} From 43d641da86f8a32a7ce9baca55c85bf4520e8216 Mon Sep 17 00:00:00 2001 From: jahatef Date: Thu, 3 Oct 2024 17:58:21 +0000 Subject: [PATCH 04/28] time mixing tp --- megatron/model/rwkv/v6/rwkv.py | 82 +++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 15 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 8d039f341..923de0835 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -8,6 +8,7 @@ from torch.nn import functional as F from torch.utils.cpp_extension import load from megatron import mpu +from mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region class WKV(torch.autograd.Function): """ @@ -104,7 +105,7 @@ class RWKV_TimeMix(nn.Module): TODO: fix jit compiling. """ - def __init__(self, neox_args, layer_number): + def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number @@ -172,14 +173,62 @@ def __init__(self, neox_args, layer_number): ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.receptance = nn.Linear( - neox_args.hidden_size, neox_args.dim_att, bias=False - ) - self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - - self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) - self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) - self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column + #self.receptance = nn.Linear( + # neox_args.hidden_size, neox_args.dim_att, bias=False + #) + self.receptance = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.key = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.value = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) + self.output = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.dim_att, + output_size=neox_args.hidden_size, + gather_output=False, + init_method=init_method, + bias=False, + ) + #self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column + self.gate = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + gather_output=False, + init_method=init_method, + bias=False, + ) + self.gate = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.dim_att, + input_is_parallel=True, + init_method=init_method, + parallel_output=False, + bias=False + ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -200,10 +249,10 @@ def jit_func(self, x): xr = x + xx * (self.time_maa_r + mr) xg = x + xx * (self.time_maa_g + mg) - r = self.receptance(xr) - k = self.key(xk) - v = self.value(xv) - g = F.silu(self.gate(xg)) + r, _ = self.receptance(xr) + k, _ = self.key(xk) + v, _ = self.value(xv) + g, _ = F.silu(self.gate(xg)) ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww @@ -215,7 +264,9 @@ def jit_func_2(self, x, g): x = x.view(B * T, C) x = self.ln_x(x).view(B, T, C) - x = self.output(x * g) + print(f"shape of x: {x.size()}, shape of g: {g.size()}") + x, _ = self.output(x * g) + print(f"new shape of x: {x.size()}") return x def forward(self, x): @@ -225,6 +276,7 @@ def forward(self, x): r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) + x = reduce_from_model_parallel_region(x) return self.jit_func_2(x, g) @@ -334,7 +386,7 @@ def __init__(self, neox_args, init_method, layer_number): self.ln1 = nn.LayerNorm(neox_args.hidden_size) self.ln2 = nn.LayerNorm(neox_args.hidden_size) - self.att = RWKV_TimeMix(neox_args, layer_number) + self.att = RWKV_TimeMix(neox_args, layer_number, init_method=init_method) self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method) From de02f37b69cec27b6042bdfcaebce41ecb5cc7cc Mon Sep 17 00:00:00 2001 From: jahatef Date: Fri, 11 Oct 2024 02:21:30 +0000 Subject: [PATCH 05/28] time-mixing --- megatron/model/rwkv/v6/rwkv.py | 46 +++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 923de0835..e96f1bd2b 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from torch.utils.cpp_extension import load from megatron import mpu -from mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region +from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region class WKV(torch.autograd.Function): """ @@ -207,7 +207,7 @@ def __init__(self, neox_args, layer_number, init_method): neox_args=neox_args, input_size=neox_args.dim_att, output_size=neox_args.hidden_size, - gather_output=False, + gather_output=True, init_method=init_method, bias=False, ) @@ -216,19 +216,19 @@ def __init__(self, neox_args, layer_number, init_method): neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.dim_att, - gather_output=False, + gather_output=True, init_method=init_method, bias=False, ) - self.gate = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.dim_att, - input_is_parallel=True, - init_method=init_method, - parallel_output=False, - bias=False - ) + #self.gate = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.dim_att, + # input_is_parallel=True, + # init_method=init_method, + # parallel_output=False, + # bias=False + # ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -252,7 +252,8 @@ def jit_func(self, x): r, _ = self.receptance(xr) k, _ = self.key(xk) v, _ = self.value(xv) - g, _ = F.silu(self.gate(xg)) + gated, _ = self.gate(xg) + g = F.silu(gated) ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww @@ -271,12 +272,15 @@ def jit_func_2(self, x, g): def forward(self, x): B, T, C = x.size() + C_tp = C//mpu.get_model_parallel_world_size() H = self.neox_args.num_attention_heads + H_tp = H//mpu.get_model_parallel_world_size() r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") - x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) - x = reduce_from_model_parallel_region(x) + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa) + x = gather_from_model_parallel_region(x) + print(f"size of x after kernel: {x.size()}") return self.jit_func_2(x, g) @@ -304,11 +308,11 @@ def __init__(self, neox_args, layer_number, init_method): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + #self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=neox_args.dim_ffn, + output_size=neox_args.ffn_dim, gather_output=False, init_method=init_method, bias=False, @@ -324,10 +328,10 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False ) - #self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + #self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) self.value = mpu.RowParallelLinear( neox_args=neox_args, - input_size=neox_args.dim_ffn, + input_size=neox_args.ffn_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=init_method, @@ -350,7 +354,7 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ - + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args @@ -446,4 +450,6 @@ def forward(self, args): assert len(args) == 2 hidden_states, mask = args neox_args = self.neox_args + if self.layer_number == 0: + hidden_states = hidden_states.transpose(0,1) return super().forward(hidden_states), mask From dd441b6252a131e25151b147670c3092007e075d Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Sat, 12 Oct 2024 02:57:47 -0400 Subject: [PATCH 06/28] time mixing debugging --- megatron/model/gpt2_model.py | 1 + megatron/model/rwkv/v6/rwkv.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 1b6aa9b54..ddf025a1d 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -74,6 +74,7 @@ def cross_entropy(output, labels, _fp16=False): else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) + print(f"model output shape: {output.size()}, loss shape: {losses.size()}") loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() return loss diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index e96f1bd2b..1d46d11bf 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -8,7 +8,7 @@ from torch.nn import functional as F from torch.utils.cpp_extension import load from megatron import mpu -from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region +from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region class WKV(torch.autograd.Function): """ @@ -237,6 +237,7 @@ def jit_func(self, x): B, T, C = x.size() xx = self.time_shift(x) - x + print(x[0,:,1],xx[0,:,1]) xxx = x + xx * self.time_maa_x xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) @@ -255,8 +256,10 @@ def jit_func(self, x): gated, _ = self.gate(xg) g = F.silu(gated) + print(f"size of ww matmuls: {self.time_decay_w1.size()}, {self.time_decay_w2.size()}") ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww + w = scatter_to_model_parallel_region(w) return r, k, v, g, w @@ -273,14 +276,18 @@ def jit_func_2(self, x, g): def forward(self, x): B, T, C = x.size() C_tp = C//mpu.get_model_parallel_world_size() - H = self.neox_args.num_attention_heads + H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() H_tp = H//mpu.get_model_parallel_world_size() + self.time_faaaa = self.time_faaaa[:self.neox_args.num_attention_heads//2,:] + #self.time_faaaa = scatter_to_model_parallel_region(self.time_faaaa) r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa) - x = gather_from_model_parallel_region(x) print(f"size of x after kernel: {x.size()}") + x = gather_from_model_parallel_region(x) + print(f"size of x after allgather: {x.size()}") return self.jit_func_2(x, g) @@ -347,7 +354,9 @@ def forward(self, x): k = torch.relu(k) ** 2 kv, _ = self.value(k) receptance, _ = self.receptance(xr) - return torch.sigmoid(receptance) * kv + retVal = torch.sigmoid(receptance) * kv + print(f"channel mix output size: {retVal.size()}") + return retVal class RWKVResidualLayer(nn.Module): @@ -452,4 +461,8 @@ def forward(self, args): neox_args = self.neox_args if self.layer_number == 0: hidden_states = hidden_states.transpose(0,1) - return super().forward(hidden_states), mask + hidden_states = super().forward(hidden_states) + if self.layer_number == self.neox_args.num_layers-1: + hidden_states = hidden_states.transpose(0,1) + print(f"output of model from residual layer pipe: {hidden_states.size()}") + return hidden_states, mask From a4186706d967106a5703f57b77530eab8829400e Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Sun, 13 Oct 2024 11:45:05 -0400 Subject: [PATCH 07/28] reset time_faaaa --- megatron/model/rwkv/v6/rwkv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 1d46d11bf..970613f27 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -279,7 +279,7 @@ def forward(self, x): H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() H_tp = H//mpu.get_model_parallel_world_size() - self.time_faaaa = self.time_faaaa[:self.neox_args.num_attention_heads//2,:] + #self.time_faaaa = self.time_faaaa[:self.neox_args.num_attention_heads//2,:] #self.time_faaaa = scatter_to_model_parallel_region(self.time_faaaa) r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") From 540d85658c64cf2f06c23c794a22faa99043a0d3 Mon Sep 17 00:00:00 2001 From: AI-WAIFU <67525070+AI-WAIFU@users.noreply.github.com> Date: Tue, 8 Oct 2024 20:25:59 +0100 Subject: [PATCH 08/28] Add additional asserts and update post training readme (#1300) * add asserts and fix post training readme * precommit --------- Co-authored-by: Quentin Anthony --- megatron/training.py | 10 ++++++++++ post-training/README.md | 2 -- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 5976ae6a7..277f127c3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -406,6 +406,9 @@ def get_batch(neox_args, data_iterator): datatype=datatype, ) elif neox_args.train_impl == "kto": + assert ( + neox_args.train_micro_batch_size_per_gpu > 1 + ), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -459,6 +462,13 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" + + assert neox_args.train_impl not in [ + "kto", + "dpo", + "rm", + ], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" + # Items and their type. keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 diff --git a/post-training/README.md b/post-training/README.md index 1ba5cde2f..930ad0e31 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -34,7 +34,6 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis ## SFT data ```bash -python post-training/llama_dpo_data.py python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages @@ -42,7 +41,6 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/lla ## KTO data ```bash -python post-training/llama_dpo_data.py python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward From 12aac35439eb78349709790b8619d60637fd44b9 Mon Sep 17 00:00:00 2001 From: AI-WAIFU <67525070+AI-WAIFU@users.noreply.github.com> Date: Tue, 8 Oct 2024 20:27:43 +0100 Subject: [PATCH 09/28] Fix failling tests (#1301) * fix typo * fix neoxargs usage test * skip conversion test due to multiprocessing issue * precommit --------- Co-authored-by: Quentin Anthony --- configs/neox_arguments.md | 23 ++++++++++++++++++++ megatron/neox_arguments/neox_args.py | 19 ++++++++++++++-- megatron/training.py | 2 +- tests/neox_args/test_neoxargs_usage.py | 4 +++- tests/unit/test_format_conversion_scripts.py | 4 ++++ 5 files changed, 48 insertions(+), 4 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 7dec66da2..686974181 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -843,6 +843,29 @@ Model Arguments +- **dim_att**: int + + Default = None + + Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. + + + +- **head_size**: int + + Default = None + + Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads. + + + +- **ffn_dim**: int + + Default = None + + Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor. + + ## NeoXArgsOptimizer Optimizer Arguments diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index ac313a3bb..c877c6c78 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -21,7 +21,7 @@ from template import NeoXArgsTemplate try: - from typing import List, Literal, Union, Optional + from typing import List, Literal, Union, Optional, Any except ImportError: from typing_extensions import List, Literal, Union, Optional @@ -502,6 +502,21 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + dim_att: int = None + """ + Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. + """ + + head_size: int = None + """ + Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads. + """ + + ffn_dim: int = None + """ + Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -673,7 +688,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Custom metadata to attach to the created Comet Experiment. """ - comet_experiment = None + comet_experiment: Any = None """ Initialized comet experiment object used to log data """ diff --git a/megatron/training.py b/megatron/training.py index 277f127c3..1965faea8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -586,7 +586,7 @@ def forward_step( return model.eval_batch(data_iterator, return_logits=return_logits) # Get the batch. - if neox_args.memory_profiling and neox_args.it: + if neox_args.memory_profiling and neox_args.iteration: torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() diff --git a/tests/neox_args/test_neoxargs_usage.py b/tests/neox_args/test_neoxargs_usage.py index 176151c2a..5f8ba7bd2 100644 --- a/tests/neox_args/test_neoxargs_usage.py +++ b/tests/neox_args/test_neoxargs_usage.py @@ -66,7 +66,9 @@ def test_neoxargs_usage(): # find args matches matches = list( - re.findall(r"(?<=args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents) + re.findall( + r"(?<=neox_args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents + ) ) if len(matches) == 0: continue diff --git a/tests/unit/test_format_conversion_scripts.py b/tests/unit/test_format_conversion_scripts.py index e0801434c..6935e480a 100644 --- a/tests/unit/test_format_conversion_scripts.py +++ b/tests/unit/test_format_conversion_scripts.py @@ -4,8 +4,12 @@ from megatron.neox_arguments.neox_args import NeoXArgsTokenizer +@pytest.mark.skip( + reason="Conversion test is skipped until we fix the CUDA + torch multiprocessing issue." +) def test_gpt_neox_to_huggingface(monkeypatch, tmpdir, tmp_path): # Generate random GPT-NEOX model, check we can convert to hf format + model_dir = str(tmpdir) input_args = ["train.py", "tests/config/test_setup.yml"] deepspeed_main_args = simulate_deepy_env(monkeypatch, input_args) From 97c7915fa997864cec44b49d24bc96a9059981f0 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 4 Jun 2024 11:22:29 -0400 Subject: [PATCH 10/28] inital tp commits --- megatron/model/rwkv/v6/rwkv.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 970613f27..fa0eaa53f 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -297,11 +297,16 @@ class ParallelRWKV_ChannelMix(nn.Module): Channel Mix layer. The ffn in RWKV """ + def __init__(self, neox_args, layer_number, init_method): def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + + world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) @@ -363,7 +368,7 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ - + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args From 5f89ed84d0c74804a6c448be3c96e36a1283b0b8 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 5 Nov 2024 17:20:14 -0500 Subject: [PATCH 11/28] merge --- megatron/model/rwkv/v6/rwkv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index fa0eaa53f..77521b9aa 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -284,7 +284,7 @@ def forward(self, x): r, k, v, g, w = self.jit_func(x) print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") - x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=self.time_faaaa) + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) ) print(f"size of x after kernel: {x.size()}") x = gather_from_model_parallel_region(x) print(f"size of x after allgather: {x.size()}") @@ -297,7 +297,6 @@ class ParallelRWKV_ChannelMix(nn.Module): Channel Mix layer. The ffn in RWKV """ - def __init__(self, neox_args, layer_number, init_method): def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args @@ -377,8 +376,9 @@ def __init__(self, neox_args, init_method, layer_number): self.bf16 = neox_args.precision == "bfloat16" assert ( neox_args.intermediate_size == None or neox_args.expansion_factor == None - ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections" - if not hasattr(neox_args, "dim_att"): + ), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" + if not neox_args.dim_att: + print("replacing dim_att") neox_args.dim_att = neox_args.hidden_size if neox_args.intermediate_size: neox_args.ffn_dim = neox_args.intermediate_size From 91cb7590531fde52b68bc952571b6bc8ca0b5ea8 Mon Sep 17 00:00:00 2001 From: Ganesh Ravichandran Date: Thu, 17 Oct 2024 17:52:49 -0400 Subject: [PATCH 12/28] Add ERROR logging prefix and sort the prefixes alphabetically (#1308) * Add ERROR logging prefix and sort alphabetically * fix comment --- megatron/neox_arguments/arguments.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index b108ebe6b..17d2e2cfb 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -50,16 +50,19 @@ ATTENTION_TYPE_CHOICES, ) -### Logging colors ### +### ANSI escape codes ### +END = "\033[0m" GREEN = "\033[92m" RED = "\033[91m" YELLOW = "\033[93m" -END = "\033[0m" -SUCCESS = f"{GREEN} [SUCCESS] {END}" -OKAY = f"{GREEN}[OKAY]{END}" -WARNING = f"{YELLOW}[WARNING]{END}" + +### Formatted logging prefixes ### +ERROR = f"{RED}[ERROR]{END} " FAIL = f"{RED}[FAIL]{END}" INFO = "[INFO]" +OKAY = f"{GREEN}[OKAY]{END}" +SUCCESS = f"{GREEN} [SUCCESS] {END}" +WARNING = f"{YELLOW}[WARNING]{END}" # ZERO defaults by deespeed # These values should not be changed unless defaults in deepspeed are changed From 49b263a2815f7ff0f39b1dfbdf0b94901182959a Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 4 Jun 2024 11:22:29 -0400 Subject: [PATCH 13/28] inital tp commits --- megatron/model/rwkv/v6/rwkv.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 77521b9aa..47e06bd0b 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -302,10 +302,6 @@ def __init__(self, neox_args, layer_number, init_method): self.neox_args = neox_args self.layer_number = layer_number - world_size = mpu.get_model_parallel_world_size() - self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) - - world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) From 48de6823f18849d9574904eff5b6bef6f48886e6 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Wed, 6 Nov 2024 17:57:08 -0500 Subject: [PATCH 14/28] cleanup --- configs/rwkv/1.5B.yml | 8 +-- configs/rwkv/430M.yml | 103 +++++++++++++++++++++++++++++++++ megatron/model/gpt2_model.py | 1 - megatron/model/rwkv/v6/rwkv.py | 39 ++----------- 4 files changed, 112 insertions(+), 39 deletions(-) create mode 100644 configs/rwkv/430M.yml diff --git a/configs/rwkv/1.5B.yml b/configs/rwkv/1.5B.yml index 0d97a7861..473bde88e 100644 --- a/configs/rwkv/1.5B.yml +++ b/configs/rwkv/1.5B.yml @@ -1,7 +1,7 @@ { # Parallelism is not yet supported for rwkv "pipe_parallel_size": 1, - "model_parallel_size": 2, + "model_parallel_size": 1, "num_layers": 24, "hidden_size": 2048, @@ -12,7 +12,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "train_micro_batch_size_per_gpu": 1, + "train_micro_batch_size_per_gpu": 4, "attention_config": [[["rwkv"], 24]], @@ -86,8 +86,8 @@ }, # misc. training settings - "train_iters": 1, - "lr_decay_iters": 1, + "train_iters": 320000, + "lr_decay_iters": 320000, "distributed_backend": "nccl", "lr_decay_style": "constant", "warmup": 0.01, diff --git a/configs/rwkv/430M.yml b/configs/rwkv/430M.yml new file mode 100644 index 000000000..1b3a62dfd --- /dev/null +++ b/configs/rwkv/430M.yml @@ -0,0 +1,103 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 24, + "hidden_size": 1024, + "num_attention_heads": 16, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 4096, + "max_position_embeddings": 4096, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 4, + + "attention_config": [[["rwkv"], 24]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + "seed": 1234, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index ddf025a1d..1b6aa9b54 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -74,7 +74,6 @@ def cross_entropy(output, labels, _fp16=False): else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) - print(f"model output shape: {output.size()}, loss shape: {losses.size()}") loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() return loss diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 47e06bd0b..3018063af 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -173,9 +173,6 @@ def __init__(self, neox_args, layer_number, init_method): ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - #self.receptance = nn.Linear( - # neox_args.hidden_size, neox_args.dim_att, bias=False - #) self.receptance = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -184,7 +181,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -193,7 +189,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) self.value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -202,7 +197,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) self.output = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.dim_att, @@ -211,7 +205,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column self.gate = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -220,15 +213,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.gate = mpu.RowParallelLinear( - # neox_args=neox_args, - # input_size=neox_args.hidden_size, - # output_size=neox_args.dim_att, - # input_is_parallel=True, - # init_method=init_method, - # parallel_output=False, - # bias=False - # ) self.ln_x = nn.GroupNorm( neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) ) @@ -237,7 +221,6 @@ def jit_func(self, x): B, T, C = x.size() xx = self.time_shift(x) - x - print(x[0,:,1],xx[0,:,1]) xxx = x + xx * self.time_maa_x xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) @@ -256,7 +239,6 @@ def jit_func(self, x): gated, _ = self.gate(xg) g = F.silu(gated) - print(f"size of ww matmuls: {self.time_decay_w1.size()}, {self.time_decay_w2.size()}") ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 w = self.time_decay + ww w = scatter_to_model_parallel_region(w) @@ -268,9 +250,8 @@ def jit_func_2(self, x, g): x = x.view(B * T, C) x = self.ln_x(x).view(B, T, C) - print(f"shape of x: {x.size()}, shape of g: {g.size()}") x, _ = self.output(x * g) - print(f"new shape of x: {x.size()}") + return x def forward(self, x): @@ -279,15 +260,11 @@ def forward(self, x): H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() H_tp = H//mpu.get_model_parallel_world_size() - #self.time_faaaa = self.time_faaaa[:self.neox_args.num_attention_heads//2,:] - #self.time_faaaa = scatter_to_model_parallel_region(self.time_faaaa) r, k, v, g, w = self.jit_func(x) - print(f"shape of r: {r.size()}, k: {k.size()}, v: {v.size()}, g: {g.size()}, w: {w.size()}, H: {H}, B: {B}, T: {T}, C: {C}, time_faaaa: {self.time_faaaa.size()}, \n") x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) ) - print(f"size of x after kernel: {x.size()}") + x = gather_from_model_parallel_region(x) - print(f"size of x after allgather: {x.size()}") return self.jit_func_2(x, g) @@ -315,7 +292,6 @@ def __init__(self, neox_args, layer_number, init_method): self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) - #self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False) self.key = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -324,9 +300,7 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False, ) - #self.receptance = nn.Linear( - # neox_args.hidden_size, neox_args.hidden_size, bias=False - #) + self.receptance = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -335,7 +309,6 @@ def __init__(self, neox_args, layer_number, init_method): init_method=init_method, bias=False ) - #self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False) self.value = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.ffn_dim, @@ -345,6 +318,7 @@ def __init__(self, neox_args, layer_number, init_method): parallel_output=False, bias=False ) + def forward(self, x): xx = self.time_shift(x) - x xk = x + xx * self.time_maa_k @@ -355,7 +329,7 @@ def forward(self, x): kv, _ = self.value(k) receptance, _ = self.receptance(xr) retVal = torch.sigmoid(receptance) * kv - print(f"channel mix output size: {retVal.size()}") + return retVal @@ -374,7 +348,6 @@ def __init__(self, neox_args, init_method, layer_number): neox_args.intermediate_size == None or neox_args.expansion_factor == None ), "Must pass either the absolute intermediate size or the relative expansion factor for rwkv" if not neox_args.dim_att: - print("replacing dim_att") neox_args.dim_att = neox_args.hidden_size if neox_args.intermediate_size: neox_args.ffn_dim = neox_args.intermediate_size @@ -450,7 +423,6 @@ def forward(self, x): return x - class RWKVResidualLayerPipe(RWKVResidualLayer): """ RWKV Pipeline Layer @@ -465,5 +437,4 @@ def forward(self, args): hidden_states = super().forward(hidden_states) if self.layer_number == self.neox_args.num_layers-1: hidden_states = hidden_states.transpose(0,1) - print(f"output of model from residual layer pipe: {hidden_states.size()}") return hidden_states, mask From c6fac961ffca78e61c12ab7e31e3e6c0d177552d Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Wed, 6 Nov 2024 18:08:10 -0500 Subject: [PATCH 15/28] cleanup --- configs/rwkv/430M.yml | 2 +- megatron/model/rwkv/v6/rwkv.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/rwkv/430M.yml b/configs/rwkv/430M.yml index 1b3a62dfd..a42e1796b 100644 --- a/configs/rwkv/430M.yml +++ b/configs/rwkv/430M.yml @@ -12,7 +12,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "train_micro_batch_size_per_gpu": 4, + "train_micro_batch_size_per_gpu": 1, "attention_config": [[["rwkv"], 24]], diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 3018063af..101e44226 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -258,7 +258,6 @@ def forward(self, x): B, T, C = x.size() C_tp = C//mpu.get_model_parallel_world_size() H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() - H_tp = H//mpu.get_model_parallel_world_size() r, k, v, g, w = self.jit_func(x) From 5a259c0b47b11493e00a72aa8d0f0b3008e66b40 Mon Sep 17 00:00:00 2001 From: Jacob Hatef <74274091+jahatef@users.noreply.github.com> Date: Wed, 6 Nov 2024 18:11:47 -0500 Subject: [PATCH 16/28] Update local_setup.yml --- configs/local_setup.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index 63d570a6f..b8ec4b06a 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -22,10 +22,6 @@ "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, - - # "launcher": "openmpi", - #"deepspeed_mpi": true, - "tensorboard_dir": "tensorboard", "log_dir": "logs", } From c2d6c852b164be48779774cc9d67baef2e747ac0 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Sun, 10 Nov 2024 15:36:46 -0500 Subject: [PATCH 17/28] add Triton FLA --- megatron/model/rwkv/v6/rwkv.py | 72 ++++++++++++++++++---------- megatron/neox_arguments/neox_args.py | 5 ++ requirements/requirements-rwkv.txt | 1 + 3 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 requirements/requirements-rwkv.txt diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 101e44226..deee63194 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -9,6 +9,15 @@ from torch.utils.cpp_extension import load from megatron import mpu from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region +try: + from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6, native_recurrent_rwkv6 + import einops +except ModuleNotFoundError: + print( + "Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ + or directly from https://github.com/TorchRWKV/flash-linear-attention/tree/stable, or use CUDA kernels." + ) + pass class WKV(torch.autograd.Function): """ @@ -96,6 +105,18 @@ def backward(ctx, gy): def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): return WKV.apply(B, T, C, H, r, k, v, w, u) +@torch.compiler.disable(recursive=True) +# torch.compiler introduces errors in numerical precision (torch 2.4) +def RUN_FLA_CHUNK(B, T, C, H, r, k, v, w, u, h=None, scale=1.0, chunk_size=32): + r = r.view(B,T,H,-1).transpose(1,2) + k = k.view(B,T,H,-1).transpose(1,2) + v = v.view(B,T,H,-1).transpose(1,2) + # u can be 3d or 2d (B, H, -1) or just (H, -1) to save VRAM + w = -torch.exp(w.view(B,T,H,-1).transpose(1,2)) + # change to scale=-1.0 when using fp16, this will apply scale to r and k. + o, final_state = chunk_rwkv6(r, k, v, w, u=u, scale=scale, initial_state=h, + output_final_state=False, chunk_size=chunk_size) #initial_state=None and output_final_state=False for rwkv6 + return o.transpose(1,2).reshape(B,T,C), final_state # RWKV6 time mix class RWKV_TimeMix(nn.Module): @@ -260,9 +281,11 @@ def forward(self, x): H = self.neox_args.num_attention_heads//mpu.get_model_parallel_world_size() r, k, v, g, w = self.jit_func(x) - - x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H) ) - + if self.neox_args.rwkv_fla: + x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) + else: + x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) + x = gather_from_model_parallel_region(x) return self.jit_func_2(x, g) @@ -382,27 +405,28 @@ def __init__(self, neox_args, init_method, layer_number): self.drop1 = nn.Dropout(p=neox_args.hidden_dropout) if layer_number == 0: - global wkv_cuda - """ - Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. - """ - wkv_cuda = load( - name="wkv6", - sources=[ - "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", - f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={self.neox_args.head_size}", - f"-D_T_={self.neox_args.seq_length}", - ], - ) + if not self.neox_args.rwkv_fla: + global wkv_cuda + """ + Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. + """ + wkv_cuda = load( + name="wkv6", + sources=[ + "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", + f"-D_T_={self.neox_args.seq_length}", + ], + ) def forward(self, x): neox_args = self.neox_args diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index c877c6c78..c64f67d32 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -277,6 +277,11 @@ class NeoXArgsModel(NeoXArgsTemplate): } """ + rwkv_fla: bool = False + """ + Whether to use the Flash Linear Attention implementation of the RWKV kernel, or the CUDA kernel version. + """ + num_unique_layers: int = None """ Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0. diff --git a/requirements/requirements-rwkv.txt b/requirements/requirements-rwkv.txt new file mode 100644 index 000000000..f193cf288 --- /dev/null +++ b/requirements/requirements-rwkv.txt @@ -0,0 +1 @@ +rwkv-fla>=0.1.202410200535 \ No newline at end of file From bdb3658a3cf0d612a584e5c99f6873864936315b Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 12 Nov 2024 16:14:05 -0500 Subject: [PATCH 18/28] change version of rwkv-fla --- megatron/model/rwkv/v6/rwkv.py | 4 ++-- requirements/requirements-rwkv.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index deee63194..0d77278bc 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -10,12 +10,12 @@ from megatron import mpu from megatron.mpu import gather_from_model_parallel_region, reduce_from_model_parallel_region, scatter_to_model_parallel_region try: - from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6, native_recurrent_rwkv6 + from fla.ops.rwkv6 import chunk_rwkv6 import einops except ModuleNotFoundError: print( "Unable to import RWKV FLA kernels. Install them from our requirements/requirements-rwkv.txt, \ - or directly from https://github.com/TorchRWKV/flash-linear-attention/tree/stable, or use CUDA kernels." + or directly from https://github.com/sustcsonglin/flash-linear-attention.git, or use CUDA kernels." ) pass diff --git a/requirements/requirements-rwkv.txt b/requirements/requirements-rwkv.txt index f193cf288..38c786d5b 100644 --- a/requirements/requirements-rwkv.txt +++ b/requirements/requirements-rwkv.txt @@ -1 +1 @@ -rwkv-fla>=0.1.202410200535 \ No newline at end of file +git+https://github.com/sustcsonglin/flash-linear-attention \ No newline at end of file From ff7f328daa94837a318bf6bd0501be0b24ff538d Mon Sep 17 00:00:00 2001 From: tiandeyu-cs <54715756+tiandeyu-cs@users.noreply.github.com> Date: Thu, 14 Nov 2024 06:36:48 +0800 Subject: [PATCH 19/28] fix a GQA issue (#1314) (#1315) - do not create a fake head dim and split the 'mixed_x_layer' into QKV layers directly. --- megatron/model/transformer.py | 60 +++++++++++------------------------ 1 file changed, 18 insertions(+), 42 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d112a7461..42dbdfeeb 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -763,51 +763,16 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) - # First: reshape so we have seqlen, batch, and num. query heads each as separate dims - # Final dim is not exactly head dim: the first (head dim) dims are query heads, - # The last (head dim * ratio of kv to q heads) each are the "k/v heads" - # (right now we treat like we have same num. heads, but smaller head dim) - - # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] - new_qkv_shape = ( - mixed_x_layer.shape[0], - mixed_x_layer.shape[1], - self.num_attention_heads_per_partition, - int( - self.hidden_size_per_attention_head - * ( - 1 - + 2 - * ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - ) - ), - ) - mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) - - # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, - # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately + # split the last dim, so that the first (q head * head dim) dimensions go to Q, + # the last smaller 2 * (kv head * head dim) each divided between K and V separately split_sizes = ( - self.hidden_size_per_attention_head, - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + self.num_kv_heads_per_partition * self.hidden_size_per_attention_head, + self.num_kv_heads_per_partition * self.hidden_size_per_attention_head, ) - # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] + # [sq, b, ((np + 2 * kvp) * hn)] --> 1 x [sq, b, np * hn] , 2 x [sq, b, kvp * hn] (query_layer, key_layer, value_layer) = [ x.contiguous() for x in torch.split( @@ -817,6 +782,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): ) ] + # reshape Q to proper output shape (last dim = correct full "real" head size again) + # [sq, b, np * hn] --> [sq, b, np, hn] + new_query_shape = ( + query_layer.size(0), + query_layer.size(1), + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + + query_layer = query_layer.view(*new_query_shape) + # reshape K/V to proper output shape (last dim = correct full "real" head size again) # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] new_kv_shape = ( From 1350b2c27390eb2f78ec61c5e334f322bf061199 Mon Sep 17 00:00:00 2001 From: tiandeyu-cs <54715756+tiandeyu-cs@users.noreply.github.com> Date: Thu, 14 Nov 2024 06:45:18 +0800 Subject: [PATCH 20/28] fix 'intermediate_size' in Llama configuration files after the 'mlp_type' option was removed (#1309) * fix 'intermediate_size' in Llama configuration files after the 'mlp_type' option was removed * config adjustments for llama and gated activations * pre-commit --------- Co-authored-by: jahatef Co-authored-by: Quentin Anthony --- configs/llama/13B.yml | 2 ++ configs/llama/30B.yml | 2 ++ configs/llama/65B.yml | 2 ++ configs/llama/7B.yml | 2 ++ configs/llama/train_config.yml | 2 +- configs/llama2/13B.yml | 1 + configs/llama2/70B.yml | 2 +- configs/llama2/7B.yml | 1 + megatron/model/transformer.py | 5 ++--- 9 files changed, 14 insertions(+), 5 deletions(-) diff --git a/configs/llama/13B.yml b/configs/llama/13B.yml index 162e51719..a7470cae8 100644 --- a/configs/llama/13B.yml +++ b/configs/llama/13B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 40, "hidden_size": 5120, + "intermediate_size": 40960, "num_attention_heads": 40, "seq_length": 2048, "max_position_embeddings": 2048, @@ -16,6 +17,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-6, + "use_bias_in_mlp": False, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": false, diff --git a/configs/llama/30B.yml b/configs/llama/30B.yml index 2c948e40c..234445c77 100644 --- a/configs/llama/30B.yml +++ b/configs/llama/30B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 60, "hidden_size": 6656, + "intermediate_size": 53248, "num_attention_heads": 52, "seq_length": 2048, "max_position_embeddings": 2048, @@ -16,6 +17,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-6, + "use_bias_in_mlp": False, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": false, diff --git a/configs/llama/65B.yml b/configs/llama/65B.yml index 4ebd249b9..8ffffe241 100644 --- a/configs/llama/65B.yml +++ b/configs/llama/65B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 80, "hidden_size": 8192, + "intermediate_size": 65536, "num_attention_heads": 64, "seq_length": 2048, "max_position_embeddings": 2048, @@ -16,6 +17,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-6, + "use_bias_in_mlp": False, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": false, diff --git a/configs/llama/7B.yml b/configs/llama/7B.yml index cc21446be..0d7c40b24 100644 --- a/configs/llama/7B.yml +++ b/configs/llama/7B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 32, "hidden_size": 4096, + "intermediate_size": 32768, "num_attention_heads": 32, "seq_length": 2048, "max_position_embeddings": 2048, @@ -16,6 +17,7 @@ "output_layer_parallelism": "column", "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-6, + "use_bias_in_mlp": False, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": false, diff --git a/configs/llama/train_config.yml b/configs/llama/train_config.yml index 7cc5a5968..459332609 100644 --- a/configs/llama/train_config.yml +++ b/configs/llama/train_config.yml @@ -70,5 +70,5 @@ "steps_per_print": 10, "keep_last_n_checkpoints": 4, "wall_clock_breakdown": true, - "mlp_multiple_of": 256, + } diff --git a/configs/llama2/13B.yml b/configs/llama2/13B.yml index 5bf7a4f72..7df5ad3ea 100644 --- a/configs/llama2/13B.yml +++ b/configs/llama2/13B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 40, "hidden_size": 5120, + "intermediate_size": 41472, "num_attention_heads": 40, "seq_length": 4096, "max_position_embeddings": 4096, diff --git a/configs/llama2/70B.yml b/configs/llama2/70B.yml index b628deffe..d175e146e 100644 --- a/configs/llama2/70B.yml +++ b/configs/llama2/70B.yml @@ -6,7 +6,7 @@ # model settings "num_layers": 80, "hidden_size": 8192, - "intermediate_size": 28672, + "intermediate_size": 86016, "num_attention_heads": 64, "num_kv_heads": 8, "seq_length": 4096, diff --git a/configs/llama2/7B.yml b/configs/llama2/7B.yml index eeba99c52..cdb63f02e 100644 --- a/configs/llama2/7B.yml +++ b/configs/llama2/7B.yml @@ -6,6 +6,7 @@ # model settings "num_layers": 32, "hidden_size": 4096, + "intermediate_size": 32768, "num_attention_heads": 32, "seq_length": 4096, "max_position_embeddings": 4096, diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 42dbdfeeb..7627e13b6 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1245,9 +1245,8 @@ def forward(self, x, attention_mask, layer_past=None): with torch.enable_grad() if not self.eval else nullcontext(): if ( - self.activation == "swiglu" - or self.num_experts > 1 - and self.moe_type == "deepspeed" + mlp_bias == None, + self.num_experts > 1 and self.moe_type == "deepspeed", ): # No dropout either assert mlp_bias is None From c4d7a54420b09a6cc2c6efe1e948709cac738a03 Mon Sep 17 00:00:00 2001 From: markNZed Date: Wed, 13 Nov 2024 23:51:32 +0100 Subject: [PATCH 21/28] Python 3.10 support (#1313) * Python 3.10 support In this issue Python 3.10 support was added https://github.com/EleutherAI/gpt-neox/pull/1122 * update wording on torch and python --------- Co-authored-by: Quentin Anthony --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d4e2939f..006f5964f 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA ### Host Setup -First make sure you are in an environment with Python 3.8 with an appropriate version of PyTorch 1.8 or later installed. **Note:** Some of the libraries that GPT-NeoX depends on have not been updated to be compatible with Python 3.10+. Python 3.9 appears to work, but this codebase has been developed and tested for Python 3.8. +This codebase has primarily developed and tested for Python 3.8-3.10, and PyTorch 1.8-2.0. This is not a strict requirement, and other versions and combinations of libraries may work. To install the remaining basic dependencies, run: From ee2f14240f602b18e81f5311e91575fd9d8a556f Mon Sep 17 00:00:00 2001 From: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:03:19 -0500 Subject: [PATCH 22/28] Fix documentation for converting SFT/DPO weights back to HF Llama (#1318) --- post-training/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/post-training/README.md b/post-training/README.md index 930ad0e31..fb7ac8eb4 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -53,5 +53,5 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/lla python tools/ckpts/convert_neox_to_hf.py --input_dir eleuther-neox/checkpoints/rm/llama3/llama3-8b-instruct/global_step100 --output_dir checkpoints/rm/llama3_hf --config_file checkpoints/rm/llama3/llama3-8b-instruct/global_step100/configs/llama3-8b-rm.yml --precision bf16 --vocab-is-hf-tokenizer --architecture llama --pad-token-id 128002 # SFT/DPO -python tools/ckpts/convert_neox_to_hf.py --input_dir eleuther-neox/checkpoints//llama3/llama3-8b-instruct/global_step100 --output_dir checkpoints//llama3_hf --config_file checkpoints//llama3/llama3-8b-instruct/global_step100/configs/llama3-8b-rm.yml --precision bf16 --vocab-is-hf-tokenizer +python tools/ckpts/convert_neox_to_hf.py --input_dir eleuther-neox/checkpoints//llama3/llama3-8b-instruct/global_step100 --output_dir checkpoints//llama3_hf --config_file checkpoints//llama3/llama3-8b-instruct/global_step100/configs/llama3-8b-rm.yml --precision bf16 --vocab-is-hf-tokenizer --architecture llama ``` From 6e81f0be7851f6a24100f25e9043f8aba254035f Mon Sep 17 00:00:00 2001 From: AI-WAIFU <67525070+AI-WAIFU@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:04:18 +0000 Subject: [PATCH 23/28] fix bug (#1311) --- megatron/neox_arguments/arguments.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 17d2e2cfb..f5e49e319 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -956,12 +956,19 @@ def calculate_derived(self): ) # derive precision - fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" if self.fp16 and self.fp16.get("enabled", False): if self.precision is None: self.update_value("precision", "fp16") else: + fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" assert self.precision == "fp16", fp16_conflict + + if self.bf16 and self.bf16.get("enabled", False): + if self.precision is None: + self.update_value("precision", "bfloat16") + else: + bf16_conflict = "DeepSpeed bf16 field was set but precision conflicts" + assert self.precision == "bfloat16", bf16_conflict if self.precision == "fp16": if isinstance(self.fp16, dict) and len(self.fp16) > 0: @@ -971,14 +978,15 @@ def calculate_derived(self): fp16_args = {"type": "fp16", "enabled": True} self.update_value("fp16", fp16_args) elif self.precision == "bfloat16": - bf_config = {"bf16": {"enabled": True}} - # dt_config = {"grad_accum_dtype": "fp32"} - if self.deepspeed_extra_args is None: - self.update_value("deepspeed_extra_args", bf_config) - else: - extra_args = copy.deepcopy(self.deepspeed_extra_args) - extra_args.update(bf_config) - self.update_value("deepspeed_extra_args", extra_args) + if not self.bf16: + bf_config = {"bf16": {"enabled": True}} + # dt_config = {"grad_accum_dtype": "fp32"} + if self.deepspeed_extra_args is None: + self.update_value("deepspeed_extra_args", bf_config) + else: + extra_args = copy.deepcopy(self.deepspeed_extra_args) + extra_args.update(bf_config) + self.update_value("deepspeed_extra_args", extra_args) zero_stage = self.zero_optimization["stage"] if self.data_types is None: From df9541974e8ef1dc422e889fa871f8078f8cf2b2 Mon Sep 17 00:00:00 2001 From: Michael Yu <76673037+michaelc-yu@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:02:06 -0800 Subject: [PATCH 24/28] Add support for dropout in sparse attention (#1312) --- megatron/model/transformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 7627e13b6..5a4586309 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -750,9 +750,13 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None - return self.sparse_attn( + attn_scores = self.sparse_attn( query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) + # apply dropout + if self.training: + attn_scores = self.attention_dropout(attn_scores) + return attn_scores def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA, From d682529cd6e8fd11831c3259424f195cfba5ff0f Mon Sep 17 00:00:00 2001 From: Louis Castricato Date: Fri, 15 Nov 2024 21:48:50 -0500 Subject: [PATCH 25/28] adds pyproject files and tests (#1302) * adds pyproject files and tests * formatting and add dev packages to dev req files * improve req testing --------- Co-authored-by: Quentin Anthony --- requirements/pyproject-apex-pip.toml | 14 ++ requirements/pyproject-comet.toml | 14 ++ requirements/pyproject-flashattention.toml | 14 ++ requirements/pyproject-mamba.toml | 16 +++ requirements/pyproject-neox-dev.toml | 23 +++ requirements/pyproject-onebitadam.toml | 14 ++ requirements/pyproject-s3.toml | 15 ++ requirements/pyproject-sparseattention.toml | 14 ++ requirements/pyproject-tensorboard.toml | 14 ++ requirements/pyproject-transformerengine.toml | 14 ++ requirements/pyproject-wandb.toml | 14 ++ requirements/pyproject.toml | 33 +++++ requirements/requirements-dev.txt | 2 + tests/requirements/test_requirements.py | 131 ++++++++++++++++++ 14 files changed, 332 insertions(+) create mode 100644 requirements/pyproject-apex-pip.toml create mode 100644 requirements/pyproject-comet.toml create mode 100644 requirements/pyproject-flashattention.toml create mode 100644 requirements/pyproject-mamba.toml create mode 100644 requirements/pyproject-neox-dev.toml create mode 100644 requirements/pyproject-onebitadam.toml create mode 100644 requirements/pyproject-s3.toml create mode 100644 requirements/pyproject-sparseattention.toml create mode 100644 requirements/pyproject-tensorboard.toml create mode 100644 requirements/pyproject-transformerengine.toml create mode 100644 requirements/pyproject-wandb.toml create mode 100644 requirements/pyproject.toml create mode 100644 tests/requirements/test_requirements.py diff --git a/requirements/pyproject-apex-pip.toml b/requirements/pyproject-apex-pip.toml new file mode 100644 index 000000000..df41dc925 --- /dev/null +++ b/requirements/pyproject-apex-pip.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-apex-pip" +version = "0.1.0" +description = "Apex pip requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +pip = "23.3.2" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-comet.toml b/requirements/pyproject-comet.toml new file mode 100644 index 000000000..04422a213 --- /dev/null +++ b/requirements/pyproject-comet.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-comet" +version = "0.1.0" +description = "Comet ML requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +comet_ml = ">=3.45.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-flashattention.toml b/requirements/pyproject-flashattention.toml new file mode 100644 index 000000000..14c7ad112 --- /dev/null +++ b/requirements/pyproject-flashattention.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-flashattention" +version = "0.1.0" +description = "Flash Attention requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +flash-attn = "2.5.6" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-mamba.toml b/requirements/pyproject-mamba.toml new file mode 100644 index 000000000..0f6191662 --- /dev/null +++ b/requirements/pyproject-mamba.toml @@ -0,0 +1,16 @@ +[tool.poetry] +name = "gpt-neox-mamba" +version = "0.1.0" +description = "Mamba requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +causal_conv1d = ">=1.1.0" +einops = "*" +mamba_ssm = ">=1.2.0.post1" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-neox-dev.toml b/requirements/pyproject-neox-dev.toml new file mode 100644 index 000000000..55b00f6ba --- /dev/null +++ b/requirements/pyproject-neox-dev.toml @@ -0,0 +1,23 @@ +[tool.poetry] +name = "gpt-neox-dev" +version = "0.1.0" +description = "Development requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +autopep8 = ">=1.5.6" +clang-format = ">=13.0.1" +pre-commit = ">=2.17.0" +pytest = ">=6.2.3" +pytest-cov = ">=2.11.1" +pytest-forked = ">=1.3.0" +pytest-html = "4.1.1" +pytest-xdist = "*" +toml = ">=0.10.2" +packaging = ">=23.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-onebitadam.toml b/requirements/pyproject-onebitadam.toml new file mode 100644 index 000000000..aeaf33aa6 --- /dev/null +++ b/requirements/pyproject-onebitadam.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-onebitadam" +version = "0.1.0" +description = "OneBitAdam requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +cupy-cuda111 = ">=8.6.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-s3.toml b/requirements/pyproject-s3.toml new file mode 100644 index 000000000..a0cb99aef --- /dev/null +++ b/requirements/pyproject-s3.toml @@ -0,0 +1,15 @@ +[tool.poetry] +name = "gpt-neox-s3" +version = "0.1.0" +description = "S3 requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +boto3 = "*" +hf-transfer = ">=0.1.3" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-sparseattention.toml b/requirements/pyproject-sparseattention.toml new file mode 100644 index 000000000..2864c799b --- /dev/null +++ b/requirements/pyproject-sparseattention.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-sparseattention" +version = "0.1.0" +description = "Sparse Attention requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +triton = "2.1.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-tensorboard.toml b/requirements/pyproject-tensorboard.toml new file mode 100644 index 000000000..79bbfa900 --- /dev/null +++ b/requirements/pyproject-tensorboard.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-tensorboard" +version = "0.1.0" +description = "TensorBoard requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +tensorboard = "2.13.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-transformerengine.toml b/requirements/pyproject-transformerengine.toml new file mode 100644 index 000000000..7c313e0d9 --- /dev/null +++ b/requirements/pyproject-transformerengine.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-transformerengine" +version = "0.1.0" +description = "Transformer Engine requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +transformer-engine = {git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "stable"} + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject-wandb.toml b/requirements/pyproject-wandb.toml new file mode 100644 index 000000000..c5806b341 --- /dev/null +++ b/requirements/pyproject-wandb.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "gpt-neox-wandb" +version = "0.1.0" +description = "Weights & Biases requirements for GPT-NeoX" +authors = ["EleutherAI "] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +wandb = ">=0.10.28" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/pyproject.toml b/requirements/pyproject.toml new file mode 100644 index 000000000..91d6fc1dd --- /dev/null +++ b/requirements/pyproject.toml @@ -0,0 +1,33 @@ +[tool.poetry] +name = "gpt-neox" +version = "2.0.0" +description = "An open-source library for training large-scale language models on GPUs" +authors = ["EleutherAI "] +license = "Apache-2.0" +readme = "README.md" +homepage = "https://www.github.com/eleutherai/gpt-neox" +repository = "https://www.github.com/eleutherai/gpt-neox" +documentation = "https://www.github.com/eleutherai/gpt-neox" + +[tool.poetry.dependencies] +python = "^3.8" +deepspeed = {git = "https://github.com/EleutherAI/DeeperSpeed.git", rev = "02e2ebf7dee6aaab3d89094ed470a4609763c742"} +ftfy = "^6.0.1" +huggingface_hub = "^0.11.0" +jinja2 = "3.1.4" +lm_dataformat = {git = "https://github.com/EleutherAI/lm_dataformat.git", rev = "4eec05349977071bf67fc072290b95e31c8dd836"} +lm_eval = ">=0.4.0,<=0.4.1" +mpi4py = "^3.0.3" +numpy = "<2.0" +pybind11 = "^2.6.2" +regex = "*" +sentencepiece = "*" +six = "*" +tiktoken = "^0.1.2" +tokenizers = "^0.12.1" +transformers = "4.38.0" +toml = "*" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 60ff3224f..8dfd5595c 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,8 +1,10 @@ autopep8>=1.5.6 clang-format>=13.0.1 +packaging>=23.0 pre-commit>=2.17.0 pytest>=6.2.3 pytest-cov>=2.11.1 pytest-forked>=1.3.0 pytest-html==4.1.1 pytest-xdist +toml>=0.10.2 diff --git a/tests/requirements/test_requirements.py b/tests/requirements/test_requirements.py new file mode 100644 index 000000000..20e8ad0dd --- /dev/null +++ b/tests/requirements/test_requirements.py @@ -0,0 +1,131 @@ +import pytest +import toml +from pathlib import Path +from typing import Dict, List, Optional +from packaging.version import parse as parse_version, Version +from dataclasses import dataclass + + +@dataclass +class Dependency: + name: str + version: Optional[str] = None + + @classmethod + def from_requirement(cls, requirement: str) -> "Dependency": + """Parse a requirement string into a Dependency object.""" + # Common version specifiers + specifiers = ["==", ">=", ">", "<=", "<"] + name = requirement + version = None + + for spec in specifiers: + if spec in requirement: + name, version = requirement.split(spec, 1) + version = version.strip() + break + + return cls(name.lower().strip(), version) + + def matches_version(self, other_version: str) -> bool: + """Check if this dependency's version matches another version string.""" + if not self.version or not other_version: + return True + + try: + # Convert versions to comparable objects + our_version = parse_version(self.version) + their_version = parse_version(other_version.replace("*", "0")) + return our_version == their_version + except ValueError: + # If versions can't be parsed, fall back to string comparison + return self.version.replace("*", "0") == other_version.replace("*", "0") + + +class DependencyValidator: + def __init__(self, requirements_dir: Path): + self.requirements_dir = requirements_dir + + def parse_requirements(self, file_path: Path) -> List[Dependency]: + """Parse requirements.txt file into a list of Dependencies.""" + try: + with open(file_path, "r") as f: + lines = [ + line.strip() + for line in f + if line.strip() and not line.startswith("#") + ] + return [Dependency.from_requirement(line) for line in lines] + except FileNotFoundError: + raise FileNotFoundError(f"Requirements file not found: {file_path}") + except Exception as e: + raise ValueError(f"Error parsing requirements file {file_path}: {str(e)}") + + def parse_pyproject(self, file_path: Path) -> Dict[str, str]: + """Parse pyproject.toml file and extract dependencies.""" + try: + with open(file_path, "r") as f: + pyproject_data = toml.load(f) + return { + name.lower(): str(version) + for name, version in pyproject_data["tool"]["poetry"][ + "dependencies" + ].items() + if name.lower() != "python" # Exclude Python version + } + except FileNotFoundError: + raise FileNotFoundError(f"pyproject.toml file not found: {file_path}") + except Exception as e: + raise ValueError(f"Error parsing pyproject.toml {file_path}: {str(e)}") + + def compare_dependencies( + self, req_deps: List[Dependency], pyproject_deps: Dict[str, str] + ) -> tuple[bool, List[str]]: + """Compare dependencies between requirements.txt and pyproject.toml.""" + mismatches = [] + + for req in req_deps: + if req.name not in pyproject_deps: + mismatches.append( + f"Dependency '{req.name}' not found in pyproject.toml" + ) + continue + + if not req.matches_version(pyproject_deps[req.name]): + mismatches.append( + f"Version mismatch for '{req.name}': " + f"requirements.txt={req.version}, " + f"pyproject.toml={pyproject_deps[req.name]}" + ) + + return len(mismatches) == 0, mismatches + + +def get_corresponding_pyproject(req_file: Path) -> Path: + """Get the corresponding pyproject.toml file for a requirements file.""" + env_name = req_file.stem.split("-")[1] + return req_file.parent / f"pyproject-{env_name}.toml" + + +@pytest.mark.parametrize("req_file", Path("requirements").glob("requirements-*.txt")) +def test_pyproject_matches_requirements(req_file: Path): + """Test that requirements.txt dependencies match pyproject.toml dependencies.""" + validator = DependencyValidator(req_file.parent) + pyproject_file = get_corresponding_pyproject(req_file) + + # Parse both dependency files + req_deps = validator.parse_requirements(req_file) + pyproject_deps = validator.parse_pyproject(pyproject_file) + + # Compare dependencies and get detailed mismatches + is_match, mismatches = validator.compare_dependencies(req_deps, pyproject_deps) + + # Create detailed error message if there are mismatches + if not is_match: + error_msg = "\n".join( + [ + f"\nDependency mismatches found between {req_file} and {pyproject_file}:", + *[f"- {msg}" for msg in mismatches], + ] + ) + pytest.fail(error_msg) From 0bc11d6c15052f7f4b11ba16a9d585e9a30126b8 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Wed, 27 Nov 2024 11:04:07 -0800 Subject: [PATCH 26/28] undo merge error (#1325) --- megatron/model/transformer.py | 5 ++--- megatron/neox_arguments/arguments.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 5a4586309..c670fd4bf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1248,9 +1248,8 @@ def forward(self, x, attention_mask, layer_past=None): raise KeyError(self.moe_type) with torch.enable_grad() if not self.eval else nullcontext(): - if ( - mlp_bias == None, - self.num_experts > 1 and self.moe_type == "deepspeed", + if mlp_bias == None or ( + self.num_experts > 1 and self.moe_type == "deepspeed" ): # No dropout either assert mlp_bias is None diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index f5e49e319..3b49cea32 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -962,7 +962,7 @@ def calculate_derived(self): else: fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" assert self.precision == "fp16", fp16_conflict - + if self.bf16 and self.bf16.get("enabled", False): if self.precision is None: self.update_value("precision", "bfloat16") From c6db95cbc030df5be42e90d07d5e80a95ec03671 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 4 Jun 2024 11:22:29 -0400 Subject: [PATCH 27/28] inital tp commits --- megatron/model/rwkv/v6/rwkv.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 0d77278bc..88d99cd86 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -350,9 +350,7 @@ def forward(self, x): k = torch.relu(k) ** 2 kv, _ = self.value(k) receptance, _ = self.receptance(xr) - retVal = torch.sigmoid(receptance) * kv - - return retVal + return torch.sigmoid(receptance) * kv class RWKVResidualLayer(nn.Module): From daac50370e6ad0108d20ec8c52c6fa4c0eb1ec08 Mon Sep 17 00:00:00 2001 From: jahatef Date: Wed, 19 Jun 2024 21:15:05 +0000 Subject: [PATCH 28/28] setup --- configs/local_setup.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index b8ec4b06a..63d570a6f 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -22,6 +22,10 @@ "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, + + # "launcher": "openmpi", + #"deepspeed_mpi": true, + "tensorboard_dir": "tensorboard", "log_dir": "logs", }