Skip to content

Conversation

@Sohaib-Ahmed21
Copy link
Contributor

@Sohaib-Ahmed21 Sohaib-Ahmed21 commented Nov 27, 2025

What does this PR do?

This PR resolves the issue where loss normalization during gradient accumulation is incorrect in case of cross entropy loss.

Key Changes:

  • Implemented a peekable iterator that inspects all microbatches in the global batch before the first forward pass in accumulation window to determine the total number of valid tokens.

  • Added support for a labels key in each batch, similar to the labels field used in Hugging Face Transformers.

  • Each microbatch’s loss is now divided by the total valid-token count of the global batch, ensuring correct scaling during gradient accumulation.

  • Documentation updates will follow to guide users on required batch structure and loss settings.

Fixes #20350

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • yes
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
    yes will do so.
  • Did you write any new necessary tests? (not for typos and docs)
    I think so not needed yet.
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21386.org.readthedocs.build/en/21386/

…obal batch at start of accumulation window i.e. in first micro-batch.
… entropy loss and properly formatted batches provided by user while using gradient accumulation
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 27, 2025
if not isinstance(batch, dict):
all_formatted_batches = False
break
labels = batch.get("labels")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the core assumption to get this working is that the user have formatted their batches such that each batch has a labels tensor?

Copy link
Contributor Author

@Sohaib-Ahmed21 Sohaib-Ahmed21 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will be listed in relevant docs (once the approach is confirmed) and its the same as labels key in transformers.

Comment on lines +386 to +387
total_valid_tokens += int((labels != -100).sum().item())
self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do wonder if this is such a special case that it is better for the user to provide the information compared to lightning trying to calculate it. For example, what if the masking token is not -100 but -1 for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, I thought that if a user is passing labels for this special case, then we also list in the same relevant docs that masking token should be -100 for this case so users don't miss it. Thoughts please?

# create a peekable iterator to look ahead without consuming the original data_fetcher
iterator = data_fetcher.iterator
assert iterator is not None
it1, self._peekable_iter = itertools.tee(iterator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this definitely seems like the smart way to implement this kind of solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient accumulation calcluation may be incorrect

2 participants