diff --git a/docs/source/analog_update.rst b/docs/source/analog_update.rst index 939c8a8d..75b2ff07 100644 --- a/docs/source/analog_update.rst +++ b/docs/source/analog_update.rst @@ -37,21 +37,25 @@ direction. Each of applied voltage pulses has the same strength in theory but th These three example traces show the implemented ReRAM model in the simulator, and it shows that it captures the measured conductance response curve quite well. One can also see the device-to-device variability in this case as illustrated by the three different colored plots. Here we show 3 different device updates. -We have implemented 3 different ways to perform the update in Analog and hope to extend the number of available optimizers in the future: +We have implemented several different ways to perform the update in Analog and hope to extend the number of available optimizers in the future: * Plain SGD: Fully parallel update using stochastic pulse trains by Gokmen & Vlasov::ref:`[9] `. * Mixed precision: Digital rank update and transfer by Nandakumar et al.::ref:`[4] `. -* Tiki-taka: Momentum-like SGD update by Gokmen & Haensch::ref:`[10] `. +* Tiki-taka (TTv1): Momentum-like SGD update by Gokmen & Haensch::ref:`[10] `. +* TTv2: Buffered transfer with a floating-point H buffer by Gokmen::ref:`[16] `. +* TTv3 (c-TTv2): Chopped-TTv2, buffered transfer with input/output choppers by Rasch et al.::ref:`[17] `. +* TTv4 (AGAD): Analog Gradient Accumulation with Dynamic reference by Rasch et al.::ref:`[17] `. These algorithmic improvements and the adaptation of existing algorithms to the characteristics of Analog hardware is one of the key focus areas of this toolkit. -Plain SGD optimizer implements a fast way to do the gradient update fully in Analog using coincidences of stochastic pulse trains to compute -the outer product as was suggested by the paper of Gokmen & Valsov::ref:`[9] `. The Mixed precision optimizer was proposed by Nandakumar -et al in 2020::ref:`[4] `. In this optimzer, the outer product to form the weight gradients is computed in digital. Compared to the first optimizer, we have more digital -compute units on this chip than the first one which has the update fully in parallel. This would be a good choice for much more non-ideal devices. The third -optimizer called Tiki Taka implements an algorithm that is similar to momentum stochastics gradient decent and assumes that both the momentum term and -the weight matrix are on analog cross bar arrays as discussed in:ref:`[10] `. The gradient update computation onto the momentum matrix uses the same fast update it it was -explained in the plain SGD case. +Plain SGD optimizer implements a fast way to do the gradient update fully in Analog using coincidences of stochastic pulse trains to compute +the outer product as was suggested by the paper of Gokmen & Vlasov::ref:`[9] `. The Mixed precision optimizer was proposed by Nandakumar +et al in 2020::ref:`[4] `. In this optimizer, the outer product to form the weight gradients is computed in digital. Compared to the first optimizer, we have more digital +compute units on this chip than the first one which has the update fully in parallel. This would be a good choice for much more non-ideal devices. The Tiki-taka +optimizer (TTv1) implements an algorithm that is similar to momentum SGD and assumes that both the momentum term and +the weight matrix are on analog crossbar arrays as discussed in :ref:`[10] `. TTv2 adds a floating-point H buffer between the fast and slow +arrays :ref:`[16] `, enabling lossless accumulation of fractional gradient steps. TTv3 (c-TTv2) further introduces input/output choppers that suppress +systematic bias :ref:`[17] `, and TTv4 (AGAD) extends TTv3 with a statistical approach for computing the gradient update :ref:`[17] `. Plain SGD: Fully Parallel Update --------------------------------- @@ -99,8 +103,8 @@ See `example 12 `. -Tiki-taka: Momentum-like SGD Update ------------------------------------ +Tiki-taka (TTv1): Momentum-like SGD Update +------------------------------------------ Tiki-Taka optimizer is also algorithmically similar to momentum SGD. The difference here is that the momentum matrix is also in Analog. This implied that the outer product update onto the momentum matrix is done on analog in fully parallel mode using stochastic pulse trains we described earlier. Therefore, this optimizer does not have the potential bottleneck to compute the outer product in digital as done in the @@ -111,3 +115,145 @@ This is explained in more details in this paper. .. image:: ../img/tikitaka.png :alt: Tiki-taka: Momentum-like SGD Update + +**TTv1 Formulation** + +The core update equations for Tiki-taka (TTv1) are: + +.. math:: + + A = A \mathrel{-} \beta \cdot \text{Gradient} + +.. math:: + + C = C \mathrel{+} \alpha \cdot A + +Where: + +* :math:`A` is the fast (momentum) array, updated at every gradient step with learning rate :math:`\beta` +* :math:`C` is the slow (weight) array, updated periodically via transfer events with coefficient :math:`\alpha` +* The gradient is computed on :math:`\gamma \cdot A + C`, where :math:`\gamma` controls the contribution of A to the effective weight. + +The key distinguishing feature is that momentum decay is achieved implicitly through device asymmetry (random up/down pulses on :math:`A`) rather than explicit multiplicative decay, which is difficult to implement in analog hardware. + +**Residual Learning and Bit-Slicing with Non-Zero** :math:`\gamma` + +The ``gamma`` parameter enables two complementary mechanisms in TTv1-TTv3: + +.. math:: + + W_{\text{eff}} = \gamma \cdot A + C + +The gradient is evaluated at the effective weight :math:`\gamma A + C` rather than at C alone, +so A can directly influence the gradient direction and magnitude. +The relative contribution of A is controlled by ``gamma``: + +**When** ``gamma = 0`` **(default):** A is fully hidden — gradients are +evaluated only at C. A acts as a hidden momentum buffer whose content is +periodically transferred to C. Because transfers are discrete and +infrequent, C may lag the true gradient direction, introducing gradient +staleness. + +**When** ``gamma > 0`` **:** A becomes an active *residual branch* on top of C, +enabling two complementary mechanisms: + +1. **Residual learning:** A can now track the residual of C: after each + transfer, any remaining deviation of C from the ideal weight (due to + device non-linearity, write noise, saturation, or drift) is visible in the + gradient evaluated at :math:`\gamma A + C`. This gradient drives A in the + direction that corrects C's error, so A continuously compensates for + whatever C fails to represent. When the next transfer event fires, the + correction accumulated in A is pushed into C, pulling it closer to the + ideal weight. The mechanism is analyzed in detail by Wu et al.::ref:`[18] `. + +2. **Bit-slicing (precision enhancement):** The two-layer decomposition + :math:`W = \gamma A + C` acts as a *bit-slicing* mechanism: the fast array A + can represent finer-grained weight updates (higher effective precision) while + the slow array C provides stable storage of the coarse weight values. By + tuning ``gamma`` and the transfer frequency, the effective weight granularity + can be reduced below the device's native conductance step, enabling higher + training accuracy without modifying the underlying analog device. This + approach is particularly valuable when C's device granularity is coarse or + non-uniform. See Li et al.::ref:`[19] ` for its extention to multi-array + setting as well as the detailed analysis. + +TTv2: Buffered Transfer +----------------------- +The buffered transfer algorithm (TTv2), proposed by Gokmen::ref:`[16] `, extends +Tiki-taka by introducing a floating-point H buffer between the fast analog array A and the +slow weight array C. Instead of sending stochastic update pulses to C at every gradient step, +each transfer event first reads a column of A and accumulates the result in the digital buffer H: + +.. math:: + + H \mathrel{+}= \alpha \cdot A + +where :math:`\alpha` is a learning-rate scale factor. An integer number of pulses is sent to C +only when the accumulated value exceeds the threshold :math:`|H| \geq 1`, after which H is +reduced by the number of steps taken (or decayed by a momentum factor when ``forget_buffer=True``). + +This buffered scheme provides two key advantages over plain Tiki-taka (TTv1): + +* **Reduced write noise on C** — pulses are sent to the slow device only when the buffer + is large enough to justify a full integer step, so C is updated less frequently and with + steps that match its conductance granularity. +* **Lossless accumulation** — fractional gradient contributions that would otherwise be + rounded away by the finite granularity of C are preserved in the floating-point buffer + until they can be committed. + +The algorithm is configured via +:class:`~aihwkit.simulator.configs.compounds.BufferedTransferCompound`. + +TTv3 (c-TTv2): Chopped Buffered Transfer +----------------------------------------- +TTv3, originally named **Chopped-TTv2 (c-TTv2)** by Rasch et al.::ref:`[17] `, extends TTv2 by adding *choppers* — +random binary sign-flip patterns applied to the input (and optionally output) of each transfer +read. After each column read of A, the row chopper sign is randomly toggled with probability +``in_chop_prob``; the current chopper state multiplies both the update written to A and the +value accumulated in H: + +.. math:: + + H \mathrel{+}= \text{chopper} \cdot \alpha \cdot A + +Because the chopper sign is applied consistently to both the write and the read, the effective +gradient in H is unbiased. Systematic device offsets and long-range correlations on the fast +array A average out over successive chopper flips, enabling more aggressive transfer rates +without accumulating systematic errors on C. + +The standard TTv3 transfer logic — accumulation, threshold crossing, and pulse dispatch to C — +is identical to TTv2. The sole difference is that all reads of A are chopper-modulated. Both +input choppers (``in_chop_prob``) and output choppers (``out_chop_prob``) can be configured +independently. + +The algorithm is configured via +:class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound`. + +.. _using_simulator: using_simulator.html + +TTv4 (AGAD): Dynamic Chopped Transfer +--------------------------------------- +TTv4, originally named **Analog Gradient Accumulation with Dynamic reference (AGAD)** by +Rasch et al.::ref:`[17] `, extends TTv3 by introducing a dynamic *symmetric point +tracking* mechanism for establishing reference values on-the-fly, using a modest amount of +additional digital compute, rather than relying on a separate reference conductance array or +differential read circuitry. + +Concretely, TTv4 establishes dynamic symmetric points by comparing the running mean of reads +taken during the two most recent chopper half-periods. The transfer onto C is proportional to +the *difference* between these two half-period means: + +.. math:: + + \Delta C \propto \bar{A}_{\text{new}} - \bar{A}_{\text{old}} + +No update is dispatched to C when this difference is not statistically distinguishable from +noise, as judged by the running standard-deviation estimate (i.e., a standard-error of the +mean noise gate is applied). Because the reference values are derived from the device reads +themselves rather than from a separately measured baseline, AGAD greatly simplifies hardware +design — it does not need a separate conductance array for reference values or differential +read circuitry. + +The algorithm is configured via +:class:`~aihwkit.simulator.configs.compounds.DynamicTransferCompound`. + diff --git a/docs/source/paper_references.rst b/docs/source/paper_references.rst index 7823b988..c21a3d86 100644 --- a/docs/source/paper_references.rst +++ b/docs/source/paper_references.rst @@ -45,6 +45,17 @@ Paper References * [14] 2023 Nature, `An analog-AI chip for energy-efficient speech recognition and transcription`_ +* [15] 2025 NeurIPS, + `Analog In-memory Training on General Non-ideal Resistive Elements: The Impact of Response Functions`_ + +* [16] 2021 Frontiers in Artificial Intelligence, + `Enabling Training of Neural Networks on Noisy Hardware`_ + +* [17] 2024 Nature Communications, + `Fast and robust analog in-memory deep neural network training`_ + +* [18] 2026 AISTATS, + `In-memory Training on Analog Devices with Limited Conductance States via Multi-tile Residual Learning`_ .. _`Memory devices and applications for in-memory computing`: https://www.nature.com/articles/s41565-020-0655-z .. _`Accurate deep neural network inference using computational phase-change memory`: https://www.nature.com/articles/s41467-020-16108-9 @@ -60,7 +71,7 @@ Paper References .. _`Hardware-aware training for large-scale and diverse deep learning inference workloads using in-memory computing-based accelerators`: https://www.nature.com/articles/s41467-023-40770-4 .. _`A 64-core mixed-signal in-memory compute chip based on phase-change memory for deep neural network inference`: https://www.nature.com/articles/s41928-023-01010-1 .. _`An analog-AI chip for energy-efficient speech recognition and transcription`: https://www.nature.com/articles/s41586-023-06337-5 - - - - +.. _`Analog In-memory Training on General Non-ideal Resistive Elements: The Impact of Response Functions`: https://openreview.net/forum?id=WhEPg4mUs6 +.. _`Enabling Training of Neural Networks on Noisy Hardware`: https://www.frontiersin.org/articles/10.3389/frai.2021.699148/full +.. _`Fast and robust analog in-memory deep neural network training`: https://www.nature.com/articles/s41467-024-51221-z +.. _`In-memory Training on Analog Devices with Limited Conductance States via Multi-tile Residual Learning`: https://arxiv.org/abs/2510.02516 diff --git a/docs/source/using_simulator.rst b/docs/source/using_simulator.rst index 6b257470..eb17acf4 100644 --- a/docs/source/using_simulator.rst +++ b/docs/source/using_simulator.rst @@ -128,12 +128,15 @@ Resistive device class Descriptio Compound Devices """""""""""""""" -==================================================================== ======== -Resistive device class Description -==================================================================== ======== -:class:`~aihwkit.simulator.configs.devices.TransferCompound` abstract device model that takes 2 or more devices per crosspoint and implements a 'transfer' based learning rule such as Tiki-Taka (see `Gokmen & Haensch 2020`_). -:class:`~aihwkit.simulator.configs.devices.MixedPrecisionCompound` abstract device model that takes one devices per crosspoint and implements a 'mixed-precision' based learning rule where the rank-update is done in digital instead of using a fully analog parallel write (see `Nandakumar et al. 2020`_). -==================================================================== ======== +========================================================================= ======== +Resistive device class Description +========================================================================= ======== +:class:`~aihwkit.simulator.configs.devices.TransferCompound` abstract device model that takes 2 or more devices per crosspoint and implements a transfer-based learning rule (TTv1 / Tiki-Taka, see `Gokmen & Haensch 2020`_). +:class:`~aihwkit.simulator.configs.compounds.BufferedTransferCompound` extends TransferCompound with a floating-point H buffer between the fast and slow arrays, implementing the TTv2 algorithm (see `Gokmen 2021`_). +:class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound` extends BufferedTransferCompound with input/output choppers, implementing TTv3 (c-TTv2) (see `Rasch et al. 2024`_). +:class:`~aihwkit.simulator.configs.compounds.DynamicTransferCompound` extends ChoppedTransferCompound with dynamic on-the-fly reference estimation, implementing the TTv4 (AGAD) algorithm (see `Rasch et al. 2024`_). +:class:`~aihwkit.simulator.configs.devices.MixedPrecisionCompound` abstract device model that takes one devices per crosspoint and implements a 'mixed-precision' based learning rule where the rank-update is done in digital instead of using a fully analog parallel write (see `Nandakumar et al. 2020`_). +========================================================================= ======== RPU Configurations ------------------ @@ -270,8 +273,8 @@ contribution simple adds up to form a joined effective weight. During forward/backward this joint effective weight will be used. Update, however, will be done on each of the "hidden" weights independently. -Transfer Compound Device -"""""""""""""""""""""""" +Transfer Compound Device (TTv1 / Tiki-taka) +"""""""""""""""""""""""""""""""""""""""""""" Compound devices are more complex than unit cell devices, which have a number of devices per crosspoint, however, they share the underlying implementation. For instance, the "Transfer Compound Device" does @@ -337,6 +340,118 @@ rule instead of plain SGD. Once the configuration is done, the usage of this complex analog tile for testing or training from the user point of view is however the same as for other tiles. +Buffered Transfer Compound Device (TTv2) +""""""""""""""""""""""""""""""""""""""""" +The :class:`~aihwkit.simulator.configs.compounds.BufferedTransferCompound` +extends the basic transfer compound with a floating-point H buffer that sits +between the fast analog array A and the slow weight array C (see `Gokmen 2021`_). + +At each transfer event, a column of A is read and the result is accumulated +into H. An integer pulse is sent to C only when the accumulated value exceeds +the device granularity threshold (:math:`|H| \geq 1`), after which H is +reduced by the number of steps taken. This design has two practical advantages +over plain Tiki-taka (TTv1): + +* **Reduced write noise on C** — the slow device is updated infrequently and + only with integer-aligned steps that match its conductance granularity. +* **Lossless accumulation** — fractional gradient contributions are preserved + in the floating-point buffer until they can be committed as full pulses. + +A minimal TTv2 configuration:: + + from aihwkit.nn import AnalogLinear + from aihwkit.simulator.configs import UnitCellRPUConfig + from aihwkit.simulator.configs.compounds import BufferedTransferCompound + from aihwkit.simulator.configs.devices import SoftBoundsDevice + + rpu_config = UnitCellRPUConfig( + device=BufferedTransferCompound( + unit_cell_devices=[SoftBoundsDevice(), SoftBoundsDevice()], + transfer_every=2, # transfer every 2 batches + momentum=0.1, # fraction of buffer kept after transfer + ) + ) + model = AnalogLinear(4, 2, bias=True, rpu_config=rpu_config) + +Chopped Transfer Compound Device (TTv3 / c-TTv2) +"""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +The :class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound` +extends the basic transfer compound with two additional features: + +* **Choppers** — input and output sign-flip patterns that are toggled + stochastically after each transfer read, suppressing systematic offset + errors and enabling faster transfer rates (TTv3/v4 behaviour). +* **Floating-point H buffer** — a per-weight accumulator that collects + fractional transfer increments; an integer pulse is sent to the slow + device only when the buffer reaches ±1. + +*Buffer update strategies* + +The H buffer is updated via standard accumulation: each transfer event reads a column +of A, scales it by a learning-rate factor α, and adds the result directly to the buffer: + +.. math:: + + H \mathrel{+}= \alpha \cdot \text{chopper} \cdot A + +When ``|H| ≥ 1`` an integer number of pulses is sent to C. After +stepping, H is reduced either by subtracting the steps taken or by an +exponential decay, controlled by ``forget_buffer`` and ``momentum``. +Because H grows without bound between transfer events, its magnitude at +step time is determined by the inter-transfer gradient history. + +For a detailed description of all governing parameters see +`analog_update`_. For the ``gamma`` residual-learning parameter see the +*Residual learning* discussion in the Transfer Compound Device (TTv1) section above. + +A minimal TTv3 configuration (no residual, standard buffer):: + + from aihwkit.nn import AnalogLinear + from aihwkit.simulator.configs import UnitCellRPUConfig + from aihwkit.simulator.configs.compounds import ChoppedTransferCompound + from aihwkit.simulator.configs.devices import SoftBoundsDevice + + rpu_config = UnitCellRPUConfig( + device=ChoppedTransferCompound( + unit_cell_devices=[SoftBoundsDevice(), SoftBoundsDevice()], + transfer_every=10, # transfer every 10 batches + in_chop_prob=0.1, # chopper switching probability + ) + ) + model = AnalogLinear(4, 2, bias=True, rpu_config=rpu_config) + +Dynamic Transfer Compound Device (TTv4 / AGAD) +""""""""""""""""""""""""""""""""""""""""""""""" +The :class:`~aihwkit.simulator.configs.compounds.DynamicTransferCompound`, +originally named **Analog Gradient Accumulation with Dynamic reference (AGAD)** +(see `Rasch et al. 2024`_), extends +:class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound` with a +*dynamic on-the-fly reference* for computing the transfer update from A to C. + +Rather than relying on a separate reference conductance array or differential +read circuitry, TTv4 establishes reference values from the device reads +themselves: the running mean of reads from the two most recent chopper +half-periods are compared, and the transfer onto C is proportional to their +*difference*. No update is dispatched when the difference is smaller than the +estimated noise floor (standard error of the mean), acting as a noise gate. +This greatly simplifies hardware design while maintaining robust training. + +A minimal TTv4 configuration:: + + from aihwkit.nn import AnalogLinear + from aihwkit.simulator.configs import UnitCellRPUConfig + from aihwkit.simulator.configs.compounds import DynamicTransferCompound + from aihwkit.simulator.configs.devices import SoftBoundsDevice + + rpu_config = UnitCellRPUConfig( + device=DynamicTransferCompound( + unit_cell_devices=[SoftBoundsDevice(), SoftBoundsDevice()], + transfer_every=10, # number of batches per chopper period + in_chop_prob=0.1, # chopper switching frequency (regular) + ) + ) + model = AnalogLinear(4, 2, bias=True, rpu_config=rpu_config) + Mixed Precision Compound """""""""""""""""""""""" @@ -415,6 +530,9 @@ For more info look into :py:mod:`aihwkit.simulator.parameters.enums.RPUDataType` .. _Gokmen & Haensch 2020: https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full +.. _Gokmen 2021: https://www.frontiersin.org/articles/10.3389/frai.2021.699148/full .. _Example 7: https://github.com/IBM/aihwkit/blob/master/examples/07_simple_layer_with_other_devices.py .. _Example 8: https://github.com/IBM/aihwkit/blob/master/examples/08_simple_layer_with_tiki_taka.py .. _Nandakumar et al. 2020: https://www.frontiersin.org/articles/10.3389/fnins.2020.00406/full +.. _Rasch et al. 2024: https://www.nature.com/articles/s41467-024-51221-z +.. _analog_update: analog_update.html diff --git a/src/aihwkit/simulator/configs/compounds.py b/src/aihwkit/simulator/configs/compounds.py index bcf6ea9a..26eab3b8 100644 --- a/src/aihwkit/simulator/configs/compounds.py +++ b/src/aihwkit/simulator/configs/compounds.py @@ -483,6 +483,14 @@ class TransferCompound(UnitCell): ie. whether to scale the transfer LR with the current LR of the SGD. """ + scale_fast_lr: bool = True + """Whether to scale the fast LR with the current LR of the SGD. + + When ``True`` (default), the effective fast LR is ``fast_lr * current_lr``. + When ``False``, ``fast_lr`` is used as an absolute value regardless of + the optimizer LR schedule. + """ + transfer_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read of a transfer event. @@ -608,6 +616,66 @@ class ChoppedTransferCompound(TransferCompound): ``random_selection=False``, ``with_reset_prob=0.0``, ``n_reads_per_transfer=1``). + At each iteration, lr is first scaled as follows. + Let current_lr be the current learning rate of the optimizer, e.g. given by + ``AnalogSGD(model.parameters(), lr=0.01)``, + which varys during training if a PyTorch learning rate scheduler is used, e.g. controlled by + ``StepLR(optimizer, step_size=10, gamma=0.1)``. + + "base_buffer_granularity": + dw_min_A = granularity of fast array (A) + threshold = thres_scale * buffer_granularity * dw_min_A + if auto_granularity > 0: + period = in_size * transfer_every + base_buffer_granularity = threshold * auto_granularity / period + else: + base_buffer_granularity = threshold + "final_fast_lr": (a.k.a. pulse-count lr, used to update A) + if fast_lr > 0: + if scale_fast_lr: + base_fast_lr = fast_lr * current_lr + else: + base_fast_lr = fast_lr + else: + base_fast_lr = current_lr + if auto_scale: + dw_min_A = granularity of fast array (A) + x_max = EMA of max(|x_input|) [see `auto_scale`] + d_max = EMA of max(|d_input|) [see `auto_scale`] + final_fast_lr = base_fast_lr * desired_BL * dw_min_A / (x_max * d_max) + else: + final_fast_lr = base_fast_lr + "final_transfer_lr": + if scale_transfer_lr: + base_transfer_lr = transfer_lr * current_lr + else: + base_transfer_lr = transfer_lr + if correct_gradient_magnitudes: + final_buffer_granularity = base_buffer_granularity * (dw_min_C / dw_min_A) + final_transfer_lr = base_transfer_lr / final_buffer_granularity / final_fast_lr + else: + final_buffer_granularity = base_buffer_granularity + final_transfer_lr = base_transfer_lr / final_buffer_granularity + + Recursion: + 1. Gradient computation + Gradient = outer product of input/output vectors, computed at W = gamma * A + C + 2. fast weight update (A) + A += chopper * final_fast_lr * Gradient + 3. buffer update (H) + H = H + chopper * final_transfer_lr * A + 4. transfer from buffer to slow weight (C) + if abs(H) >= 1: + n_steps = trunc(H) # integer pulse count, clamped to desired_BL + Send n_steps pulses to C # (i.e., C += n_steps * dw_min_C) + if forget_buffer: + H = momentum * H # decay full buffer + else: + H -= (1 - momentum) * n_steps # subtract steps taken, keep remainder + + Here H is a dimensionless pulse-count accumulator (threshold = 1.0). + The granularity factors are absorbed into final_transfer_lr via lr_scale. + Note: This device is identical to :class:`BufferedTransferCompound` if the chopper probabilities are set to 0 (with the above @@ -700,10 +768,13 @@ class ChoppedTransferCompound(TransferCompound): This will dynamically compute a reasonable update strength onto the fast matrix. ``fast_lr`` can be used to scale the gradient update further. + + When ``auto_scale=True``, this keeps the number of pulses roughly + constant across training regardless of the gradient magnitude. """ auto_momentum: float = 0.99 - """Momentum of the gradient when using auto scale """ + """Momentum of the gradient magnitude EMA when ``auto_scale`` is enabled.""" correct_gradient_magnitudes: bool = False """Scale the transfer LR with the fast LR to yield the @@ -750,6 +821,13 @@ class ChoppedTransferCompound(TransferCompound): ie. whether to scale the transfer LR with the current LR of the SGD. """ + scale_fast_lr: bool = False + """Whether to scale the fast device LR with the current optimizer LR. + + When ``True``, the effective fast LR is ``fast_lr * current_lr``. + When ``False`` (default), ``fast_lr`` is used as an absolute value. + """ + transfer_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read of a transfer event. diff --git a/src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp b/src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp index 3f60a0a5..b79584b0 100644 --- a/src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp +++ b/src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp @@ -756,6 +756,7 @@ template void declare_rpu_devices(py::module &m, std::string type_n .def_readwrite("fast_lr", &TransferParam::fast_lr) .def_readwrite("transfer_lr_vec", &TransferParam::transfer_lr_vec) .def_readwrite("scale_transfer_lr", &TransferParam::scale_transfer_lr) + .def_readwrite("scale_fast_lr", &TransferParam::scale_fast_lr) .def_readwrite("transfer_forward", &TransferParam::transfer_io) .def_readwrite("transfer_update", &TransferParam::transfer_up) .def( diff --git a/src/rpucuda/cuda/rpucuda_chopped_transfer_device.cu b/src/rpucuda/cuda/rpucuda_chopped_transfer_device.cu index bfb62500..b170d022 100644 --- a/src/rpucuda/cuda/rpucuda_chopped_transfer_device.cu +++ b/src/rpucuda/cuda/rpucuda_chopped_transfer_device.cu @@ -190,6 +190,8 @@ void ChoppedTransferRPUDeviceCuda::readMatrix( DEBUG_CALL(cwo_->printWeightOutputInChopper()); T *output_weights = cwo_->getWeightOutputData(); chop_t *wo_chopper_data = cwo_->getWeightOutputInChopperData(); + // matrix consisting of transfer vectors. each column is a one-hot vector + // also the in-chopper is applied size_t max_n_vec_per_chunk = (par.transfer_max_vec_chunk_size + n_vec - 1) / n_vec; size_t n_chunks = (n_vec + max_n_vec_per_chunk - 1) / max_n_vec_per_chunk; @@ -350,6 +352,11 @@ __global__ void kernelChoppedTransfer( const bool forget_buffer_in, const bool no_buffer) { + // W_buffer += transfer_in * lr_scale_in; + // transfer_out: the number of step to increase by `W_buffer`; + + UNUSED(in_chopper); + const T max_steps = (T)max_steps_in; const int w_size = out_size * in_size; const int t_size = out_size * n_vec; @@ -402,17 +409,75 @@ __global__ void kernelChoppedTransfer( } } + +// Applies chopper correction when gamma > 0. +// The fast weight A is stored in "chopped" form: A_stored[i,j] ≈ c_d[i] * c_x[j] * A_true[i,j]. +// After the base GEMV produces W = gamma * A_stored + C, this kernel corrects it to +// W = gamma * c_d[i] * c_x[j] * A_stored[i,j] + C[i,j] +// by adding gamma * (c_d[i]*c_x[j] - 1) * A_stored[i,j] to each element. +// Layout: W[x_idx * d_size + d_idx] (column-major, d as inner dimension). +template +__global__ void kernelApplyChopperCorrectionToWeights( + T *dev_weights, + const T *A, + const chop_t *c_x, // size: x_size + const chop_t *c_d, // size: d_size + const T gamma, + const int d_size, + const int total_size) { + + RPU_CUDA_1D_KERNEL_LOOP(idx, total_size) { + int d_idx = idx % d_size; + int x_idx = idx / d_size; + T chop = (T)(c_x[x_idx] * c_d[d_idx]); + dev_weights[idx] += gamma * (chop - (T)1.0) * A[idx]; + } +} + +template +void ChoppedTransferRPUDeviceCuda::reduceToWeights(CudaContextPtr context, T *dev_weights) { + const auto &par = getPar(); + + if (par.fullyHidden()) { + // fully_hidden_: dev_weights_ptrs_[last] == dev_weights already, no-op. + return; + } + + // Standard GEMV: W = gamma * A + C (+ any additional slow devices) + VectorRPUDeviceCuda::reduceToWeights(context, dev_weights); + + T gamma = par.gamma_vec[0]; // fast weight (device[0]) contribution scale + if (gamma == (T)0.0) { + return; + } + + // Apply chopper correction: replace gamma*A with gamma*(c_d⊗c_x)*A element-wise. + chop_t *c_x = cwo_->getXChopperInData(); + chop_t *c_d = cwo_->getDChopperInData(); + if (c_x == nullptr || c_d == nullptr) { + // No chopper active — base GEMV result is already correct. + return; + } + + int n = this->size_; + int nthreads = context->getNThreads(n); + int nblocks = context->getNBlocks(n, nthreads); + + kernelApplyChopperCorrectionToWeights<<getStream()>>>( + dev_weights, this->dev_weights_ptrs_[0], c_x, c_d, gamma, this->d_size_, n); +} + template void ChoppedTransferRPUDeviceCuda::readAndUpdate( int to_device_idx, int from_device_idx, int i_slice_start, - const T lr, + const T transfer_lr, const T count_lr, const T *vec, const int n_vec, const PulsedUpdateMetaParameter &up) { - if (lr == (T)0.0) { + if (transfer_lr == (T)0.0) { return; } if (!this->transfer_buffer_vec_.size()) { @@ -430,7 +495,8 @@ void ChoppedTransferRPUDeviceCuda::readAndUpdate( T *transfer_tmp = this->context_->template getSharedBuffer(RPU_BUFFER_DEVICE_0, t_size); T *transfer_out = this->context_->template getSharedBuffer(RPU_BUFFER_DEVICE_1, t_size); T lr_scale = par.getTransferLRScale( - from_weight_granularity, to_weight_granularity, lr, count_lr, cwo_->getCurrentMBatch()); + from_weight_granularity, to_weight_granularity, transfer_lr, count_lr, + cwo_->getCurrentMBatch()); // forward/backward with transfer vectors into tmp this->readMatrix(from_device_idx, nullptr, transfer_tmp, n_vec, (T)1.0); @@ -443,6 +509,9 @@ void ChoppedTransferRPUDeviceCuda::readAndUpdate( int nblocks = this->context_->getNBlocks(n, nthreads); T sub_momentum = (T)1.0 - MAX(MIN(par.momentum, (T)1.0), (T)0.0); + // `transfer_tmp` is a column/row of the weight on `from_device_idx` + // transfer a column/row from `from_device_device` to the digital buffer `B` + // `transfer_out`: the number of step to increase weight in `to_device_idx` by `B`; kernelChoppedTransfer<<context_->getStream()>>>( transfer_out, B, transfer_tmp, cwo_->getWeightOutputInChopperData(), cwo_->getWeightOutputOutChopperData(), out_size, in_size, n_vec, i_slice_start, lr_scale, @@ -464,9 +533,11 @@ void ChoppedTransferRPUDeviceCuda::readAndUpdate( CudaArray dev_a(this->context_, out_size * in_size); math::copyWithIterator(this->context_, dev_a.getData(), B, out_size * in_size); this->context_->synchronize(); dev_a.printMatrixValues(out_size);); - // update according to device T write_lr = par.getWriteLR(to_weight_granularity); + // the second and third arguments of `writeMatrix` are the vectors of rank-updates + // (1) `transfer_out`, a column of the digital buffer B + // (2) an one-hot vector, which will be assigned in `writeMatrix`. So `nullptr` is passed here this->writeMatrix(to_device_idx, nullptr, transfer_out, n_vec, write_lr, up); this->context_->template releaseSharedBuffer(RPU_BUFFER_DEVICE_0); @@ -486,6 +557,9 @@ T ChoppedTransferRPUDeviceCuda::getPulseCountLearningRate( out_count_lr = par.getPulseCountAutoLR( m_x_, m_d_, d_sparsity_, this->rpucuda_device_vec_[0]->getWeightGranularity(), transfer_every, up); + if (par.scale_fast_lr) { + out_count_lr *= lr; + } } else { out_count_lr = BufferedTransferRPUDeviceCuda::getPulseCountLearningRate(lr, current_m_batch, up); @@ -508,12 +582,17 @@ void ChoppedTransferRPUDeviceCuda::transfer( int in_size = par.getInSize(); int out_size = par.getOutSize(); + // the index of transferred column/row int i_slice = cwo_->getValStart(); + int n_transfers = cwo_->getNumWeightOutputs(); // could be more than in_size ! - T lr = par.getTransferLR(to_device_idx, from_device_idx, current_lr); + T transfer_lr = par.getTransferLR(to_device_idx, from_device_idx, current_lr); + // unlike `TransferRPUDeviceCuda`, which use a transfer vector (`transfer_vecs_`) to mark + // the index of transferred column/row, here `i_slice` is used to mark the index + // So the argument `vec` is nullptr here readAndUpdate( - to_device_idx, from_device_idx, i_slice, lr, current_count_lr, nullptr, n_transfers, + to_device_idx, from_device_idx, i_slice, transfer_lr, current_count_lr, nullptr, n_transfers, par.transfer_up); this->current_slice_indices_[from_device_idx] = (i_slice + n_transfers) % in_size; } @@ -563,17 +642,22 @@ void ChoppedTransferRPUDeviceCuda::runUpdateKernel( RPU_FATAL("Explicit CWO as input not allowed here."); } - if (!this->fully_hidden_) { - RPU_FATAL("Expects fully hidden fast matrix."); - } + // if (!this->fully_hidden_) { + // RPU_FATAL("Expects fully hidden fast matrix."); + // } if (blm->getCurrentLR() == (T)0.0) { return; } + const auto &par = getPar(); // set full hidden - this->dev_weights_ptrs_[this->n_devices_ - 1] = dev_weights; - + // When gamma > 0, A contributes to the visible weight W = gamma*(c_d⊗c_x)*A + C. + // The chopper correction is applied in reduceToWeights() via + // kernelApplyChopperCorrectionToWeights, so no special handling is needed here. + if (par.fullyHidden()) { + this->dev_weights_ptrs_[this->n_devices_ - 1] = dev_weights; + } // TODO: enable asynchronous update on a separate update stream. // for that one needs to make sure that wait events are // inserted in the main stream and that all the context @@ -584,6 +668,9 @@ void ChoppedTransferRPUDeviceCuda::runUpdateKernel( // generate the choppers, advance counter, etc cwo_->makeWeightOutputChoppers(blm); + // here m_batch data samples (gradients) are used to update the fast weight (the last device) + // no transfer happens during the update process + // if m_batch is greater than `transfer_every`, all transfers are deferred to the end this->rpucuda_device_vec_[0]->runUpdateKernel( kpars, c, this->dev_weights_ptrs_[0], m_batch, blm, up, lr, dev_states, one_sided, nullptr, nullptr, &*cwo_); diff --git a/src/rpucuda/cuda/rpucuda_chopped_transfer_device.h b/src/rpucuda/cuda/rpucuda_chopped_transfer_device.h index c0cd63e1..23708156 100644 --- a/src/rpucuda/cuda/rpucuda_chopped_transfer_device.h +++ b/src/rpucuda/cuda/rpucuda_chopped_transfer_device.h @@ -98,6 +98,7 @@ template class ChoppedTransferRPUDeviceCuda : public BufferedTransf void loadExtra(const RPU::state_t &extra, const std::string prefix, bool strict) override; protected: + void reduceToWeights(CudaContextPtr c, T *dev_weights) override; int getTransferEvery( int didx, int m_batch, const PulsedUpdateMetaParameter &up) const override; T getPulseCountLearningRate( diff --git a/src/rpucuda/cuda/rpucuda_transfer_device.cu b/src/rpucuda/cuda/rpucuda_transfer_device.cu index 374eb824..e84b2397 100644 --- a/src/rpucuda/cuda/rpucuda_transfer_device.cu +++ b/src/rpucuda/cuda/rpucuda_transfer_device.cu @@ -186,7 +186,12 @@ T TransferRPUDeviceCuda::getPulseCountLearningRate( const auto &par = getPar(); if (par.fast_lr > (T)0.0) { - return par.fast_lr; + if (par.scale_fast_lr) { + // scale the fast LR with the SGD learning rate + return par.fast_lr * learning_rate; + } else { + return par.fast_lr; + } } else { return PulsedRPUDeviceCudaBase::getPulseCountLearningRate( learning_rate, current_m_batch, up); diff --git a/src/rpucuda/rpu_chopped_transfer_device.cpp b/src/rpucuda/rpu_chopped_transfer_device.cpp index 612f189b..114a3f1f 100644 --- a/src/rpucuda/rpu_chopped_transfer_device.cpp +++ b/src/rpucuda/rpu_chopped_transfer_device.cpp @@ -75,9 +75,9 @@ template void ChoppedTransferRPUDeviceMetaParameter::checkSuppor RPU_FATAL("Only same context supported"); } - if (!this->fullyHidden()) { - RPU_FATAL("Expects a fully hidden fast device."); - } + // if (!this->fullyHidden()) { + // RPU_FATAL("Expects a fully hidden fast device."); + // } if ((this->n_reads_per_transfer != 1) || (this->random_selection != false) || (this->with_reset_prob > (T)0.0)) { @@ -341,6 +341,9 @@ T ChoppedTransferRPUDevice::getPulseCountLearningRate( count_lr = par.getPulseCountAutoLR( m_x_, m_d_, d_sparsity_, this->rpu_device_vec_[0]->getWeightGranularity(), transfer_every, up); + if (par.scale_fast_lr) { + count_lr *= lr; + } } else { count_lr = BufferedTransferRPUDevice::getPulseCountLearningRate(lr, current_m_batch, up); @@ -409,6 +412,9 @@ void ChoppedTransferRPUDevice::readAndUpdate( int non_zero_count = 0; bool in_chop = in_chopper_[(size_t)(i_slice_start)]; + // Standard TTv2/TTv3 buffer accumulation: + // H += lr_scale * A + // After taking steps, H is reduced (momentum decay or forget_buffer). PRAGMA_SIMD for (int j = 0; j < out_size; j++) { @@ -510,6 +516,40 @@ void ChoppedTransferRPUDevice::doDenseUpdate(T **weights, int *coincidences, TransferRPUDevice::doDenseUpdate(weights, coincidences, rng); } +template +void ChoppedTransferRPUDevice::reduceToWeights(T **weights) const { + const auto &par = getPar(); + + if (par.fullyHidden()) { + // fully_hidden_: weights == dev_weights_ptrs_[last] already, no-op. + return; + } + + // Standard GEMV: W = gamma * A + C (+ any additional slow devices) + TransferRPUDevice::reduceToWeights(weights); + + T gamma = par.gamma_vec[0]; // fast weight (device[0]) contribution scale + if (gamma == (T)0.0) { + return; + } + + // Apply chopper correction: W = gamma * (c_d ⊗ c_x) * A + C + // The fast weight A is stored in "chopped" form and needs de-chopping + // when it contributes to the final weight W. + // Correction: add gamma * (c_d[i] * c_x[j] - 1) * A[i,j] to each element. + + T *A = *this->weights_vec_[0]; + int size = this->x_size_ * this->d_size_; + + // Layout: W[x_idx * d_size + d_idx] (column-major, d as inner dimension) + for (int idx = 0; idx < size; idx++) { + int d_idx = idx % this->d_size_; + int x_idx = idx / this->d_size_; + T chop = (T)(in_chopper_[x_idx] * out_chopper_[d_idx]); + weights[0][idx] += gamma * (chop - (T)1.0) * A[idx]; + } +} + template class ChoppedTransferRPUDevice; #ifdef RPU_USE_DOUBLE template class ChoppedTransferRPUDevice; diff --git a/src/rpucuda/rpu_chopped_transfer_device.h b/src/rpucuda/rpu_chopped_transfer_device.h index fe03dd19..fbcb1b07 100644 --- a/src/rpucuda/rpu_chopped_transfer_device.h +++ b/src/rpucuda/rpu_chopped_transfer_device.h @@ -154,6 +154,8 @@ template class ChoppedTransferRPUDevice : public BufferedTransferRP void doDenseUpdate(T **weights, int *coincidences, RNG *rng) override; + void reduceToWeights(T **weights) const override; + void initUpdateCycle( T **weights, const PulsedUpdateMetaParameter &up, diff --git a/src/rpucuda/rpu_transfer_device.cpp b/src/rpucuda/rpu_transfer_device.cpp index 1a5fe30d..db8c7fdd 100644 --- a/src/rpucuda/rpu_transfer_device.cpp +++ b/src/rpucuda/rpu_transfer_device.cpp @@ -372,7 +372,11 @@ T TransferRPUDevice::getPulseCountLearningRate( const auto &par = getPar(); if (par.fast_lr > (T)0.0) { - return par.fast_lr; + if (par.scale_fast_lr) { + return par.fast_lr * learning_rate; + } else { + return par.fast_lr; + } } else { return learning_rate; } diff --git a/src/rpucuda/rpu_transfer_device.h b/src/rpucuda/rpu_transfer_device.h index 2232659c..0cc15da0 100644 --- a/src/rpucuda/rpu_transfer_device.h +++ b/src/rpucuda/rpu_transfer_device.h @@ -64,6 +64,7 @@ template struct TransferRPUDeviceMetaParameter : VectorRPUDeviceMet T transfer_lr = (T)1.0; std::vector transfer_lr_vec; bool scale_transfer_lr = true; + bool scale_fast_lr = true; bool transfer_columns = true; // or rows int _in_size = 0; int _out_size = 0;