fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763
Open
Zhaoxian-Wu wants to merge 1 commit intoIBM:masterfrom
Open
fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763Zhaoxian-Wu wants to merge 1 commit intoIBM:masterfrom
Zhaoxian-Wu wants to merge 1 commit intoIBM:masterfrom
Conversation
When a tile's first update uses m_batch=1, tuneUpdate() benchmarks all kernels valid for that batch size, which includes SingleFunctor — a CUDA kernel with no inner batch loop that processes only batch item 0. Due to GPU timing jitter, SingleFunctor can win the benchmark race (~1/5 cold- start runs). If a subsequent update uses m_batch=M>>1, the cached kernel_pars_ is silently reused, producing a weight change of ~1/M instead of the correct value (~99% relative error). Fix: add tuned_m_batch_ to PulsedWeightUpdater to track the m_batch used during the last tuneUpdate() call. When m_batch grows beyond this value, invalidate kernel_pars_ and force-retune with the new batch size. SingleFunctor is marked invalid for m_batch>1 (via SingleBase), so a correct batch-aware kernel (BatchShared*, BatchSum, ...) is selected. Add regression test in AnalogTileTest that primes a CUDA tile with m_batch=1 then updates with m_batch=128, comparing the result against a reference tile with no priming. The test fails with the old code (~99% error) and passes with the fix (~0% error). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
6f677b3 to
aeb25d3
Compare
Collaborator
|
Hello @maljoras @maljoras-sony, can you help us and take a look if everything is ok with this? |
aeb25d3 to
6f677b3
Compare
Collaborator
|
Hello @Zhaoxian-Wu! Please update this branch with the latest commits on master so we can check everything runs ok on the CICD side since I fixed the problem with the linting. Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Background
PulsedWeightUpdater::tuneUpdate()benchmarks all valid CUDA kernels on the firstupdate()call and permanently caches the winner inkernel_pars_. Among the candidates isSingleFunctor(kernel classSingleBase), which has no inner batch loop and is only correct whenm_batch=1.Due to GPU cold-start timing jitter,
SingleFunctor(~0.025 ms) and batch-aware kernels likeBatchSharedBase(~0.026 ms) have nearly identical benchmark times. In approximately 1 in 5 cold-start runs,SingleFunctorwins the race and is permanently selected.This becomes a silent correctness bug when:
update()call usesm_batch=1(e.g. a priming update, a warm-up step, or a gradient accumulation flush), causingtuneUpdate()to potentially selectSingleFunctor.update()call usesm_batch=M >> 1.SingleFunctoris reused without re-tuning and silently processes only batch item 0, producing a weight change of~1/Minstead of the correct value — roughly 99% relative error.The bug affects all pulsed leaf devices (
ConstantStepDevice,LinearStepDevice,SoftBoundsDevice,ExpStepDevice,PowStepDevice,PiecewiseStepDevice, etc.), as all of them includeSingleFunctorin their valid kernel list.Fix
Add
tuned_m_batch_toPulsedWeightUpdaterto track them_batchused during the lasttuneUpdate()call. When a subsequentupdate()arrives with a largerm_batch, invalidatekernel_pars_and force-retune with the new batch size.SingleFunctoris then correctly excluded (itsSingleBasevalidity check rejectsm_batch > 1), and a batch-aware kernel is selected instead.Minimal Working Example
The following self-contained script reproduces the bug (pre-fix) and verifies the fix.
Pre-fix output (when
SingleFunctorwins the cold-start benchmark, ~1/5 runs):Post-fix output (guaranteed, all runs):
Changes
src/rpucuda/cuda/pulsed_weight_updater.htuned_m_batch_field with explanatory commentsrc/rpucuda/cuda/pulsed_weight_updater.cutuned_m_batch_on device-type change; add retune guard whenm_batchgrows; recordtuned_m_batch_aftertuneUpdate()tests/test_bindings_tiles.pyAnalogTileTest::test_update_mbatch_changeregression test for CUDAConstantStepDevice