Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/rpucuda/cuda/pulsed_weight_updater.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ void PulsedWeightUpdater<T>::update(
update_type_ = update_type;

update_count_ = 0;
tuned_m_batch_ = 0; // reset so the m_batch check below will also retune

// init kernels
valid_kernels_ = getValidUpdateKernels(rpucuda_device, m_batch, up);
Expand Down Expand Up @@ -402,6 +403,18 @@ void PulsedWeightUpdater<T>::update(
}
}

// Retune if m_batch has grown beyond what was used during the last tuning.
// This prevents a kernel valid only for small m_batch (e.g. SingleFunctor,
// which has no batch loop) from being reused incorrectly for larger batches.
if (!force_tuning && m_batch > tuned_m_batch_) {
force_tuning = true;
valid_kernels_ = getValidUpdateKernels(rpucuda_device, m_batch, up);
if (valid_kernels_.size() == 0) {
RPU_FATAL("Cannot find valid update kernels");
}
kernel_pars_ = valid_kernels_[0];
}

if (update_count_ < FORCE_TUNING_THRES) { // only once again
update_count_ += 1;
force_tuning = force_tuning || (update_count_ == FORCE_TUNING_THRES);
Expand All @@ -412,6 +425,7 @@ void PulsedWeightUpdater<T>::update(
this->tuneUpdate(
kernel_pars_, valid_kernels_, x_in, d_in, dev_weights, rpucuda_device, up, lr, m_batch,
x_trans, d_trans);
tuned_m_batch_ = m_batch;
}

// do update
Expand Down
6 changes: 6 additions & 0 deletions src/rpucuda/cuda/pulsed_weight_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ template <typename T> class PulsedWeightUpdater {
int x_size_ = 0;
int d_size_ = 0;
int update_count_ = 0;
// Tracks the m_batch value used during the last tuneUpdate() call.
// When m_batch grows beyond this value, kernel_pars_ is invalidated and
// tuneUpdate() is re-run with the new batch size. This prevents kernels
// that are only valid for small batches (e.g. SingleFunctor, which has no
// inner batch loop) from being silently reused for larger batches.
int tuned_m_batch_ = 0;
bool is_async_update_ = false;
int verbose_ = 0;
DeviceUpdateType update_type_ = DeviceUpdateType::Undefined;
Expand Down
68 changes: 68 additions & 0 deletions tests/test_bindings_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,71 @@ def test_setters_weights(self):
input_weights = Tensor([[6, 5, 4], [3, 2, 1]])
cpp_tile.set_weights(input_weights)
self.assertEqual(cpp_tile.get_weights().shape, (2, 3))

def test_update_mbatch_change(self):
"""Regression test: weight update must be correct when m_batch grows after m=1.

Before the fix in pulsed_weight_updater.cu, tuneUpdate() could select
SingleFunctor (no inner batch loop) when the first update uses m_batch=1.
If m_batch then grew to M>1 the same kernel was reused silently, processing
only batch item 0 and producing a weight change of ~1/M → ~99% error.

Fix: track tuned_m_batch_; force-retune when m_batch grows.
"""
if not self.use_cuda or SKIP_CUDA_TESTS:
raise SkipTest("tuneUpdate() only runs on CUDA tiles")

in_size, out_size, m_batch = 32, 16, 128

# Zero-noise config with large BL so the update is fully deterministic
# for all-ones signals: both tiles produce identical weight changes.
rpu_config = SingleRPUConfig(
device=ConstantStepDevice(
dw_min=2 / 12000,
w_max=1.0,
w_min=-1.0,
w_max_dtod=0.0,
w_min_dtod=0.0,
up_down_dtod=0.0,
dw_min_dtod=0.0,
dw_min_std=0.0,
)
)
rpu_config.update.desired_bl = 255
rpu_config.mapping.max_input_size = 2**30
rpu_config.mapping.max_output_size = 2**30
rpu_config.forward.is_perfect = True
rpu_config.backward.is_perfect = True

zeros = Tensor(out_size, in_size).fill_(0.0).cuda()
x_main = Tensor(m_batch, in_size).fill_(1.0).cuda()
d_main = Tensor(m_batch, out_size).fill_(1.0).cuda()
x_prime = Tensor(1, in_size).fill_(1.0).cuda()
d_prime = Tensor(1, out_size).fill_(1.0).cuda()

# Reference tile: no priming; tuneUpdate fires on the large-batch update.
t_ref = self.get_tile(out_size, in_size, rpu_config)
t_ref.set_learning_rate(1.0)
t_ref.set_weights(zeros)
t_ref.update(x_main, d_main)
w_ref = t_ref.get_weights()[0]

# Primed tile: tuneUpdate fires on m=1 first, then m_batch grows to m_batch.
# The fix must detect the growth and force-retune with the larger batch size.
t_test = self.get_tile(out_size, in_size, rpu_config)
t_test.set_learning_rate(1.0)
t_test.set_weights(zeros)
t_test.update(x_prime, d_prime) # tuneUpdate(m=1) — may pick SingleFunctor
t_test.set_weights(zeros)
t_test.update(x_main, d_main) # fix: detects m_batch grew 1→m_batch, retuning
w_test = t_test.get_weights()[0]

rel_err = (w_test - w_ref).norm() / w_ref.norm()
self.assertLess(
rel_err.item(),
0.5,
f"Weight change after m=1 priming should match reference (got {rel_err:.1%} "
f"error). SingleFunctor may have been reused without retuning for "
f"m_batch={m_batch} — check tuned_m_batch_ tracking in "
f"pulsed_weight_updater.cu.",
)
Loading