diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 9d1f16fbe..da443c4f6 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -72,10 +72,12 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: +def safe_barrier( + group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None +) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -86,9 +88,10 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, + device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 5ddf2ff98..4d31324fe 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -327,7 +327,9 @@ def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool ) -> typing.Generator[None, None, None]: batch_config = context.schedule.batch_config - grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs + grad_output = ( + self._optimizer.grad_scale / batch_config.num_inputs if context.schedule.phase.is_training else None + ) for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index 1428ed25e..cea5f6ee2 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -60,3 +60,14 @@ def call(*args, **kwargs): def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.Tensor: # noqa return context + + +class AuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa + ctx.grad = torch.full_like(aux_loss, grad) + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa + return grad_output, ctx.grad, None diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61d..050c700c9 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,16 +93,17 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 -class CrossEntropyImpl(str, enum.Enum): +class EntropyLossImplementation(enum.StrEnum): auto = "auto" torch = "torch" fused = "fused" triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" +class EntropyLossType(enum.StrEnum): cross_entropy = "cross_entropy" + forward_kl = "forward_kl" + reverse_kl = "reverse_kl" class TargetFormat(enum.StrEnum): diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py deleted file mode 100644 index a12516b5d..000000000 --- a/fast_llm/functional/cross_entropy.py +++ /dev/null @@ -1,359 +0,0 @@ -import torch - -from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward -from fast_llm.utils import Assert - - -def _torch_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A wrapper for the pytorch implementation of cross-entropy. - The cross-entropy kernels themselves are well-optimized, but the need for explicit casting - and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for with labels format and if the masking index is set to -100. - """ - # Torch compile doesn't understand this. - with torch.set_grad_enabled(grad_output is not None): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) - if target_format == TargetFormat.logits: - if logits_scale_factor != 1.0: - target = target * logits_scale_factor - if teacher_softmax_temperature != 1.0: - target = target / teacher_softmax_temperature - target = torch.softmax(target, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target - ) - else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" - ) - * loss_mask - ).mean() - if grad_output is None: - grad = None - else: - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.detach().to(logits.dtype) - return loss.detach_(), grad - - -@torch.compile -def _fused_softmax_base( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - logits = logits.float() - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - logits_max = torch.max(logits, dim=dim, keepdim=True)[0] - if group is not None: - all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = (logits - logits_max).float() - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) - if group is not None: - all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - return logits_norm, exp_logits, sum_exp_logits - - -@torch.compile -def _fused_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 -) -> torch.Tensor: - _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) - return exp_logits / sum_exp_logits - - -# @torch.compile -def _fused_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - group: ProcessGroup | None = None, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A fused implementation of cross-entropy with torch compile. - It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, - but still suboptimal because it needs multiple kernels. - """ - # Do the forward and backward passes all at once, and fused with dtype conversion. - # Way faster and more memory-efficient than the pytorch version. - - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - - if target_format == TargetFormat.logits: - target = _fused_softmax(target, logits_scale_factor / teacher_softmax_temperature, group) - - if target_format == TargetFormat.labels: - target = target.unsqueeze(-1) - loss_mask = target >= 0 - if group is None: - # Keep values within range for scatter and gather ops to work. - target = target * loss_mask - target_mask = None - else: - # Mask the target (fused) - # TODO: Could mask earlier on cpu or overlap with reduce? - vocab_start_index = logits.size(-1) * group.rank() - target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask - else: - # Target should be tensor-parallel already, no further manipulation needed. - target_mask = None - if loss_mask is not None: - loss_mask = loss_mask.unsqueeze(-1) - - if grad_output is None: - grad = None - else: - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - if target_format == TargetFormat.labels: - grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) - ) - else: - grad_base = exp_logits - sum_exp_logits * target - - grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) - if logits_scale_factor != 1.0: - grad *= logits_scale_factor - if loss_mask is not None: - grad *= loss_mask - grad = grad.to(logits.dtype) - - # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) - if target_format == TargetFormat.labels: - predicted_logits = logits_norm.gather(1, target) - if group is not None: - predicted_logits = target_mask * predicted_logits - - all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - else: - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - if group is not None and target_format != TargetFormat.labels: - # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. - # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) - # = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out. - predicted_logits = predicted_logits * group.size() - - per_sample_loss = sum_exp_logits.log() - predicted_logits - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - - loss = per_sample_loss.mean() - if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.AVG, group=group) - - return loss, grad - - -_CROSS_ENTROPY_IMPLEMENTATIONS = { - CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward, - CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward, - CrossEntropyImpl.triton: triton_cross_entropy_forward_backward, -} - - -def cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - implementation: CrossEntropyImpl = CrossEntropyImpl.fused, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Select the appropriate implementation of cross-entropy. - The triton implementation from the triton submodule is the fastest and recommended one. - It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, - which is faster and has a relatively small memory overhead. - """ - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - assert loss_mask is None - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, CrossEntropyImpl.fused) - return _fused_cross_entropy_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - group, - teacher_softmax_temperature, - ) - else: - return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - teacher_softmax_temperature=teacher_softmax_temperature, - ) - - -def distributed_log_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -): - logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) - - return logits_norm - sum_exp_logits.log() # log_softmax - - -@torch.compile -def _reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - target_format: TargetFormat, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Reverse KL using PyTorch's native kl_div function. - This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. - - Takes: - logits: [BxS, V] or [B, S, V] - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - log_ratio = distributed_log_softmax(logits, group=group) - - student_probs = log_ratio.exp() - log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs - del teacher_log_probs - # Compute loss terms: student_probs * log_ratio, then sum over vocab - # This is equivalent to kl_div(..., log_target=True) but more memory efficient - loss_terms = (student_probs * log_ratio).sum(dim=-1) - - if loss_mask is not None: - # loss mask is the same on all ranks for TP over vocab. - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() # sums over batch and seq. len. - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) - # where E_q[log(q/p)] is the expected log ratio under the student distribution - expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) - if group is not None: - all_reduce(expected, op=ReduceOp.SUM, group=group) - log_ratio = log_ratio - expected - log_ratio = log_ratio * student_probs - del student_probs # Free after use - - if loss_mask is not None: - log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - - log_ratio = log_ratio * (grad_output / valid_tokens) - grad = log_ratio.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - -def reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). - This is mode-seeking (vs. mode-covering for forward KL) and useful for: - - Encouraging the model to focus on the modes of the target distribution - - Avoiding probability mass on low-probability regions of the target - - Distillation scenarios where you want sharp, focused predictions - - Key differences from standard cross-entropy: - - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - - Takes: - logits: [BxS, V] or [B, S, V], where V is local vocab size - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - - Returns: - loss: Reverse KL divergence loss - grad: Gradients w.r.t. logits - """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") - - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - teacher_softmax_temperature=teacher_softmax_temperature, - group=group, - ) - return distillation_loss, distillation_grad diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py deleted file mode 100644 index c5ae48eba..000000000 --- a/fast_llm/functional/dpo.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch - - -def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): - # Gather log probabilities corresponding to the target tokens - return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - -def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): - return sum( - log_probabilities[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(spans) - for begin, end in sample_spans - ) - - -def compute_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - reference_model_logits: torch.Tensor, - chosen_spans: list[list[tuple[int, int]]], - rejected_spans: list[list[tuple[int, int]]], - beta: float, - grad_output: float | None, -) -> tuple[torch.Tensor, torch.Tensor]: - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_() - reference_model_logits_ = reference_model_logits.float().detach() - - policy_log_probabilities = _get_target_log_probabilities(logits_, targets) - policy_log_ratios = _get_target_log_probability_for_spans( - policy_log_probabilities, chosen_spans - ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - - reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) - reference_log_ratios = _get_target_log_probability_for_spans( - reference_log_probabilities, chosen_spans - ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - - # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= - losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) - - if grad_output is None: - loss = None - else: - loss = losses.mean() - loss.backward(torch.full_like(loss, grad_output)) - loss.detach() - return loss.detach(), logits_.grad.detach().to(logits.dtype) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py new file mode 100644 index 000000000..757832a71 --- /dev/null +++ b/fast_llm/functional/entropy_loss.py @@ -0,0 +1,348 @@ +import torch + +from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat +from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.utils import Assert + + +def _torch_entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + A wrapper for the pytorch implementation of cross-entropy. + The cross-entropy kernels themselves are well-optimized, but the need for explicit casting + and separate forward and backward kernels lead to poor performance. + TODO: loss masking only works for with labels format and if the masking index is set to -100. + """ + # Torch compile doesn't understand this. + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + logits_scaled = logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor + if target_format == TargetFormat.logits: + target_scale = logits_scale_factor / temperature + target = target if target_scale == 1.0 else target * target_scale + else: + Assert.eq(temperature, 1.0) + + if entropy_loss_type == EntropyLossType.cross_entropy: + if target_format == TargetFormat.logits: + target = torch.softmax(target, dim=-1) + loss = torch.nn.functional.cross_entropy( + logits_scaled, target, reduction="mean" if loss_mask is None else "none" + ) + else: + predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) + if target_format == TargetFormat.logits: + target_log_probability = torch.nn.functional.log_softmax(target, dim=-1) + elif target_format == TargetFormat.probabilities: + target_log_probability = target.log() + else: + target_log_probability = ( + torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).add(1.0e-10).log() + ) + if entropy_loss_type == EntropyLossType.forward_kl: + loss = torch.nn.functional.kl_div( + predicted_log_probability, + target_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + loss = torch.nn.functional.kl_div( + target_log_probability, + predicted_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, + ) + else: + raise NotImplementedError(entropy_loss_type) + if loss_mask is not None: + loss = loss.sum(dim=-1) + + if loss_mask is not None: + loss = (loss * loss_mask).mean() + + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.detach().to(logits.dtype) + return loss.detach_(), grad + + +@torch.compile +def _fused_softmax_base( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + logits = logits.float() + if logits_scale_factor != 1.0: + logits = logits * logits_scale_factor + logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + if group is not None: + all_reduce(logits_max, op=ReduceOp.MAX, group=group) + logits_norm = (logits - logits_max).float() + exp_logits = logits_norm.exp() + sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + if group is not None: + all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) + return logits_norm, exp_logits, sum_exp_logits + + +@torch.compile +def _fused_reverse_kl_base( + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + temperature: float = 1.0, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + predicted_log_probability = logits_norm - sum_exp_logits.log() + predicted_probability = exp_logits / sum_exp_logits + + if target_format == TargetFormat.logits: + target_logits_norm, _, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / temperature, group + ) + target_log_probability = target_logits_norm - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) + + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + log_ratio = predicted_log_probability - target_log_probability + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) + if group is not None: + all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) + + if grad_output is None: + grad = None + else: + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output + + return per_sample_loss, grad + + +@torch.compile +def _fused_cross_entropy_base( + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + temperature: float = 1.0, + return_kl_loss: bool = False, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + + if target_format == TargetFormat.logits: + target_logits_norm, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / temperature, group + ) + target = exp_logits_targets / sum_exp_target_logits + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) + if return_kl_loss: + if target_format == TargetFormat.logits: + target_log_probability = target_logits_norm - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) + logits_norm = logits_norm - target_log_probability + predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + if group is not None: + # We need to sum the over the tensor-parallel group, + # but this is handled in the final averaging provided we multiply by the group size. + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + + per_sample_loss = sum_exp_logits.log() - predicted_logits + + if grad_output is None: + grad = None + else: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad = (exp_logits - sum_exp_logits * target) * (grad_output / sum_exp_logits) + + return per_sample_loss, grad + + +@torch.compile +def _fused_cross_entropy_base_from_labels( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + group: ProcessGroup | None = None, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + + target = target.unsqueeze(-1) + + if group is None: + # Keep values within range for scatter and gather ops to work. + target = target * loss_mask.unsqueeze(-1) + target_mask = None + else: + # Mask the target (fused) + # TODO: Could mask earlier on cpu or overlap with reduce? + vocab_start_index = logits.size(-1) * group.rank() + target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) + target = (target - vocab_start_index) * target_mask + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss is the same because P * log(P) == 0. + predicted_logits = logits_norm.gather(1, target) + if group is not None: + predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + per_sample_loss = sum_exp_logits.log() - predicted_logits + + if grad_output is None: + grad = None + else: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad = exp_logits.scatter_add( + 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + ) * (grad_output / sum_exp_logits) + + return per_sample_loss, grad + + +@torch.compile +def _fused_entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + group: ProcessGroup | None = None, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + A fused implementation of cross-entropy with torch compile. + It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, + but still suboptimal because it needs multiple kernels. + """ + grad_output = None if grad_output is None else grad_output / logits.size(0) * logits_scale_factor + if target_format == TargetFormat.labels: + assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) + if loss_mask is None: + loss_mask = target >= 0 + per_sample_loss, grad = _fused_cross_entropy_base_from_labels( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + group, + ) + elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): + per_sample_loss, grad = _fused_cross_entropy_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + per_sample_loss, grad = _fused_reverse_kl_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + ) + else: + raise NotImplementedError(entropy_loss_type) + + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1) + loss = per_sample_loss.mean() + + if grad is not None: + if loss_mask is not None: + grad = grad * loss_mask.unsqueeze(-1) + grad = grad.to(logits.dtype) + + return loss, grad + + +_ENTROPY_LOSS_IMPLEMENTATIONS = { + EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, + EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, + EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, +} + + +def entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + implementation: EntropyLossImplementation = EntropyLossImplementation.fused, + logits_scale_factor: float = 1.0, + temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Select the appropriate implementation of cross-entropy. + The triton implementation from the triton submodule is the fastest and recommended one. + It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, + which is faster and has a relatively small memory overhead. + """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + assert loss_mask is None + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + if group: + Assert.eq(implementation, EntropyLossImplementation.fused) + return _fused_entropy_loss_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature, + ) + else: + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + temperature=temperature, + ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 295cdb74d..709d0c52d 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,8 @@ import torch -from fast_llm.functional.config import TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.utils import Assert @triton_jit() @@ -125,7 +126,8 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -134,6 +136,7 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + Assert.eq(entropy_loss_type, EntropyLossType.cross_entropy) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -163,7 +166,7 @@ def triton_cross_entropy_forward_backward( assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, - target / teacher_softmax_temperature, + target / temperature, loss_mask, grad_logits, losses, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b06f69ee5..fd76d36cb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -92,7 +92,7 @@ def get_layer( peft=peft, ) - def get_distillation_models(self) -> set[str]: + def get_reference_models(self) -> set[str]: return set() @@ -126,8 +126,8 @@ def layer_class(self) -> "type[FixedBlockSequence]": return FixedBlockSequence - def get_distillation_models(self) -> set[str]: - return self.block.get_distillation_models() + def get_reference_models(self) -> set[str]: + return self.block.get_reference_models() @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) @@ -176,10 +176,10 @@ def preprocessing_layers(self) -> dict[str, int]: # The index at which each block first appears. These blocks are used for preprocessing. return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} - def get_distillation_models(self) -> set[str]: + def get_reference_models(self) -> set[str]: models = set() for block in self.blocks.values(): - models.update(block.get_distillation_models()) + models.update(block.get_reference_models()) return models @classmethod diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py deleted file mode 100644 index 44c2d2088..000000000 --- a/fast_llm/layers/common/auxiliary_loss.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch - - -class AuxiliaryLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, scores: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa - ctx.grad = torch.full_like(aux_loss, grad) - return scores - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa - return grad_output, ctx.grad, None - - -@torch.compile -def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - return torch.mean(torch.logsumexp(logits, dim=-1) ** 2) - - -def z_loss( - logits: torch.Tensor, - z_loss_factor: float, - training: bool, - grad_scale: float | None = None, - losses: dict | None = None, - loss_name: str | None = None, - logits_scale_factor: float = 1.0, -) -> torch.Tensor: - if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) - if losses is not None and loss_name is not None: - losses[loss_name].append(loss.detach()) - if training and grad_scale is not None: - logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) - - return logits diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index f5abd1f6d..8f6e360fd 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -9,12 +9,11 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig -from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -176,7 +175,7 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metr Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: un-scaled loss for reporting? Average loss over layers? # L2 loss - activation_loss_factor = self._config.activation_distillation_factor + activation_loss_factor = self._config.distillation_loss_weight # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. # Handle possible padding by using pre-computed activation mask @@ -249,8 +248,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metr hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None # Logging - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + if losses is not None and self._distillation_loss_name in losses: + losses[self._distillation_loss_name].append(activation_loss.detach()) # Per-layer metrics if metrics is not None: metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach() @@ -279,15 +278,15 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.mlp.preprocess(kwargs) # TODO: add layer_index - _activation_distillation_loss_name = "activation_distillation_loss" + _distillation_loss_name = "activation_distillation_loss" def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_definitions = [] - if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + if self._config.distillation_model is not None: loss_definitions.append( LossDef( - name=self._activation_distillation_loss_name, - formatted_name=_format_name(self._activation_distillation_loss_name), + name=self._distillation_loss_name, + formatted_name=self._distillation_loss_name, count=count, ) ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 875be5624..2f5990ccb 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -210,18 +210,13 @@ class DecoderBlockConfig(BlockConfig): desc="Name of the reference model to use for activation-level distillation.", hint=FieldHint.feature, ) - activation_distillation_factor: float = Field( - default=0.0, - desc="Factor to scale the activation-level distillation loss by.", + distillation_loss_weight: float = Field( + default=1.0, + desc="Weight for the scale the activation distillation loss.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - def _validate(self) -> None: - super()._validate() - if self.activation_distillation_factor > 0.0 and self.distillation_model is None: - raise ValueError("Activation distillation requires a distillation_model.") - @property def layer_class(self) -> "type[DecoderBlock]": from fast_llm.layers.decoder.block import DecoderBlock @@ -245,7 +240,5 @@ def get_layer( return_input=return_input, ) - def get_distillation_models(self) -> set[str]: - if self.distillation_model is not None and self.activation_distillation_factor > 0.0: - return {self.distillation_model} - return set() + def get_reference_models(self) -> set[str]: + return set() if self.distillation_model is None else {self.distillation_model} diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 5cc351dac..413a88ed6 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -9,14 +9,15 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.decoder.mlp.mlp import MLPBase +from fast_llm.layers.language_model.loss.z_loss import z_loss from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -102,14 +103,13 @@ def _forward( # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: - logits = z_loss( - logits, - self._config.z_loss_coefficient, - self.training, - grad_scale=kwargs.get("grad_output"), - losses=losses, - loss_name=MLPLossNames.router_z_loss, - ) + is_training = (grad_scale := kwargs.get("grad_output")) is not None and self.training + if is_training or losses is not None: + loss = z_loss(logits) + if losses is not None: + losses[MLPLossNames.router_z_loss].append(loss.detach()) + if is_training: + logits = AuxiliaryLoss.apply(logits, loss, self._config.z_loss_coefficient * grad_scale) # Apply input_jitter if applicable: if self.training and self._config.jitter_eps > 0.0: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac2892..5f58024e0 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,15 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.loss.config import ( + LanguageModelLabelEntropyLossConfig, + LanguageModelLossConfig, + LanguageModelLossKwargs, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -19,21 +23,21 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): +class LanguageModelKwargs(LanguageModelLossKwargs): token_ids = "token_ids" position_ids = "position_ids" token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" # TODO: These are generic - labels = "labels" phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" loss_mask = "loss_mask" mask_inputs = "mask_inputs" +LM_HEAD_LOSS_NAME = "lm_head_loss" + + @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False @@ -135,44 +139,24 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LanguageModelLossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) - cross_entropy_splits: int | None = Field( - default=None, + # TODO: Option to chose whether to split in batch or sequence dimension? + # (Currently split merged batch and sequence, depends on `sequence_first`) + cross_entropy_splits: int = Field( + default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", - hint=FieldHint.feature, - ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -181,29 +165,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) def get_layer( self, @@ -237,21 +198,18 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() - assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + assert LM_HEAD_LOSS_NAME not in self.losses @property def max_prediction_distance(self) -> int: return 1 - @property - def enable_dpo(self) -> bool: - return self.dpo_reference_model is not None + def get_reference_models(self) -> set[str]: + return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()} @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) @@ -337,3 +295,6 @@ def layer_class(self) -> "type[LanguageModel]": from fast_llm.layers.language_model.language_model import LanguageModel return LanguageModel + + def get_reference_models(self) -> set[str]: + return self.decoder.get_reference_models() | self.head.get_reference_models() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 9f3b6506f..e8c60ae9c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -7,28 +7,26 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import gather_op, split_op +from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -from fast_llm.functional.dpo import compute_dpo_loss +from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( + LM_HEAD_LOSS_NAME, LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div, get_unique +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -69,11 +67,6 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and ( - self._config.distillation_model is not None or self._config.dpo_reference_model is not None - ): - raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") - Assert.in_range(prediction_distance, 0, prediction_heads) self._prediction_distance = prediction_distance self._prediction_heads = prediction_heads @@ -84,19 +77,9 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_parallel_logits = self._sequence_parallel and not self._vocab_parallel - if self._config.cross_entropy_splits is not None and self._sequence_parallel: + if self._config.cross_entropy_splits > 1 and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED and torch.cuda.is_available(): - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -112,6 +95,31 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + self._losses = [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in self._config.losses.items() + ] + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (loss) + return ( + 2 + * (config.forward + 2 * config.backward) + * (input_.global_shape if config.global_ else input_).numel() + * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) + ) + + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -137,8 +145,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -150,186 +156,112 @@ def forward( # MTP: Return shared_hidden to be used by the next head. return shared_hidden - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # TODO: Add marginal compute? (loss) - return ( - 2 - * (config.forward + 2 * config.backward) - * (input_.global_shape if config.global_ else input_).numel() - * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) - ) - def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - targets = self._get_targets(kwargs) - input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) + input_ = input_.detach().requires_grad_(self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) - # Transormers expect normalized outputs for the last transformer layer, + # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) - - grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 - ) - - output_weights = self.output_weights - loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses - ) - - if do_grad: - ln_output.backward(ln_output_grad) - return loss, input_.grad - else: + loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach(), kwargs, losses) + if ln_output_grad is None: return loss, None - - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: - # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None - - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None - return targets - - def get_output_weights(self) -> list[torch.Tensor]: - return [self.output_weights] + ln_output.backward(ln_output_grad.view_as(ln_output)) + return loss, input_.grad - def _logits_cross_entropy_forward_backward_split( + def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, - weight: torch.Tensor, - grad_output: float, kwargs: dict, - losses: dict | None = None, + all_losses_dict: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + + if not self.training: + logits, _ = self._logits_loss_forward_backward_partial(input_, kwargs, return_logits=True) + # TODO: Make a proper way of returning the model output. + logits = logits.detach() + if kwargs.get("global_logits"): + if self._vocab_parallel: + logits = gather_op(logits, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + logits = gather_op( + logits, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = ( + logits.detach() ) - if targets is None: - # TODO: Make a proper way of returning the model output. - loss = loss.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - loss = gather_op(loss, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - loss = gather_op( - loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss - return None, None + return None, None + + input_ = input_.flatten(0, -2) + + if self._config.cross_entropy_splits == 1: + loss_dict, input_grad = self._logits_loss_forward_backward_partial(input_, kwargs) else: - loss = None - # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._config.cross_entropy_splits - logit_input = input_.flatten(0, -2) - if self.training: - logit_input_grad = torch.empty_like(logit_input) - else: - logit_input_grad = None - split_size = div( - get_unique(target.size(0) for target in targets if target is not None), - self._config.cross_entropy_splits, - ) + input_grad = torch.empty_like(input_) tensors_split = [ - [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + ( + [None] * self._config.cross_entropy_splits + if tensor is None + else tensor.chunk(self._config.cross_entropy_splits) + ) + for tensor in [input_, input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( - logit_input_, - targets_, - weight, - grad_output, + for split_index, (partial_input_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): + partial_loss_dict, grad_ = self._logits_loss_forward_backward_partial( + partial_input_, kwargs, + split_index=split_index, ) # TODO: Avoid copy with explicit out argument. - if self.training: - logit_input_grad_.copy_(grad_) - loss = loss_ if loss is None else loss + loss_ - del grad_, loss_ - loss_count = (self._config.cross_entropy_splits or 1) * ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 + input_grad_.copy_(grad_) + if split_index == 0: + loss_dict = partial_loss_dict + else: + Assert.eq(partial_loss_dict.keys(), loss_dict.keys()) + for name in loss_dict: + loss_dict[name] += partial_loss_dict[name] + + total_loss = sum( + (loss_.weight / self._config.cross_entropy_splits) * loss_dict[loss_.name] + for loss_ in self._losses + if loss_.weight != 0.0 and loss_.name in loss_dict ) - if loss_count != 1: - loss.div_(loss_count) + if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._parallel_dim.group) - return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - - def _logits_cross_entropy_forward_backward( + all_reduce(total_loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + + if all_losses_dict is not None: + all_losses_dict[self._total_loss_name].append(total_loss) + if len(self._losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._losses): + for name, loss_value in loss_dict.items(): + if self._config.cross_entropy_splits != 1: + loss_value /= self._config.cross_entropy_splits + if self._sequence_parallel_logits: + # TODO: Async + all_reduce(loss_value, op=ReduceOp.AVG, group=self._parallel_dim.group) + all_losses_dict[name].append(loss_value) + + return total_loss, input_grad + + def _logits_loss_forward_backward_partial( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], - weight: torch.Tensor, - grad_output: float, kwargs: dict, - losses: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - group = self._parallel_dim.group if self._vocab_parallel else None + split_index: int = 0, + return_logits: bool = False, + ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, torch.Tensor | None]: logits, context = output_parallel_linear_forward( input_=input_, - weight=weight, + weight=self.output_weights, bias=None, - group=group, + group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - if self._config.logit_z_loss > 0.0: - logits = z_loss( - logits, - self._config.logit_z_loss, - self.training, - grad_output, - losses, - self._z_loss_name, - logits_scale_factor=self._config.logits_scale_factor, - ) - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] @@ -342,171 +274,46 @@ def _logits_cross_entropy_forward_backward( dims = None self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) - if targets is None: - return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets + if return_logits: + return logits, None - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + losses, grad = {}, None + for loss in self._losses: + # losses are returned unscaled but the grads are already scaled + loss_value, grad_ = loss.forward_backward( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + kwargs, + split_index, ) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - lm_loss, lm_grad = None, None - - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) - - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) - else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None - - # TODO: de-allocate earlier. - del logits + losses[loss.name] = loss_value.detach() + if grad_ is not None: + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = grad_ if grad is None else grad + grad_ - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - if self.training and losses is not None: - if dpo_loss is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) - - return loss, output_parallel_linear_backward(grad, context) if self.training else None - - @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name + return losses, output_parallel_linear_backward(grad, context) if self.training else None def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.logit_z_loss: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) - if self._config.enable_dpo: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) - ) - - if self._config.distillation_model is not None: - loss_defs.append( + return [ + LossDef(name=self._total_loss_name, formatted_name=self._total_loss_name, count=count), + *( LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), + name=loss.name, + formatted_name=loss.name, count=count, + dtype=DataType.float32, ) - ) - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) + for loss in self._losses + ), + ] - return loss_defs + def _get_full_loss_name(self, name) -> str: + return name if self._prediction_distance == 0 else f"{name}_{self._prediction_distance}" + + @functools.cached_property + def _total_loss_name(self) -> str: + return self._get_full_loss_name(LM_HEAD_LOSS_NAME) @property def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError() diff --git a/fast_llm/layers/language_model/loss/__init__.py b/fast_llm/layers/language_model/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py new file mode 100644 index 000000000..f531a1d46 --- /dev/null +++ b/fast_llm/layers/language_model/loss/config.py @@ -0,0 +1,168 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + + from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss + from fast_llm.layers.language_model.loss.entropy_loss import ( + LanguageModelDistillationLoss, + LanguageModelLabelEntropyLoss, + ) + from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss + + +class LanguageModelLossKwargs(BlockKwargs): + labels = "labels" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + advantages = "advantages" + old_log_probabilities = "old_log_probabilities" + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + def get_layer( + self, + distributed_config: DistributedConfig, + name: str, + prediction_distance: int = 0, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + ): + return self.loss_class( + self, + distributed_config, + name=name, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + vocab_parallel=vocab_parallel, + num_splits=num_splits, + logits_scale_factor=logits_scale_factor, + weight=weight, + ) + + @property + def loss_class(self) -> "type[LanguageModelLoss]": + raise NotImplementedError() + + def get_reference_models(self) -> set[str]: + return set() + + +@config_class(dynamic_type={LanguageModelLossConfig: "label"}) +class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + + @property + def loss_class(self) -> "type[LanguageModelLabelEntropyLoss]": + from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelLabelEntropyLoss + + return LanguageModelLabelEntropyLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "distillation"}) +class LanguageModelDistillationLossConfig(LanguageModelLossConfig): + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + reference_model: str = Field( + default="teacher", + desc="Name of the reference model for knowledge distillation.", + hint=FieldHint.feature, + ) + temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + @property + def loss_class(self) -> "type[LanguageModelDistillationLoss]": + from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelDistillationLoss + + return LanguageModelDistillationLoss + + def get_reference_models(self) -> set[str]: + return {self.reference_model} + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class LanguageModelDPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + reference_model: str = Field( + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + @property + def loss_class(self) -> "type[LanguageModelDPOLoss]": + from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss + + return LanguageModelDPOLoss + + def get_reference_models(self) -> set[str]: + return {self.reference_model} + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class LanguageModelZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _abstract: typing.ClassVar[bool] = False + + @property + def loss_class(self) -> "type[LanguageModelZLoss]": + from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss + + return LanguageModelZLoss diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py new file mode 100644 index 000000000..15c4c788c --- /dev/null +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -0,0 +1,81 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelDPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelDPOLoss[ConfigType: LanguageModelDPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._prediction_distance > 0: + raise NotImplementedError() + if self._num_splits > 1: + raise NotImplementedError() + if self._prediction_distance > 0: + raise NotImplementedError() + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + + if self._get_loss_mask(kwargs, split_index) is not None: + raise NotImplementedError() + + return loss_forward_backward( + self._get_grad_output(kwargs), + dpo_loss, + logits, + self._get_labels(kwargs, split_index), + self._get_reference_model_logits(self._config.reference_model, kwargs, split_index), + kwargs[LanguageModelLossKwargs.chosen_spans], + kwargs[LanguageModelLossKwargs.rejected_spans], + self._config.beta, + ) + + +def dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], + beta: float = 1.0, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + + policy_log_probabilities = _get_target_log_probabilities(logits, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) + + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) + + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= + return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean() + + +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + + +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py new file mode 100644 index 000000000..3ae87d2e9 --- /dev/null +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -0,0 +1,86 @@ +import typing + +import torch + +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.layers.language_model.loss.config import ( + LanguageModelDistillationLossConfig, + LanguageModelLabelEntropyLossConfig, +) +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + + +def _get_imlementation( + default: EntropyLossImplementation = EntropyLossImplementation.auto, + loss_type: EntropyLossType = EntropyLossType.cross_entropy, + vocab_parallel: bool = False, +) -> EntropyLossImplementation: + # Vocab parallel requires fused. + if vocab_parallel: + assert default in (EntropyLossImplementation.auto, EntropyLossImplementation.fused) + return EntropyLossImplementation.fused + + # Triton only available for cross_entropy + if TritonConfig.TRITON_ENABLED and torch.cuda.is_available() and loss_type == EntropyLossType.cross_entropy: + return EntropyLossImplementation.triton if default == EntropyLossImplementation.auto else default + else: + assert default != EntropyLossImplementation.triton + + # Otherwise, use fused. + return EntropyLossImplementation.fused if default == EntropyLossImplementation.auto else default + + +class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._implementation = _get_imlementation( + self._config.implementation, self._config.loss_type, self._vocab_parallel + ) + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return entropy_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + None, # Labels are already masked + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + implementation=self._implementation, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels, + entropy_loss_type=self._config.loss_type, + ) + + +class LanguageModelDistillationLoss[ConfigType: LanguageModelDistillationLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._prediction_distance > 0: + raise NotImplementedError() + + self._implementation = _get_imlementation( + self._config.implementation, self._config.loss_type, self._vocab_parallel + ) + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return entropy_loss_forward_backward( + logits, + self._get_reference_model_logits(self._config.reference_model, kwargs, split_index), + self._get_loss_mask(kwargs, split_index), + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + implementation=self._implementation, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.logits, + entropy_loss_type=self._config.loss_type, + ) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py new file mode 100644 index 000000000..711560a8f --- /dev/null +++ b/fast_llm/layers/language_model/loss/loss.py @@ -0,0 +1,121 @@ +import abc +import typing + +import torch + +from fast_llm.config import Configurable +from fast_llm.core.ops import split_op +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs +from fast_llm.utils import Assert + + +class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + name: str, + prediction_distance: int = 0, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + ): + super().__init__(config) + Assert.in_range(prediction_distance, 0, prediction_heads) + self._prediction_distance = prediction_distance + self._prediction_heads = prediction_heads + self._name = name + self._num_splits = num_splits + self._logits_scale_factor = logits_scale_factor + self._weight = weight * self._config.weight + self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel + self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel + self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + @abc.abstractmethod + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + pass + + @property + def name(self) -> str: + return self._name + + @property + def weight(self) -> float: + return self._weight + + def _prepare_target( + self, + target: torch.Tensor | None, + kwargs: dict[str, typing.Any], + split_index: int = 0, + *, + multi_token_format: bool = False, + ) -> torch.Tensor | None: + # MTP shift + if multi_token_format and self._prediction_heads > 1: + sequence_first: bool = kwargs[LanguageModelLossKwargs.sequence_first] + sequence_q_length = target.size(1 - sequence_first) + 1 - self._prediction_heads + target_slice = slice(self._prediction_distance, self._prediction_distance + sequence_q_length) + target = target[target_slice] if sequence_first else target[:, target_slice] + + # Flatten the batch and sequence dimensions. + target = target.flatten(0, 1) + + # Get the local chunk. + if self._sequence_parallel: + target = split_op(target, self._parallel_dim.group, 0) + + # Get the chunk for the current split. + if self._num_splits > 1: + target = target.chunk(self._num_splits)[split_index] + + return target + + def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: + grad_output = kwargs.get(LanguageModelKwargs.grad_output) + if grad_output is not None: + grad_output = ( + grad_output + * self._weight + / (self._parallel_dim.size if self._sequence_parallel else 1) + / self._num_splits + ) + return grad_output + + def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): + return self._prepare_target( + kwargs[LanguageModelLossKwargs.labels], kwargs, split_index, multi_token_format=True + ) + + def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) + + def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): + return self._prepare_target(kwargs[f"{reference_model}_logits"], kwargs, split_index) + + +def loss_forward_backward( + grad_output: float | None, fn: typing.Callable, input_: torch.Tensor, *args, **kwargs +) -> tuple[torch.Tensor, torch.Tensor | None]: + with torch.set_grad_enabled(grad_output is not None): + input_ = input_.detach().requires_grad_(grad_output is not None) + loss = fn(input_, *args, **kwargs) + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = input_.grad.detach().to(input_.dtype) + + return loss, grad diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py new file mode 100644 index 000000000..c94851bf2 --- /dev/null +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -0,0 +1,43 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Support vocab_parallel + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return loss_forward_backward( + self._get_grad_output(kwargs), + z_loss, + logits, + self._get_loss_mask(kwargs, split_index), + self._logits_scale_factor, + ) + + +@torch.compile +def z_loss( + logits: torch.Tensor, + loss_mask: "torch.Tensor | None" = None, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + """ + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + """ + out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + if loss_mask is not None: + out = out * loss_mask + return torch.mean(out) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index dc7f63299..a315beecc 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -11,7 +11,7 @@ from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.conversion.config import ( Apriel2TextCheckpointFormat, AprielHybridSSMCheckpointFormat, @@ -159,30 +159,14 @@ def _validate(self) -> None: Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) # TODO: Avoid digging inside the model. - head = self.model.base_model.head - if isinstance(head, MultiTokenPredictionConfig): - prediction_heads = head.prediction_heads - head = head.head - else: - prediction_heads = 1 - - expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} - expected_names.update(self.model.base_model.decoder.get_distillation_models()) - Assert.eq(self.reference_models.keys(), expected_names) + Assert.eq(self.reference_models.keys(), self.model.base_model.get_reference_models()) for reference_model in self.reference_models.values(): - reference_head = reference_model.model.base_model.head - if isinstance(reference_head, MultiTokenPredictionConfig): - reference_prediction_heads = reference_head.prediction_heads - reference_head = reference_head.heads - else: - reference_prediction_heads = 1 - Assert.geq(reference_prediction_heads, prediction_heads) - - Assert.none(reference_head.distillation_model) - Assert.none(reference_head.dpo_reference_model) - # TODO: Support more LM head features. - Assert.none(reference_head.cross_entropy_splits) + Assert.geq( + reference_model.model.base_model.head.max_prediction_distance, + self.model.base_model.head.max_prediction_distance, + ) + Assert.empty(reference_model.model.base_model.get_reference_models()) Assert.eq( reference_model.model.base_model.embeddings.vocab_parallel, self.model.base_model.embeddings.vocab_parallel, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8de6822fd..bd2932984 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -167,7 +167,7 @@ def preprocess_batch( if preprocessed_meta is None: preprocessed_meta = self.preprocess_meta(batch, phase) - distillation_models = self._config.decoder.get_distillation_models() + distillation_models = self._config.decoder.get_reference_models() # TODO: Support multiple distillation models? assert len(distillation_models) <= 1 reference_logits = [{} for _ in preprocessed_meta] @@ -273,7 +273,7 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False labels = torch.where(loss_mask, labels, -100) - if self._config.head.get_distillation_models(): # loss masks only used for distillation currently + if self._config.head.get_reference_models(): # loss masks only used for distillation currently # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py deleted file mode 100644 index 420316ce3..000000000 --- a/tests/functional/test_cross_entropy.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import sys -import tempfile -import traceback -import typing - -import pytest -import torch - -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -from fast_llm.utils import Assert - - -def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = "cuda" if torch.cuda.is_available() else "cpu" - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) - logits = torch.nn.functional.one_hot(target, num_columns) + logits_var - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) - logits = target + logits_var - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - return logits, target, loss_mask - - -def _compare_cross_entropy_outputs( - loss: torch.Tensor, - ref_loss: torch.Tensor, - has_grad: bool, - grad: torch.Tensor | None, - ref_grad: torch.Tensor | None, - threshold=1e-5, -): - Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) - if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) - else: - assert grad is None - assert ref_grad is None - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), - ( - (8192, 1.0, 1.0, False), # Simple - (5000, 1.0, 1.0, False), # Not a power of 2 - (5000, None, 1.0, False), # No grad - (5000, 1.0, 4.0, False), # Loss scaling - (5000, 4.0, 1.0, False), # Grad scaling - (5000, 1.0, 1.0, True), # Loss masking - (65536, 1.0, 1.0, False), # Max block size - (65537, 1.0, 1.0, False), # Above max block size - ), -) -@pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) -def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): - # TODO: Test tensor-parallel implementation. - logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) - kwargs = { - "logits": logits, - "target": target, - "loss_mask": loss_mask, - "grad_output": grad_output, - "logits_scale_factor": logits_scale_factor, - "target_format": target_format, - } - # Torch serves as the reference implementation. - out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) - out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) - - # TODO: Why is the error so high with logit scaling? - threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 - _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) - - if not torch.cuda.is_available(): - return - assert TritonConfig.TRITON_ENABLED - if num_columns > 65536: - with pytest.raises(AssertionError): - cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - else: - out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - _compare_cross_entropy_outputs( - out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold - ) - - -def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): - # Manual reference: sum over vocab then average over valid tokens. - logits = logits.detach().requires_grad_() - per_sample = torch.nn.functional.kl_div( - torch.log_softmax(target.float(), dim=-1), - torch.log_softmax(logits.float(), dim=-1), - reduction="none", - log_target=True, - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - output = per_sample.mean() - output.backward() - return output, logits.grad - - -@pytest.mark.slow -# TODO: Support the same parameterization as above in the reference implementation. -@pytest.mark.parametrize("loss_masking", [False, True]) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) -def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) - out, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - target_format=TargetFormat.logits, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) - - -def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): - try: - torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) - fn_args[0](rank, torch.distributed.group.WORLD, *fn_args[1:]) - finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -def _spawn_dist(world_size: int, *fn_args): - """ - Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. - """ - with tempfile.NamedTemporaryFile(delete=False) as tmp: - init_method = f"file://{tmp.name}" - - try: - torch.multiprocessing.spawn( - _mp_worker, - args=(world_size, init_method, fn_args), - nprocs=world_size, - join=True, - start_method="spawn", - ) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) - - -def _compare_parallel_cross_entropy( - rank: int, - group: torch.distributed.ProcessGroup, - target_format: TargetFormat, - function: typing.Callable, - loss_masking: bool, -): - # Ensure all workers have the same inputs. - torch.manual_seed(0) - world_size = torch.distributed.get_world_size(group) - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - - out, grad = function( - logits=logits.chunk(world_size, 1)[rank], - target=target.chunk(world_size, 1)[rank], - loss_mask=loss_mask, - grad_output=1, - group=group, - target_format=target_format, - ) - - out_ref, grad_ref = function( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1, - target_format=target_format, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) - - -def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): - success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): - for target_format in (TargetFormat.logits,): - for loss_masking in [False, True]: - try: - _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) - except Exception: - print( - f" >>>>>> Failed {function.__name__}, target_format, use_mask={loss_masking}", file=sys.stderr - ) - traceback.print_exc() - success = False - if not success: - raise RuntimeError("Test failed") - - -@pytest.mark.slow -def test_distillation_losses(): - _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py new file mode 100644 index 000000000..9c06c1919 --- /dev/null +++ b/tests/functional/test_entropy_loss.py @@ -0,0 +1,179 @@ +import pathlib + +import pytest +import torch + +from fast_llm.engine.distributed.config import DistributedBackend +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.utils import Assert +from tests.utils.subtest import DistributedTestContext + + +def _get_cross_entropy_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, dtype=torch.float32, device=device) / 3 + loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) + loss_mask = None + else: + target = torch.randn(256, num_columns, dtype=torch.float32, device=device) + logits = target + logits_var + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) + return logits, target, loss_mask + + +def _compare_entropy_loss_outputs( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, + threshold=1e-5, + loss_min_threshold=1e-6, +): + Assert.rms_close_relative(loss, ref_loss, threshold, loss_min_threshold) + if has_grad: + Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) + else: + assert grad is None + assert ref_grad is None + + +@pytest.mark.slow +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), + ( + (8192, 1.0, 1.0, False), # Simple + (5000, 1.0, 1.0, False), # Not a power of 2 + (5000, None, 1.0, False), # No grad + (5000, 1.0, 4.0, False), # Loss scaling + (5000, 4.0, 1.0, False), # Grad scaling + (5000, 1.0, 1.0, True), # Loss masking + (65536, 1.0, 1.0, False), # Max block size + (65537, 1.0, 1.0, False), # Above max block size + ), +) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + # TODO: Test tensor-parallel implementation. + logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) + kwargs = { + "logits": logits, + "target": target, + "loss_mask": loss_mask, + "grad_output": grad_output, + "logits_scale_factor": logits_scale_factor, + "target_format": target_format, + "entropy_loss_type": entropy_loss_type, + } + # Torch serves as the reference implementation. + out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) + out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) + + # TODO: Why is the error so high with loss masking for reverse KL? + _compare_entropy_loss_outputs( + out_fused, + out_torch, + grad_output is not None, + grad_fused, + grad_torch, + loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6, + ) + + if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): + # Triton implementation only supports cross-entropy. + return + assert TritonConfig.TRITON_ENABLED + if num_columns > 65536: + with pytest.raises(AssertionError): + entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) + else: + out_triton, grad_triton = entropy_loss_forward_backward( + **kwargs, implementation=EntropyLossImplementation.triton + ) + _compare_entropy_loss_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) + + +def _entropy_loss_distributed( + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + loss_masking: bool, + group: torch.distributed.ProcessGroup, +): + # Ensure all workers have the same inputs. + torch.manual_seed(0) + rank = group.rank() + world_size = group.size() + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + + kwargs = { + "loss_mask": loss_mask, + "grad_output": 1.0, + "target_format": target_format, + "implementation": EntropyLossImplementation.fused, + "entropy_loss_type": entropy_loss_type, + } + out_ref, grad_ref = entropy_loss_forward_backward(logits, target, **kwargs) + + out, grad = entropy_loss_forward_backward( + logits.chunk(world_size, 1)[rank], + target if target_format == TargetFormat.labels else target.chunk(world_size, 1)[rank], + group=group, + **kwargs, + ) + _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) + + +def _run_entropy_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path): + for entropy_loss_type in EntropyLossType: + for target_format in TargetFormat: + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + continue + for loss_masking in [False, True]: + name = f"{entropy_loss_type}_{target_format}_{loss_masking}" + with test_context.subtest(base_path, name, 2) as subtest: + if subtest.do_run: + _entropy_loss_distributed(target_format, entropy_loss_type, loss_masking, test_context.group) + + +@pytest.mark.slow +def test_entropy_loss_distributed_dependency(): + # Mock test so the distributed subtest are placed in the same dependency group. + pass + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +def test_run_entropy_loss_distributed(run_parallel_script, result_path): + run_parallel_script( + _run_entropy_loss_distributed, + (result_path / "test_entropy_loss",), + world_size=2, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. + ) + + +# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. +# This should still run after `test_run_entropy_loss_distributed` +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +@pytest.mark.parametrize("loss_masking", (False, True)) +def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 76c0841d9..840e3846d 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -3,9 +3,9 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans @@ -61,20 +61,14 @@ def reference_dpo_loss( def test_dpo_loss(): - random_state = np.random.RandomState(0) - logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32).requires_grad_() - reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) - targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) + logits = torch.normal(0, 1, (10, 50, 100)) + reference_model_logits = torch.normal(0, 1, (10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) + spans = get_random_spans(np.full(10, 50), 0, 10) - spans = get_random_spans(np.full(10, 50), 0, 10, random_state) - - fastllm_loss, fast_llm_grad = compute_dpo_loss( - logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 - ) + fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2]) reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) - reference_loss.backward() Assert.rms_close(fastllm_loss, reference_loss, 1e-5) - Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @pytest.mark.parametrize("gated", [True, False]) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 1a607b246..1d08986f8 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -1,347 +1,325 @@ +import collections +import dataclasses import typing import pytest import torch -from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage - -def _reverse_kl_loss( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - teacher_softmax_temperature: float = 1.0, -): - scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) - - with torch.enable_grad(): - # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) - student_log_probs = torch.log_softmax(logits, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="batchmean", - log_target=True, - ) - else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() - return loss - - -def _lm_head( - input_: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - *, - # config:LanguageModelBaseConfig, - rms_weight: torch.Tensor, - logit_weight: torch.Tensor, - grad_output: float = 1.0, - logit_scale_factor: float = 1.0, - logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, -): - hidden = torch.rms_norm( - input_.to(rms_weight.dtype), - input_.shape[-1:], - rms_weight, - 1e-5, - ) - logits = torch.nn.functional.linear(hidden, logit_weight).float() - - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None - - if logit_scale_factor != 1.0: - logits *= logit_scale_factor - z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - else: - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss - - SEQUENCE_LENGTH = 200 BATCH_SIZE = 4 HIDDEN_SIZE = 256 VOCAB_SIZE = 500 -@pytest.mark.slow -@pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) -@pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), - ( - ({}, {}, False, 1), - ({}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"sequence_first": True}, {}, False, 1), - ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), - ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), - ({"tied_embedding_weight": True}, {}, False, 1), - ({}, {}, False, 2), - ({}, {}, True, 1), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - True, - 1, - ), - ), -) -def test_lm_head( - cross_entropy_impl: CrossEntropyImpl, - config_dict: dict[str, typing.Any], - distributed_config_dict: dict[str, typing.Any], - loss_masking: bool, - prediction_heads: int, -): - if cross_entropy_impl in (CrossEntropyImpl.auto, CrossEntropyImpl.triton) and not torch.cuda.is_available(): - pytest.skip("Cuda is not available") - head_config = { - "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm", "implementation": "auto" if torch.cuda.is_available() else "torch"}, - } - config = GPTBaseModelConfig.from_dict( - { - "decoder": {"num_blocks": 0}, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "head": ( - head_config - if prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": prediction_heads, - } - ), - "hidden_size": HIDDEN_SIZE, - }, - config_dict, - update_type=UpdateType.update, - ) - head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head +@dataclasses.dataclass +class LMHeadTestConfig: + name: str + label_loss: bool | float = False + distillation_loss: bool | float = False + z_loss: bool | float = False + logits_scale_factor: float = 1.0 + compute_dtype: DataType = DataType.float32 + full_precision_residual: bool = False + sequence_first: bool = False + loss_masking: bool = False + prediction_heads: int = 1 + tied_embedding_weight: bool = False + num_splits: int = 1 + + @property + def actual_label_loss(self): + return ( + True + if self.label_loss is False and self.distillation_loss is False and self.z_loss is False + else self.label_loss + ) - model, distributed = get_base_model( - GPTModelConfig.from_dict( + def get_config(self) -> GPTModelConfig: + head_config = { + "normalization": {"type": "rms_norm"}, + "logits_scale_factor": self.logits_scale_factor, + "cross_entropy_splits": self.num_splits, + } + losses = {} + if self.label_loss is not False: + losses["label"] = {"type": "label"} + if isinstance(self.label_loss, float): + losses["label"]["weight"] = self.label_loss + if self.distillation_loss is not False: + losses["distillation"] = {"type": "distillation", "reference_model": "distillation"} + if isinstance(self.distillation_loss, float): + losses["distillation"]["weight"] = self.distillation_loss + if self.z_loss is not False: + losses["z_loss"] = {"type": "z_loss"} + if isinstance(self.z_loss, float): + losses["z_loss"]["weight"] = self.z_loss + if losses: + head_config["losses"] = losses + + return GPTModelConfig.from_dict( { - "base_model": config, - "distributed": {**distributed_config_dict, "use_cuda": torch.cuda.is_available()}, + "base_model": { + "decoder": {"num_blocks": 0}, + "embeddings": {"vocab_size": VOCAB_SIZE, "full_precision_residual": self.full_precision_residual}, + "head": ( + head_config + if self.prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": self.prediction_heads, + } + ), + "hidden_size": HIDDEN_SIZE, + "tied_embedding_weight": self.tied_embedding_weight, + }, + "distributed": {"compute_dtype": self.compute_dtype, "use_cuda": torch.cuda.is_available()}, }, ) - ) - sequence_first = config.sequence_first or ( - head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 - ) - input_ = torch.randn( - (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), - dtype=( - distributed.config.optimization_dtype.torch - if config.embeddings.full_precision_residual - else distributed.config.compute_dtype.torch - ), - device=distributed.device, - requires_grad=True, - ) - label_shape = ( - (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) - ) - if loss_masking: - loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) - else: - loss_mask = None - kwargs = { - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.grad_output: 1.0, - } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, + def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: + device = "cuda" if torch.cuda.is_available() else "cpu" + input_ = torch.randn( + ( + (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE) + ), + dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), + device=device, + requires_grad=True, ) - if loss_mask is not None: - target *= loss_mask - - kwargs[LanguageModelKwargs.labels] = target - else: - assert config.head.max_prediction_distance == 1 - target = torch.randn( - input_.shape[:-1] + (VOCAB_SIZE,), - dtype=input_.dtype, - device=distributed.device, + label_shape = ( + (SEQUENCE_LENGTH + self.prediction_heads - 1, BATCH_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + self.prediction_heads - 1) ) - kwargs[f"{head_config.distillation_model}_logits"] = target - if loss_mask is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask + kwargs: dict[str, typing.Any] = { + AttentionKwargs.sequence_first: self.sequence_first, + AttentionKwargs.grad_output: 1.0, + } + if self.loss_masking: + kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) + if self.actual_label_loss is not False: + labels = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=device, + ) + if LanguageModelKwargs.loss_mask in kwargs: + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + kwargs[LanguageModelKwargs.labels] = labels - if config.tied_embedding_weight or config.head.max_prediction_distance > 1: - logit_weight = torch.nn.Parameter( + if self.distillation_loss is not False: + assert self.prediction_heads == 1 + kwargs[f"distillation_logits"] = torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=device, + ) + return input_, kwargs + + def get_reference_outputs( + self, + head: LanguageModelHead, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + tied_logit_weight: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + # Get reference outputs and grads + logit_weight = ( + (head.output_weights if tied_logit_weight is None else tied_logit_weight).detach().requires_grad_() + ) + normalization_weight = head.final_norm.weight.detach().requires_grad_() + input_ = input_.detach().requires_grad_() + + hidden = torch.rms_norm(input_.to(normalization_weight.dtype), input_.shape[-1:], normalization_weight, 1e-5) + logits = torch.nn.functional.linear(hidden, logit_weight).float() + + if self.logits_scale_factor is not None: + logits = logits * self.logits_scale_factor + + total_loss = 0 + losses = {} + + if self.actual_label_loss is not False: + if self.sequence_first: + labels = kwargs[LanguageModelKwargs.labels][ + head._prediction_distance : head._prediction_distance + logits.size(0) + ] + else: + labels = kwargs[LanguageModelKwargs.labels][ + :, head._prediction_distance : head._prediction_distance + logits.size(1) + ] + label_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), labels.flatten(), reduction="none" + ).mean() + losses["label"] = label_loss.detach() + total_loss = total_loss + float(self.actual_label_loss) * label_loss + + if self.distillation_loss is not False: + distillation_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), + torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1), + reduction="none", + ) + if LanguageModelKwargs.loss_mask in kwargs: + distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() + distillation_loss = distillation_loss.mean() + losses["distillation"] = distillation_loss.detach() + total_loss = total_loss + float(self.distillation_loss) * distillation_loss + + if self.z_loss is not False: + z_loss = torch.logsumexp(logits, dim=-1) ** 2 + if LanguageModelKwargs.loss_mask in kwargs: + z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask] + z_loss = z_loss.mean() + losses["z_loss"] = z_loss.detach() + total_loss = total_loss + float(self.z_loss) * z_loss + + total_loss.backward() + + if len(losses) > 1: + losses[LM_HEAD_LOSS_NAME] = total_loss.detach() + else: + losses = {LM_HEAD_LOSS_NAME: total_loss.detach()} + + if head._prediction_distance > 0: + losses = {f"{name}_{head._prediction_distance}": loss for name, loss in losses.items()} + + return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses + + +_lm_head_test_configs = [] + + +def _add_configs(base_name: str, **kwargs): + # Loss masking and splits are important and error-prone, so we test them for all scenarios. + for loss_masking in (False, True): + for num_splits in (1, 2): + _lm_head_test_configs.append( + LMHeadTestConfig( + f"{base_name}{"_masked" if loss_masking else ""}{"" if num_splits == 1 else "_split"}", + loss_masking=loss_masking, + num_splits=num_splits, + **kwargs, + ) + ) + + +_add_configs("default") +_add_configs("bfloat16", compute_dtype=DataType.bfloat16) +_add_configs("full_precision_residual", full_precision_residual=True) +_add_configs("sequence_first", sequence_first=True) +_add_configs("logit_scaling", logits_scale_factor=5.0) +_add_configs("tied_embedding_weight", tied_embedding_weight=True) +_add_configs("multi_token_prediction", prediction_heads=2) +_add_configs("label_loss", label_loss=True) +_add_configs("distillation_loss", distillation_loss=True) +_add_configs("z_loss", z_loss=True) +_add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True) +_add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5) +_add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "test_config", + [ + pytest.param(_lm_head_test_config, id=_lm_head_test_config.name) + for _lm_head_test_config in _lm_head_test_configs + ], +) +def test_lm_head(test_config): + model_config = test_config.get_config() + model, distributed = get_base_model(model_config) + input_, kwargs = test_config.get_inputs() + + tied_logit_weight = ( + torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.hidden_size**-0.5) + ).normal_(HIDDEN_SIZE**-0.5) ) - else: - logit_weight = None + if test_config.tied_embedding_weight or test_config.prediction_heads > 1 + else None + ) for prediction_distance, head in enumerate(model.head.heads): # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 stage = get_stage( [head], distributed, tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], - tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + tied_parameter_duplicate_buffers=( + {head.output_weights.tensor_name: tied_logit_weight} if is_duplicate else {} + ), # Names must be kept as-is for tied weights. set_names=False, ) - # Get reference outputs and grads - if is_duplicate: - logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) - logit_weight.param_grad_is_zero = True - else: - logit_weight = head.output_weights - - ref_input = input_.detach().requires_grad_() - ref_rms_weight = head.final_norm.weight.detach().requires_grad_() - ref_logit_weight = logit_weight.detach().requires_grad_() - - ref_loss, ref_z_loss = _lm_head( - ref_input, - ( - target[prediction_distance : prediction_distance + SEQUENCE_LENGTH] - if sequence_first - else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] - ), - loss_mask, - rms_weight=ref_rms_weight, - logit_weight=ref_logit_weight, - logit_scale_factor=head_config.logits_scale_factor, - logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, + ref_total_loss, ref_input_grad, ref_logit_weight_grad, ref_normalization_weight_grad, ref_losses = ( + test_config.get_reference_outputs( + head, input_, kwargs, tied_logit_weight if prediction_distance > 0 else None + ) ) # Prepare LM head inputs if head._is_last_head: - head_input = input_ - output_grad = ref_input.new_full((), float("nan")) + head_input = input_.detach().requires_grad_() + output_grad = input_.new_full((), float("nan")) else: shared_hidden = torch.randn_like(input_) head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} - if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") - - Assert.eq( - {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, - ) - losses = {key: [] for key in loss_keys} + if is_duplicate: + logit_weight = tied_logit_weight + logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) + logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights + + losses = collections.defaultdict(list) output, context = stage.forward(head_input, kwargs, losses) + print(losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * head_config.logits_scale_factor + ) * test_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) - if ref_z_loss is not None: - Assert.eq(len(losses["z_loss"]), 1) - Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + Assert.eq(losses.keys(), ref_losses.keys()) + for name, loss in losses.items(): + assert len(loss) == 1, name + losses = {name: loss[0] for name, loss in losses.items()} - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + for name, loss in losses.items(): + Assert.rms_close_relative(loss, ref_losses[name], threshold, min_threshold, msg=name) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + # Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) shared_hidden_grad, input_grad = head_input.grad.unbind() Assert.all_equal(shared_hidden_grad, output_grad) - Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) - Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) - Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + Assert.rms_close_relative(input_grad, ref_input_grad, threshold, min_threshold) + Assert.rms_close_relative( + head.final_norm.weight.grad_buffer, ref_normalization_weight_grad, threshold, min_threshold + ) + Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight_grad, threshold, min_threshold) diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fbc..2e900cb14 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,12 +148,15 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 84466fe29..5e7526377 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -564,7 +564,9 @@ def update_and_add_testing_config( "mistral", "mistral_distill_logits", updates={ - ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation": {"type": "distillation", "loss_type": "reverse_kl", "reference_model": "teacher"}, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -587,34 +589,14 @@ def update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation", "weight"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", - ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, + ("model", "base_model", "decoder", "block", "distillation_loss_weight"): 0.1, ("reference_models"): { "teacher": { "model": {"base_model": copy.deepcopy(_mistral_base_model)}, diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 4fea1fbba..b6764c0e2 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -2,6 +2,7 @@ import json import logging import math +import os import pathlib import sys import time @@ -27,11 +28,13 @@ def __init__( timeout: float = 20.0, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, + use_cuda: bool = True, ) -> None: self._do_capture = do_capture self._timeout = timeout self._init_method = init_method self._backend = backend + self._use_cuda = use_cuda def __enter__(self): if self._do_capture: @@ -40,7 +43,7 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend + timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cuda=self._use_cuda ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size @@ -48,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start") + safe_barrier(self._group, "start", device=self._pool.device) return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end") + safe_barrier(self._group, "testing end", device=self._pool.device) # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -75,6 +78,10 @@ def rank(self) -> int: def world_size(self) -> int: return self._world_size + @property + def group(self) -> torch.distributed.ProcessGroup: + return self._group + class DistributedSubtestContext: def __init__( self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int @@ -83,7 +90,7 @@ def __init__( self._path = base_path / name self._name = name self._num_gpus = num_gpus - self._skip = self._test_context._world_size < self._num_gpus + self._skip = self._test_context._world_size < self._num_gpus and self._test_context._use_cuda self._do_run = self._test_context._rank < num_gpus and not self._skip self._do_capture = self._test_context._do_capture and self._do_run self._success = False @@ -131,10 +138,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name) - self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + safe_barrier(group, self._name, device=self._test_context._pool.device) + self._success = ( + allreduce_scalar( + self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device + ) + == group.size() + ) - if self._do_capture: + if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) report["duration"] = time.perf_counter() - self._start @@ -233,13 +245,14 @@ def parallel_worker( init_method: str, backend: DistributedBackend, do_capture: bool, + use_cuda: bool, fn: typing.Callable, fn_args: typing.Sequence[typing.Any], ): DistributedConfig.default_rank = rank DistributedConfig.default_world_size = world_size DistributedConfig.default_local_world_size = world_size - with DistributedTestContext(do_capture, 60, init_method, backend) as test_context: + with DistributedTestContext(do_capture, 60, init_method, backend, use_cuda) as test_context: fn(test_context, *fn_args) @@ -251,14 +264,17 @@ def do_run_parallel_script( world_size: int, timeout: float = 240, backend: DistributedBackend = DistributedBackend.nccl, + use_cuda: bool = True, # Use CPU device in process group pool. May be used to disable device count check ): + if "PYTHONHASHSEED" not in os.environ: + os.environ["PYTHONHASHSEED"] = "0" if do_capture: logger.warning( "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." ) torch.multiprocessing.spawn( parallel_worker, - args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args), + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, use_cuda, fn, fn_args), nprocs=world_size, join=False, ).join(timeout, grace_period=5)