Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/lightning/pytorch/loops/optimization/automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _clone_loss(self) -> None:
self.loss = self.closure_loss.detach().clone()

@classmethod
def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: int = 1) -> "ClosureResult":
def from_training_step_output(
cls, training_step_output: STEP_OUTPUT, normalize: int = 1, num_global_valid_tokens: Optional[int] = None
) -> "ClosureResult":
closure_loss, extra = None, {}

if isinstance(training_step_output, Mapping):
Expand All @@ -80,7 +82,10 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize:
if closure_loss is not None:
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
# note: avoid in-place operation `x /= y` here on purpose
closure_loss = closure_loss / normalize
if num_global_valid_tokens is not None:
closure_loss = closure_loss / num_global_valid_tokens
elif normalize > 1:
closure_loss = closure_loss / normalize

return cls(closure_loss, extra=extra)

Expand Down Expand Up @@ -315,6 +320,7 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:

"""
trainer = self.trainer
num_global_valid_tokens = kwargs.pop("num_global_valid_tokens", None)

training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
Expand All @@ -326,4 +332,6 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
" place."
)

return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
return self.output_result_cls.from_training_step_output(
training_step_output, trainer.accumulate_grad_batches, num_global_valid_tokens=num_global_valid_tokens
)
41 changes: 41 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import itertools
import math
import time
from collections import OrderedDict
from dataclasses import dataclass
from itertools import islice
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
self._batches_that_stepped: int = 0
self._restart_stage = RestartStage.NONE
self._skip_next_val = False
self._num_global_valid_tokens: Optional[int] = None

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -278,6 +281,12 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
StopIteration: When the epoch is canceled by the user returning -1

"""
# create a peekable iterator to look ahead without consuming the original data_fetcher
iterator = data_fetcher.iterator
assert iterator is not None
it1, self._peekable_iter = itertools.tee(iterator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this definitely seems like the smart way to implement this kind of solution

data_fetcher.iterator = it1

if self.restarting and self._should_check_val_fx(data_fetcher):
if self.val_loop.restarted_mid_evaluation:
# Go back and finish running validation
Expand Down Expand Up @@ -346,6 +355,38 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)

# Count valid tokens across global batch when using grad accumulation when using cross entropy loss
# Only calculate at the first batch of accumulation window and then reuse
if (
trainer.lightning_module.automatic_optimization
and trainer.accumulate_grad_batches > 1
and batch_idx % trainer.accumulate_grad_batches == 0
):
# require all batches in accumulation window to be properly formatted
total_valid_tokens = 0
all_formatted_batches = True
# Take next N batches without consuming the original data_fetcher
peek_batches = list(islice(self._peekable_iter, trainer.accumulate_grad_batches))
for batch in peek_batches:
# unwrap Lightning's list/tuple wrapper
if isinstance(batch, (list, tuple)):
batch = batch[0]
# require batch to be instance of dict and has labels, otherwise break
if not isinstance(batch, dict):
all_formatted_batches = False
break
labels = batch.get("labels")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the core assumption to get this working is that the user have formatted their batches such that each batch has a labels tensor?

Copy link
Contributor Author

@Sohaib-Ahmed21 Sohaib-Ahmed21 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will be listed in relevant docs (once the approach is confirmed) and its the same as labels key in transformers.

# break if labels missing or None
if labels is None:
all_formatted_batches = False
break
# safe to process
labels = torch.as_tensor(labels)
total_valid_tokens += int((labels != -100).sum().item())
self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None
Comment on lines +386 to +387
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do wonder if this is such a special case that it is better for the user to provide the information compared to lightning trying to calculate it. For example, what if the masking token is not -100 but -1 for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, I thought that if a user is passing labels for this special case, then we also list in the same relevant docs that masking token should be -100 for this case so users don't miss it. Thoughts please?


kwargs["num_global_valid_tokens"] = self._num_global_valid_tokens
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
Expand Down
Loading