-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Gradient accumulation fix in cross entropy loss #21386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Gradient accumulation fix in cross entropy loss #21386
Conversation
…obal batch at start of accumulation window i.e. in first micro-batch.
… entropy loss and properly formatted batches provided by user while using gradient accumulation
| if not isinstance(batch, dict): | ||
| all_formatted_batches = False | ||
| break | ||
| labels = batch.get("labels") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the core assumption to get this working is that the user have formatted their batches such that each batch has a labels tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this will be listed in relevant docs (once the approach is confirmed) and its the same as labels key in transformers.
| total_valid_tokens += int((labels != -100).sum().item()) | ||
| self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i do wonder if this is such a special case that it is better for the user to provide the information compared to lightning trying to calculate it. For example, what if the masking token is not -100 but -1 for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this, I thought that if a user is passing labels for this special case, then we also list in the same relevant docs that masking token should be -100 for this case so users don't miss it. Thoughts please?
| # create a peekable iterator to look ahead without consuming the original data_fetcher | ||
| iterator = data_fetcher.iterator | ||
| assert iterator is not None | ||
| it1, self._peekable_iter = itertools.tee(iterator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this definitely seems like the smart way to implement this kind of solution
What does this PR do?
This PR resolves the issue where loss normalization during gradient accumulation is incorrect in case of cross entropy loss.
Key Changes:
Implemented a peekable iterator that inspects all microbatches in the global batch before the first forward pass in accumulation window to determine the total number of valid tokens.
Added support for a
labelskey in each batch, similar to thelabelsfield used in Hugging Face Transformers.Each microbatch’s loss is now divided by the total valid-token count of the global batch, ensuring correct scaling during gradient accumulation.
Documentation updates will follow to guide users on required batch structure and loss settings.
Fixes #20350
Before submitting
yes will do so.
I think so not needed yet.
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21386.org.readthedocs.build/en/21386/