Skip to content

fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763

Open
Zhaoxian-Wu wants to merge 1 commit intoIBM:masterfrom
Zhaoxian-Wu:fix/retune-on-mbatch-change
Open

fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763
Zhaoxian-Wu wants to merge 1 commit intoIBM:masterfrom
Zhaoxian-Wu:fix/retune-on-mbatch-change

Conversation

@Zhaoxian-Wu
Copy link
Copy Markdown

Background

PulsedWeightUpdater::tuneUpdate() benchmarks all valid CUDA kernels on the first update() call and permanently caches the winner in kernel_pars_. Among the candidates is SingleFunctor (kernel class SingleBase), which has no inner batch loop and is only correct when m_batch=1.

Due to GPU cold-start timing jitter, SingleFunctor (~0.025 ms) and batch-aware kernels like BatchSharedBase (~0.026 ms) have nearly identical benchmark times. In approximately 1 in 5 cold-start runs, SingleFunctor wins the race and is permanently selected.

This becomes a silent correctness bug when:

  1. A tile's first update() call uses m_batch=1 (e.g. a priming update, a warm-up step, or a gradient accumulation flush), causing tuneUpdate() to potentially select SingleFunctor.
  2. A subsequent update() call uses m_batch=M >> 1. SingleFunctor is reused without re-tuning and silently processes only batch item 0, producing a weight change of ~1/M instead of the correct value — roughly 99% relative error.

The bug affects all pulsed leaf devices (ConstantStepDevice, LinearStepDevice, SoftBoundsDevice, ExpStepDevice, PowStepDevice, PiecewiseStepDevice, etc.), as all of them include SingleFunctor in their valid kernel list.

Fix

Add tuned_m_batch_ to PulsedWeightUpdater to track the m_batch used during the last tuneUpdate() call. When a subsequent update() arrives with a larger m_batch, invalidate kernel_pars_ and force-retune with the new batch size. SingleFunctor is then correctly excluded (its SingleBase validity check rejects m_batch > 1), and a batch-aware kernel is selected instead.

// pulsed_weight_updater.cu — new retune guard
if (!force_tuning && m_batch > tuned_m_batch_) {
  force_tuning = true;
  valid_kernels_ = getValidUpdateKernels(rpucuda_device, m_batch, up);
  kernel_pars_ = valid_kernels_[0];
}
// after tuneUpdate():
tuned_m_batch_ = m_batch;

Minimal Working Example

The following self-contained script reproduces the bug (pre-fix) and verifies the fix.

import torch
from aihwkit.simulator.configs.configs import SingleRPUConfig
from aihwkit.simulator.configs.devices import ConstantStepDevice

IN, OUT, M = 32, 16, 128
device = torch.device("cuda")

rpu_config = SingleRPUConfig(
    device=ConstantStepDevice(
        dw_min=2/12000, w_max=1.0, w_min=-1.0,
        w_max_dtod=0., w_min_dtod=0., up_down_dtod=0.,
        dw_min_dtod=0., dw_min_std=0.,
    )
)
rpu_config.update.desired_bl       = 255
rpu_config.mapping.max_input_size  = 2**30
rpu_config.mapping.max_output_size = 2**30
rpu_config.forward.is_perfect      = True
rpu_config.backward.is_perfect     = True
tile_cls = rpu_config.get_default_tile_module_class(OUT, IN)

def make_tile():
    t = tile_cls(OUT, IN, rpu_config, False).to(device)
    t.set_learning_rate(1.0)
    return t

zeros   = torch.zeros(OUT, IN, device=device)
x_main  = torch.ones(M, IN,  device=device)
d_main  = torch.ones(M, OUT, device=device)
x_prime = torch.ones(1, IN,  device=device)
d_prime = torch.ones(1, OUT, device=device)

# Reference: no priming
t_ref = make_tile();  t_ref.set_weights(zeros)
t_ref.update(x_main, d_main)
w_ref = t_ref.get_weights()[0]

# Test: prime with m=1, then update with m=128
t_test = make_tile();  t_test.set_weights(zeros)
t_test.update(x_prime, d_prime)   # tuneUpdate fires here (m=1)
t_test.set_weights(zeros)
t_test.update(x_main, d_main)     # bug: SingleFunctor reused; fix: retuning here
w_test = t_test.get_weights()[0]

err = (w_test - w_ref).norm() / w_ref.norm()
print(f"Relative error: {err:.1%}")
print("PASS" if err < 0.5 else "FAIL — SingleFunctor reused without retuning")

Pre-fix output (when SingleFunctor wins the cold-start benchmark, ~1/5 runs):

Relative error: 99.2%
FAIL — SingleFunctor reused without retuning

Post-fix output (guaranteed, all runs):

Relative error: 0.0%
PASS

Changes

File Description
src/rpucuda/cuda/pulsed_weight_updater.h Add tuned_m_batch_ field with explanatory comment
src/rpucuda/cuda/pulsed_weight_updater.cu Reset tuned_m_batch_ on device-type change; add retune guard when m_batch grows; record tuned_m_batch_ after tuneUpdate()
tests/test_bindings_tiles.py Add AnalogTileTest::test_update_mbatch_change regression test for CUDA ConstantStepDevice

When a tile's first update uses m_batch=1, tuneUpdate() benchmarks all
kernels valid for that batch size, which includes SingleFunctor — a CUDA
kernel with no inner batch loop that processes only batch item 0. Due to
GPU timing jitter, SingleFunctor can win the benchmark race (~1/5 cold-
start runs). If a subsequent update uses m_batch=M>>1, the cached
kernel_pars_ is silently reused, producing a weight change of ~1/M
instead of the correct value (~99% relative error).

Fix: add tuned_m_batch_ to PulsedWeightUpdater to track the m_batch used
during the last tuneUpdate() call. When m_batch grows beyond this value,
invalidate kernel_pars_ and force-retune with the new batch size.
SingleFunctor is marked invalid for m_batch>1 (via SingleBase), so a
correct batch-aware kernel (BatchShared*, BatchSum, ...) is selected.

Add regression test in AnalogTileTest that primes a CUDA tile with
m_batch=1 then updates with m_batch=128, comparing the result against a
reference tile with no priming. The test fails with the old code (~99%
error) and passes with the fix (~0% error).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Zhaoxian-Wu Zhaoxian-Wu force-pushed the fix/retune-on-mbatch-change branch from 6f677b3 to aeb25d3 Compare March 22, 2026 03:18
@PabloCarmona PabloCarmona requested review from maljoras April 1, 2026 15:14
@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @maljoras @maljoras-sony, can you help us and take a look if everything is ok with this?

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the fix/retune-on-mbatch-change branch from aeb25d3 to 6f677b3 Compare April 16, 2026 05:09
@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @Zhaoxian-Wu! Please update this branch with the latest commits on master so we can check everything runs ok on the CICD side since I fixed the problem with the linting. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants