diff --git a/src/aihwkit/simulator/tiles/periphery.py b/src/aihwkit/simulator/tiles/periphery.py index 420ff013..4b17628c 100644 --- a/src/aihwkit/simulator/tiles/periphery.py +++ b/src/aihwkit/simulator/tiles/periphery.py @@ -979,6 +979,10 @@ def add_quant_periphery_bias( Tensor The output of the tile added with the bias """ + # Ensure an iterable tuple for .view(*tensor_view) + if tensor_view is None: + tensor_view = self.get_tensor_view(output.dim()) + if self.bias_quantizer is None: return output + self.bias.view(*tensor_view) diff --git a/src/rpucuda/rpu_pulsed_device.h b/src/rpucuda/rpu_pulsed_device.h index 47a1cc6a..dfcc2707 100644 --- a/src/rpucuda/rpu_pulsed_device.h +++ b/src/rpucuda/rpu_pulsed_device.h @@ -110,7 +110,6 @@ template struct PulsedRPUDeviceMetaParameter : PulsedRPUDeviceMetaP } reset_dtod = MAX(reset_dtod, (T)0.0); this->reset_std = MAX(this->reset_std, (T)0.0); - reset = MAX(reset, (T)0.0); }; }; diff --git a/tests/test_specific_tiles.py b/tests/test_specific_tiles.py index 34eac448..bf18d7f1 100644 --- a/tests/test_specific_tiles.py +++ b/tests/test_specific_tiles.py @@ -150,3 +150,67 @@ def test_decay(self): self.assertAlmostEqual(bias[0].item(), gamma * (a - b) + c - d, 5) self.assertAlmostEqual(weight[0][0].item(), gamma * (a - b) + c - d, 5) + + def test_decay_with_negative_reset_bias(self): + """Test that decay keeps a negative reset bias.""" + # pylint: disable=invalid-name, too-many-locals + + lifetime = 100.0 + gamma = 0.1 + reset_bias = -0.3 + rpu_config = self.get_transfer_compound( + gamma=gamma, lifetime=lifetime, lifetime_dtod=0.0, reset=reset_bias, reset_std=0.0 + ) + model = self.get_layer(in_features=2, out_features=1, rpu_config=rpu_config) + + weight, bias = model.get_weights() + model.set_weights(weight * 0.0, bias * 0.0 if bias is not None else None) + + analog_tile = next(model.analog_tiles()) + params = analog_tile.get_hidden_parameters() + shape = params["hidden_weights_0_0"].shape + + a, b, c, d = 0.47, 0.21, 0.64, 0.12 + params["hidden_weights_0_0"] = a * ones(*shape) + params["hidden_weights_1_0"] = b * ones(*shape) + params["hidden_weights_0_1"] = c * ones(*shape) + params["hidden_weights_1_1"] = d * ones(*shape) + + a_dcy, b_dcy, c_dcy, d_dcy = 0.95, 0.28, 0.33, 0.12 + params["decay_scales_0_0"] = a_dcy * ones(*shape) + params["decay_scales_1_0"] = b_dcy * ones(*shape) + params["decay_scales_0_1"] = c_dcy * ones(*shape) + params["decay_scales_1_1"] = d_dcy * ones(*shape) + + analog_tile.set_hidden_parameters(params) + x_b = Tensor([[0.1, 0.2], [0.2, 0.4]]) + y_b = Tensor([[0.3], [0.6]]) + + if self.use_cuda: + x_b = x_b.cuda() + y_b = y_b.cuda() + + opt = AnalogSGD(model.parameters(), lr=0.0) + + epochs = 2 + for _ in range(epochs): + opt.zero_grad() + pred = model(x_b) + loss = mse_loss(pred, y_b) + + loss.backward() + opt.step() + + weight, bias = model.get_weights() + + a = (a - reset_bias) * pow(a_dcy, epochs) + reset_bias + b = (b - reset_bias) * pow(b_dcy, epochs) + reset_bias + c = (c - reset_bias) * pow(c_dcy, epochs) + reset_bias + d = (d - reset_bias) * pow(d_dcy, epochs) + reset_bias + + if self.digital_bias: + self.assertAlmostEqual(bias[0].item(), 0.0) + if self.bias and not self.digital_bias: + self.assertAlmostEqual(bias[0].item(), gamma * (a - b) + c - d, 5) + + self.assertAlmostEqual(weight[0][0].item(), gamma * (a - b) + c - d, 5)