From 6f677b314f01b4ba17adfd8051db2846f55bee0a Mon Sep 17 00:00:00 2001 From: Zhaoxian Wu Date: Sat, 21 Mar 2026 22:42:00 -0400 Subject: [PATCH] fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a tile's first update uses m_batch=1, tuneUpdate() benchmarks all kernels valid for that batch size, which includes SingleFunctor — a CUDA kernel with no inner batch loop that processes only batch item 0. Due to GPU timing jitter, SingleFunctor can win the benchmark race (~1/5 cold- start runs). If a subsequent update uses m_batch=M>>1, the cached kernel_pars_ is silently reused, producing a weight change of ~1/M instead of the correct value (~99% relative error). Fix: add tuned_m_batch_ to PulsedWeightUpdater to track the m_batch used during the last tuneUpdate() call. When m_batch grows beyond this value, invalidate kernel_pars_ and force-retune with the new batch size. SingleFunctor is marked invalid for m_batch>1 (via SingleBase), so a correct batch-aware kernel (BatchShared*, BatchSum, ...) is selected. Add regression test in AnalogTileTest that primes a CUDA tile with m_batch=1 then updates with m_batch=128, comparing the result against a reference tile with no priming. The test fails with the old code (~99% error) and passes with the fix (~0% error). Co-Authored-By: Claude Sonnet 4.6 --- src/rpucuda/cuda/pulsed_weight_updater.cu | 14 +++++ src/rpucuda/cuda/pulsed_weight_updater.h | 6 ++ tests/test_bindings_tiles.py | 68 +++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/src/rpucuda/cuda/pulsed_weight_updater.cu b/src/rpucuda/cuda/pulsed_weight_updater.cu index d4a54012..e1eb2f65 100644 --- a/src/rpucuda/cuda/pulsed_weight_updater.cu +++ b/src/rpucuda/cuda/pulsed_weight_updater.cu @@ -369,6 +369,7 @@ void PulsedWeightUpdater::update( update_type_ = update_type; update_count_ = 0; + tuned_m_batch_ = 0; // reset so the m_batch check below will also retune // init kernels valid_kernels_ = getValidUpdateKernels(rpucuda_device, m_batch, up); @@ -402,6 +403,18 @@ void PulsedWeightUpdater::update( } } + // Retune if m_batch has grown beyond what was used during the last tuning. + // This prevents a kernel valid only for small m_batch (e.g. SingleFunctor, + // which has no batch loop) from being reused incorrectly for larger batches. + if (!force_tuning && m_batch > tuned_m_batch_) { + force_tuning = true; + valid_kernels_ = getValidUpdateKernels(rpucuda_device, m_batch, up); + if (valid_kernels_.size() == 0) { + RPU_FATAL("Cannot find valid update kernels"); + } + kernel_pars_ = valid_kernels_[0]; + } + if (update_count_ < FORCE_TUNING_THRES) { // only once again update_count_ += 1; force_tuning = force_tuning || (update_count_ == FORCE_TUNING_THRES); @@ -412,6 +425,7 @@ void PulsedWeightUpdater::update( this->tuneUpdate( kernel_pars_, valid_kernels_, x_in, d_in, dev_weights, rpucuda_device, up, lr, m_batch, x_trans, d_trans); + tuned_m_batch_ = m_batch; } // do update diff --git a/src/rpucuda/cuda/pulsed_weight_updater.h b/src/rpucuda/cuda/pulsed_weight_updater.h index cfd229ae..ce649850 100644 --- a/src/rpucuda/cuda/pulsed_weight_updater.h +++ b/src/rpucuda/cuda/pulsed_weight_updater.h @@ -107,6 +107,12 @@ template class PulsedWeightUpdater { int x_size_ = 0; int d_size_ = 0; int update_count_ = 0; + // Tracks the m_batch value used during the last tuneUpdate() call. + // When m_batch grows beyond this value, kernel_pars_ is invalidated and + // tuneUpdate() is re-run with the new batch size. This prevents kernels + // that are only valid for small batches (e.g. SingleFunctor, which has no + // inner batch loop) from being silently reused for larger batches. + int tuned_m_batch_ = 0; bool is_async_update_ = false; int verbose_ = 0; DeviceUpdateType update_type_ = DeviceUpdateType::Undefined; diff --git a/tests/test_bindings_tiles.py b/tests/test_bindings_tiles.py index 770a840b..f6adfc36 100644 --- a/tests/test_bindings_tiles.py +++ b/tests/test_bindings_tiles.py @@ -303,3 +303,71 @@ def test_setters_weights(self): input_weights = Tensor([[6, 5, 4], [3, 2, 1]]) cpp_tile.set_weights(input_weights) self.assertEqual(cpp_tile.get_weights().shape, (2, 3)) + + def test_update_mbatch_change(self): + """Regression test: weight update must be correct when m_batch grows after m=1. + + Before the fix in pulsed_weight_updater.cu, tuneUpdate() could select + SingleFunctor (no inner batch loop) when the first update uses m_batch=1. + If m_batch then grew to M>1 the same kernel was reused silently, processing + only batch item 0 and producing a weight change of ~1/M → ~99% error. + + Fix: track tuned_m_batch_; force-retune when m_batch grows. + """ + if not self.use_cuda or SKIP_CUDA_TESTS: + raise SkipTest("tuneUpdate() only runs on CUDA tiles") + + in_size, out_size, m_batch = 32, 16, 128 + + # Zero-noise config with large BL so the update is fully deterministic + # for all-ones signals: both tiles produce identical weight changes. + rpu_config = SingleRPUConfig( + device=ConstantStepDevice( + dw_min=2 / 12000, + w_max=1.0, + w_min=-1.0, + w_max_dtod=0.0, + w_min_dtod=0.0, + up_down_dtod=0.0, + dw_min_dtod=0.0, + dw_min_std=0.0, + ) + ) + rpu_config.update.desired_bl = 255 + rpu_config.mapping.max_input_size = 2**30 + rpu_config.mapping.max_output_size = 2**30 + rpu_config.forward.is_perfect = True + rpu_config.backward.is_perfect = True + + zeros = Tensor(out_size, in_size).fill_(0.0).cuda() + x_main = Tensor(m_batch, in_size).fill_(1.0).cuda() + d_main = Tensor(m_batch, out_size).fill_(1.0).cuda() + x_prime = Tensor(1, in_size).fill_(1.0).cuda() + d_prime = Tensor(1, out_size).fill_(1.0).cuda() + + # Reference tile: no priming; tuneUpdate fires on the large-batch update. + t_ref = self.get_tile(out_size, in_size, rpu_config) + t_ref.set_learning_rate(1.0) + t_ref.set_weights(zeros) + t_ref.update(x_main, d_main) + w_ref = t_ref.get_weights()[0] + + # Primed tile: tuneUpdate fires on m=1 first, then m_batch grows to m_batch. + # The fix must detect the growth and force-retune with the larger batch size. + t_test = self.get_tile(out_size, in_size, rpu_config) + t_test.set_learning_rate(1.0) + t_test.set_weights(zeros) + t_test.update(x_prime, d_prime) # tuneUpdate(m=1) — may pick SingleFunctor + t_test.set_weights(zeros) + t_test.update(x_main, d_main) # fix: detects m_batch grew 1→m_batch, retuning + w_test = t_test.get_weights()[0] + + rel_err = (w_test - w_ref).norm() / w_ref.norm() + self.assertLess( + rel_err.item(), + 0.5, + f"Weight change after m=1 priming should match reference (got {rel_err:.1%} " + f"error). SingleFunctor may have been reused without retuning for " + f"m_batch={m_batch} — check tuned_m_batch_ tracking in " + f"pulsed_weight_updater.cu.", + )