diff --git a/src/aihwkit/nn/conversion.py b/src/aihwkit/nn/conversion.py index b751fdb5..e8e319c9 100644 --- a/src/aihwkit/nn/conversion.py +++ b/src/aihwkit/nn/conversion.py @@ -16,6 +16,7 @@ from torch.nn import Module, Linear, Conv1d, Conv2d, Conv3d, Sequential from aihwkit.exceptions import ArgumentError +from aihwkit.optim.context import AnalogContext from aihwkit.simulator.tiles.module import TileModule from aihwkit.nn.modules.container import AnalogWrapper from aihwkit.nn.modules.base import AnalogLayerBase @@ -88,6 +89,7 @@ def convert_to_analog( exclude_modules: Optional[List[str]] = None, inplace: bool = False, verbose: bool = False, + readonly: Optional[bool] = None, ) -> Module: """Convert a given digital model to its analog counterpart. @@ -138,6 +140,13 @@ def convert_to_analog( verbose: Increase verbosity. Will print converted layers. + readonly: If not ``None``, override the ``readonly`` flag on + every :class:`~aihwkit.optim.context.AnalogContext` after + conversion. When ``True``, in-place weight modifications + are blocked; when ``False``, they are allowed. If ``None`` + (default), each tile uses the value from + ``rpu_config.mapping.readonly_weights``. + Returns: Module where all the digital layers are replaced with analog mapped layers. @@ -193,6 +202,7 @@ def convert_to_analog( exclude_modules, True, verbose, + readonly, ) continue @@ -211,6 +221,12 @@ def convert_to_analog( if ensure_analog_root and not module_name and not isinstance(module, AnalogLayerBase): module = AnalogWrapper(module) + # Apply global readonly override if specified + if readonly is not None and not module_name: + for param in module.parameters(): + if isinstance(param, AnalogContext): + param.readonly = readonly + return module diff --git a/src/aihwkit/nn/modules/base.py b/src/aihwkit/nn/modules/base.py index 1424061d..3f346b34 100644 --- a/src/aihwkit/nn/modules/base.py +++ b/src/aihwkit/nn/modules/base.py @@ -7,10 +7,12 @@ """Base class for adding functionality to analog layers.""" from typing import Any, List, Optional, Tuple, NamedTuple, Union, Generator, Callable, TYPE_CHECKING from collections import OrderedDict +import warnings from torch import Tensor from torch.nn import Parameter from torch import device as torch_device +from torch import Size as torch_size from aihwkit.exceptions import ModuleError from aihwkit.simulator.tiles.module import TileModule @@ -244,6 +246,30 @@ class type when setting ``load_rpu_config`` to ``load_rpu_config=False``. """ + keys_to_delete = [] + for name, param in list(state_dict.items()): + if name.endswith("analog_ctx") and param.size() == torch_size([]): + keys_to_delete.append(name) + + # For the checkpoint saved by aihwkit before version, 0.9.2, the parameters with + # class `AnalogContext` are saved as empty size tensors since `AnalogContext`, + # which is a derived class of `torch.nn.Parameter`, + # uses an empty Parameter tensor to store the context. + if len(keys_to_delete) > 0: + strict = False + for key in keys_to_delete: + del state_dict[key] + warnings.warn( + "Some parameters in the loaded checkpoint has empty size" + "(param.size() == torch.Size([]))." + "It could happens because of the loaded checkpoint" + "is generated by an older version of aihwkit." + "The parameter is skipped for compatibility reasons." + "The loading mode is set to non-strict." + "It is recommended to re-save the checkpoint with the latest version of aihwkit." + "Related parameters are: {}".format(keys_to_delete) + ) + for analog_tile in self.analog_tiles(): analog_tile.set_load_rpu_config_state(load_rpu_config, strict_rpu_config_check) return super().load_state_dict(state_dict, strict) # type: ignore diff --git a/src/aihwkit/nn/modules/linear.py b/src/aihwkit/nn/modules/linear.py index 9db3bee2..8a5debfe 100644 --- a/src/aihwkit/nn/modules/linear.py +++ b/src/aihwkit/nn/modules/linear.py @@ -82,6 +82,10 @@ def reset_parameters(self) -> None: self.weight, self.bias = self.get_weights() # type: ignore super().reset_parameters() self.set_weights(self.weight, self.bias) # type: ignore + # AnalogLinear doesn't support access weight and bias directly, so delete them + del self.weight, self.bias + # delete them manually is necessary since asigning `bias` (a bool) is forbidden + # by torch if self.bias is already a tensor self.weight, self.bias = None, bias def forward(self, x_input: Tensor) -> Tensor: diff --git a/src/aihwkit/optim/context.py b/src/aihwkit/optim/context.py index d9228d03..a2ed07cc 100644 --- a/src/aihwkit/optim/context.py +++ b/src/aihwkit/optim/context.py @@ -8,18 +8,49 @@ # pylint: disable=attribute-defined-outside-init +from contextlib import contextmanager from typing import Optional, Type, Union, Any, TYPE_CHECKING -from torch import ones, dtype, Tensor, no_grad +from torch import dtype, Tensor, no_grad from torch.nn import Parameter from torch import device as torch_device +from aihwkit.optim.weight_view import ReadOnlyWeightView + if TYPE_CHECKING: from aihwkit.simulator.tiles.base import SimulatorTileWrapper class AnalogContext(Parameter): - """Context for analog optimizer.""" + """Context for analog optimizer. + + Note: `data` attribution, inherited from `torch.nn.Parameter`, is a tensor of training parameter + If `analog_bias` (which is provided by `analog_tile`) is False, + `data` has the same meaning as `torch.nn.Parameter` + If `analog_bias` (which is provided by `analog_tile`) is True, + The last column of `data` is the `bias` term + + Even though it allows us to access the weights directly, always keep in mind that it is used + only for studying propuses. To simulate the real reading, call the `read_weights` method + instead, i.e. given `analog_ctx: AnalogContext`, + estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights() + + Similarly, even though this feature allows us to update the weights directly, + always keep in mind that the real RPU devices change their weights only + by "pulse update" method. + + Therefore, use the following update methods instead of + writing `data` directly in the analog optimizer: + --- + analog_ctx.analog_tile.update(...) + analog_ctx.analog_tile.update_indexed(...) + --- + + The ``readonly`` flag (default ``True``) causes ``.data`` reads to + return a :class:`~aihwkit.optim.weight_view.ReadOnlyWeightView` + that blocks in-place mutations. Toggle it via the property or the + :meth:`writable` context manager. + """ def __new__( cls: Type["AnalogContext"], @@ -28,11 +59,18 @@ def __new__( ) -> "AnalogContext": # pylint: disable=signature-differs if parameter is None: + weights_ref = analog_tile._get_tile_weights_ref() return Parameter.__new__( cls, - data=ones((), device=analog_tile.device, dtype=analog_tile.get_dtype()), + data=weights_ref, requires_grad=True, ) + # analog_tile.tile can comes from different classes: + # aihwkit.silulator.rpu_base.devices.AnalogTile (C++) + # TorchInferenceTile (Python) + # It stores the "weight" matrix; + # If analog_tile.analog_bias is True, it also stores the "bias" matrix + parameter.__class__ = cls return parameter @@ -40,6 +78,7 @@ def __init__( self, analog_tile: "SimulatorTileWrapper", parameter: Optional[Parameter] = None ): # pylint: disable=unused-argument super().__init__() + self._readonly = self._default_readonly(analog_tile) self.analog_tile = analog_tile self.use_torch_update = False self.use_indexed = False @@ -47,6 +86,63 @@ def __init__( self.analog_grad_output = [] # type: list self.reset(analog_tile) + # -- readonly flag -------------------------------------------------------- + + @staticmethod + def _default_readonly(analog_tile: "SimulatorTileWrapper") -> bool: + """Read the default ``readonly`` setting from ``rpu_config.mapping``.""" + rpu_config = getattr(analog_tile, "rpu_config", None) + if rpu_config is not None: + mapping = getattr(rpu_config, "mapping", None) + if mapping is not None: + return getattr(mapping, "readonly_weights", True) + return True + + @property + def readonly(self) -> bool: + """Whether in-place modifications on ``data`` are blocked.""" + try: + return object.__getattribute__(self, "_readonly") + except AttributeError: + return True + + @readonly.setter + def readonly(self, value: bool) -> None: + self._readonly = value + + def __getattribute__(self, name: str) -> Any: + """Intercept ``.data`` reads: return a :class:`ReadOnlyWeightView` + when ``readonly`` is ``True``, otherwise the raw tensor.""" + if name == "data": + raw = super().__getattribute__(name) + try: + readonly = object.__getattribute__(self, "_readonly") + except AttributeError: + return raw + if readonly: + return ReadOnlyWeightView(raw) + return raw + return super().__getattribute__(name) + + @contextmanager + def writable(self): + """Context manager that temporarily allows direct weight modification. + + Example:: + + with analog_ctx.writable(): + analog_ctx.data.add_(delta) + # readonly is restored automatically + """ + old = self.readonly + self.readonly = False + try: + yield self + finally: + self.readonly = old + + # -- existing API --------------------------------------------------------- + def set_indexed(self, value: bool = True) -> None: """Set the context to forward_indexed.""" self.use_indexed = value @@ -54,7 +150,13 @@ def set_indexed(self, value: bool = True) -> None: def set_data(self, data: Tensor) -> None: """Set the data value of the Tensor.""" with no_grad(): - self.data.copy_(data) + # Unwrap source if it is a ReadOnlyWeightView so that + # copy_() does not trigger the in-place guard. + if isinstance(data, ReadOnlyWeightView): + data = data.as_writable() + # Access raw data directly (bypassing readonly wrap) to + # preserve storage sharing with the tile weight tensor. + super().__getattribute__("data").copy_(data) def get_data(self) -> Tensor: """Get the data value of the underlying Tensor.""" @@ -77,11 +179,11 @@ def has_gradient(self) -> bool: def __copy__(self) -> Parameter: """Turn off copying of the pointers. Context will be re-created when tile is created""" - return Parameter(self.data) + return Parameter(super().__getattribute__("data")) def __deepcopy__(self, memo: Any) -> Parameter: """Turn off deep copying. Context will be re-created when tile is created""" - return Parameter(self.data) + return Parameter(super().__getattribute__("data")) def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "AnalogContext": """Move the context to a cuda device. @@ -92,8 +194,8 @@ def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "Analo Returns: This context in the specified device. """ - self.data = self.data.cuda(device) # type: Tensor if not self.analog_tile.is_cuda: + self.data = self.analog_tile._get_tile_weights_ref() # type: Tensor self.analog_tile = self.analog_tile.cuda(device) self.reset(self.analog_tile) return self diff --git a/src/aihwkit/optim/weight_view.py b/src/aihwkit/optim/weight_view.py new file mode 100644 index 00000000..9682f644 --- /dev/null +++ b/src/aihwkit/optim/weight_view.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# +# Licensed under the MIT license. See LICENSE file in the project root for details. + +"""Read-only tensor view for analog tile weights.""" + +from torch import Tensor +from torch.utils._pytree import tree_map + + +class ReadOnlyWeightView(Tensor): + """A tensor that shares storage with tile weights but blocks in-place mutations. + + All read operations (``size``, ``norm``, ``sum``, indexing, comparisons, etc.) + work transparently because this IS a real tensor sharing the same memory. + In-place write operations raise ``RuntimeError`` with guidance to use the + correct analog tile API. + + This class is stateless — it always blocks in-place operations. The policy + of whether to wrap or unwrap is managed by :class:`AnalogContext` via its + ``readonly`` flag. + """ + + @staticmethod + def __new__(cls, data: Tensor) -> "ReadOnlyWeightView": + """Create a ReadOnlyWeightView sharing storage with ``data``.""" + if isinstance(data, ReadOnlyWeightView): + return data + return Tensor._make_subclass(cls, data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + func_name = getattr(func, "__name__", "") + + # PyTorch convention: in-place ops end with single '_' (add_, mul_, ...) + # Dunder methods (__repr__, __eq__, ...) end with '__' and must pass through + if func_name.endswith("_") and not func_name.endswith("__"): + raise RuntimeError( + f"In-place operation '{func_name}' is not allowed on analog weights. " + f"Analog weights cannot be modified directly — this would bypass " + f"the physical constraints of the analog device.\n" + f" - For programmatic writes: analog_tile.set_weights(new_weight)\n" + f" - For gradient updates: analog_tile.update(x_input, d_input)\n" + f" - To unlock direct access: analog_ctx.readonly = False" + ) + + # Unwrap to plain Tensor so downstream ops don't propagate our subclass + def unwrap(t): + return t.as_subclass(Tensor) if isinstance(t, ReadOnlyWeightView) else t + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + return func(*args, **kwargs) + + def __setitem__(self, key, value): + """Block item assignment (e.g., ``ctx.data[0, 0] = 999``).""" + raise RuntimeError( + "Direct item assignment on analog weights is not allowed. " + "Use analog_tile.set_weights() instead, " + "or set analog_ctx.readonly = False to unlock direct access." + ) + + def as_writable(self) -> Tensor: + """Return the underlying plain Tensor (for internal tile use only). + + This removes the read-only guard. Only tile internals should call this. + """ + return self.as_subclass(Tensor) diff --git a/src/aihwkit/simulator/parameters/mapping.py b/src/aihwkit/simulator/parameters/mapping.py index b0791e36..d8786bbc 100644 --- a/src/aihwkit/simulator/parameters/mapping.py +++ b/src/aihwkit/simulator/parameters/mapping.py @@ -100,6 +100,20 @@ class MappingParameter(_PrintableMixin): :class:`aihwkit.nn.modules.linear_mapped.AnalogLinearMapped`. """ + readonly_weights: bool = True + """Whether the analog context data is read-only. + + When ``True`` (default), ``analog_ctx.data`` is wrapped in a + :class:`~aihwkit.optim.weight_view.ReadOnlyWeightView` that blocks + in-place mutations (e.g. ``add_``, ``copy_``, ``__setitem__``). + This prevents users from accidentally modifying analog weights in + a way that bypasses the physical constraints of the analog device. + + Set to ``False`` to allow direct weight access for research or + hardware-aware training scenarios. The flag can also be toggled at + runtime via ``analog_ctx.readonly``. + """ + def compatible_with(self, mapping: "MappingParameter") -> bool: """Checks compatiblity @@ -117,6 +131,7 @@ def compatible_with(self, mapping: "MappingParameter") -> bool: "weight_scaling_omega", "weight_scaling_columnwise", "weight_scaling_lr_compensation", + "readonly_weights", ]: continue diff --git a/src/aihwkit/simulator/tiles/array.py b/src/aihwkit/simulator/tiles/array.py index 7d1ddb9c..8469312c 100644 --- a/src/aihwkit/simulator/tiles/array.py +++ b/src/aihwkit/simulator/tiles/array.py @@ -124,7 +124,7 @@ def set_weights(self, weight: Tensor, bias: Optional[Tensor] = None, **kwargs: A in_start = in_end if self.bias is not None and bias is not None: - self.bias.data = bias.detach().to(self.bias.device) + self.bias.data.copy_(bias) @no_grad() def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]: @@ -149,7 +149,7 @@ def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]: weight = cat(weight_lst, 1) if self.bias is not None: - return weight, self.bias.clone().cpu() + return weight, self.bias.data return weight, None def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor: diff --git a/src/aihwkit/simulator/tiles/base.py b/src/aihwkit/simulator/tiles/base.py index dac1a185..ab0d01b8 100644 --- a/src/aihwkit/simulator/tiles/base.py +++ b/src/aihwkit/simulator/tiles/base.py @@ -14,7 +14,7 @@ from numpy import array from numpy.typing import ArrayLike -from torch import Tensor, from_numpy, float32, unsqueeze, cat, empty, stack, dtype +from torch import Tensor, from_numpy, float32, unsqueeze, cat, empty, stack, dtype, zeros from torch import device as torch_device from torch.cuda import device as cuda_device from torch.autograd import no_grad @@ -171,8 +171,15 @@ def get_brief_info(self) -> str: """Returns a brief info""" raise NotImplementedError - def get_weights(self) -> Tensor: - """Returns the analog weights.""" + def get_weights(self, as_ref: bool = False) -> Tensor: + """Returns the analog weights. + + Args: + as_ref: if True, return a reference to the internal weight tensor + (not detached, stays on the current device). If False (default), + return a detached CPU copy. Not all tile types support true + references; C++ tiles always return a copy regardless. + """ raise NotImplementedError def set_weights(self, weight: Tensor) -> None: @@ -264,6 +271,7 @@ def get_meta_parameters(self) -> Any: raise NotImplementedError +# pylint: disable=too-many-public-methods class SimulatorTileWrapper: """Wrapper base class for defining the necessary tile functionality. @@ -281,6 +289,18 @@ class SimulatorTileWrapper: should be used. handle_output_bound: whether the bound clamp gradient should be inserted ignore_analog_state: whether to ignore the analog state when __getstate__ is called + + Attributes: + tile: A simulator tile object that handles the computations + for the given input/output sizes. + It is created by `self._create_simulator_tile` method, + which is provided by the derived class. + E.g., `aihwkit.simulator.tiles.analog.AnalogTile` and + `aihwkit.simulator.tiles.inference_torch.TorchInferenceTile` + implement this method. + The weight data is stored in the tile object. + analog_ctx: `AnalogContext`, which wraps the weight in tile + into a `torch.nn.Parameter` object. """ def __init__( @@ -295,8 +315,6 @@ def __init__( handle_output_bound: bool = False, ignore_analog_state: bool = False, ): - self.is_cuda = False - self.device = torch_device("cpu") self.out_size = out_size self.in_size = in_size self.rpu_config = deepcopy(rpu_config) @@ -321,9 +339,119 @@ def __init__( self.tile = self._create_simulator_tile(x_size, d_size, rpu_config) + # Set up zero-copy shared weight tensor for C++ tiles. + self._shared_weight_tensor = None # type: Optional[Tensor] + self._bind_shared_weights() + self.analog_ctx = AnalogContext(self) self.analog_ctx.use_torch_update = torch_update + def _bind_shared_weights(self) -> None: + """Bind a PyTorch tensor as the C++ tile's weight storage. + + For C++ tiles that expose ``set_shared_weights``, this allocates a + contiguous tensor and passes its ``data_ptr`` to the C++ side so that + both Python and C++ operate on the same memory. After this call + ``tile.update()`` / ``tile.set_weights()`` modify the tensor + in-place — no explicit sync is needed. + + For pure-Python tiles (which already store weights as + ``torch.Tensor``), this is a no-op. + """ + if not hasattr(self.tile, "set_shared_weights"): + return + + # Probe whether the tile is a pure-Python tile by trying to call + # get_weights(as_ref=True). + # + # Tiles that accept ``as_ref``: + # TorchSimulatorTile — returns self.weight.data + # CustomSimulatorTile — returns self._analog_weight.data + # TransferSimulatorTile — accepts but delegates to C++ tile + # + # C++ tiles (pybind11 bindings) do NOT accept keyword arguments + # and raise TypeError. These are the ones that need binding: + # tiles.AnalogTile / CudaAnalogTile + # tiles.FloatingPointTile / CudaFloatingPointTile + # (and their half/double/bfloat16 variants) + try: + self.tile.get_weights(as_ref=True) + return # Pure-Python tile — already backed by torch.Tensor. + except TypeError: + pass # C++ tile — proceed with shared weight binding below. + + d_size = self.tile.get_d_size() + x_size = self.tile.get_x_size() + # C++ get_weights() always returns CPU; infer device from tile type name. + # CUDA C++ tiles expect transposed layout (x_size, d_size) for + # set_shared_weights, while CPU tiles expect (d_size, x_size). + is_cuda = "Cuda" in type(self.tile).__name__ + tile_device = torch_device("cuda") if is_cuda else torch_device("cpu") + if is_cuda: + shared = zeros(x_size, d_size, dtype=self.get_dtype(), device=tile_device) + else: + shared = zeros(d_size, x_size, dtype=self.get_dtype(), device=tile_device) + self.tile.set_shared_weights(shared) + # CUDA set_shared_weights does not auto-populate the buffer (unlike CPU). + # Force-sync from the tile's internal device weights into the shared tensor. + if is_cuda: + w_cpu = self.tile.get_weights() # (d_size, x_size) on CPU + shared.copy_(w_cpu.t().to(tile_device)) + self._shared_weight_tensor = shared + + def _get_tile_weights_ref(self) -> Tensor: + """Get tile weights, preferring a reference if the tile supports it.""" + if self._shared_weight_tensor is not None: + # CUDA C++ tiles store shared weights in transposed layout + # (x_size, d_size). Return .t() so callers always see the + # standard (d_size, x_size) shape — still zero-copy because + # .t() on a 2-D tensor is a stride-only view. + if "Cuda" in type(self.tile).__name__: + return self._shared_weight_tensor.t() + return self._shared_weight_tensor + # Fast path: CUDA C++ tiles with native GPU weight access. + # get_weights_cuda() returns [x_size, d_size] on device; .t() gives + # the standard [d_size, x_size] view without any CPU roundtrip. + if hasattr(self.tile, "get_weights_cuda"): + return self.tile.get_weights_cuda().t() + try: + return self.tile.get_weights(as_ref=True) + except TypeError: + # C++ tile bindings don't accept as_ref + return self.tile.get_weights() + + def _sync_analog_ctx_weights(self) -> None: + """Sync analog_ctx.data with the tile weights. + + With shared weight tensors, ``analog_ctx.data`` and the tile's + internal weights already share the same memory, so during normal + training (forward → update) this is a **no-op** (same + ``data_ptr``). + + This method is still necessary for **device moves** (cpu ↔ cuda): + moving the tile to a different device replaces its backing store, + which invalidates the old ``data_ptr``. The callers in + ``cpu()`` / ``cuda()`` (base.py, module.py) and + ``set_weights()`` (periphery.py) rely on this to re-bind + ``analog_ctx.data`` after such transitions. + """ + if not hasattr(self, "analog_ctx"): + return + target_device = self.analog_ctx.data.device + ref = self._get_tile_weights_ref() + if self.analog_ctx.data.data_ptr() != ref.data_ptr(): + self.analog_ctx.data = ref.to(target_device) + + @property + def device(self) -> torch_device: + """Return the device of the tile.""" + return self.analog_ctx.device + + @property + def is_cuda(self) -> bool: + """Return the is_cuda state of the tile.""" + return self.analog_ctx.is_cuda + def get_runtime(self) -> RuntimeParameter: """Returns the runtime parameter.""" if not hasattr(self.rpu_config, "runtime"): @@ -456,11 +584,11 @@ def __getstate__(self) -> Dict: # don't save device. Will be determined by loading object current_dict.pop("stream", None) - current_dict.pop("is_cuda", None) - current_dict.pop("device", None) # this is should not be saved. current_dict.pop("image_sizes", None) + # Shared weight tensor is rebuilt by _bind_shared_weights(). + current_dict.pop("_shared_weight_tensor", None) return current_dict @@ -527,9 +655,6 @@ def __setstate__(self, state: Dict) -> None: self.rpu_config = rpu_config self.__dict__.update(current_dict) - self.device = torch_device("cpu") - self.is_cuda = False - # recreate attributes not saved # always first create on CPU x_size = self.in_size + 1 if self.analog_bias else self.in_size @@ -537,6 +662,8 @@ def __setstate__(self, state: Dict) -> None: # Recreate the tile. self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config) + self._shared_weight_tensor = None + self._bind_shared_weights() names = self.tile.get_hidden_parameter_names() if len(hidden_parameters_names) > 0 and names != hidden_parameters_names: @@ -552,6 +679,23 @@ def __setstate__(self, state: Dict) -> None: weights = from_numpy(array(weights)) self.tile.set_weights(weights) + # After loading weights, re-bind shared_weights to the newly recreated + # C++ tile's backing store (see shared_weights docstring in + # RPUCudaSimulatorTileWrapper for its role). + # + # The tile is always recreated on CPU at this point, so we only sync + # here for the CPU case. If the tile will be moved to CUDA, + # .to(device) triggers cuda() which recreates the CUDA tile and + # re-syncs shared_weights there. + # + # ensure_shared_weights is defined on RPUCudaSimulatorTileWrapper, + # not on this base class — guard with hasattr for both runtime + # safety and pylint. + if hasattr(self, "shared_weights") and self.shared_weights is not None: + if not self.shared_weights.is_cuda and hasattr(self, "ensure_shared_weights"): + self.ensure_shared_weights() # pylint: disable=no-member + self.analog_ctx.data = self._get_tile_weights_ref() + if analog_lr is not None: self.tile.set_learning_rate(analog_lr) @@ -568,8 +712,15 @@ def __setstate__(self, state: Dict) -> None: if analog_ctx is not None: # Keep the object ID and device to_device = analog_ctx.device - if self.device != to_device: - self.analog_ctx = self.analog_ctx.to(to_device) + if self.analog_ctx.device != to_device: + # aihwkit implements analog tiles in both CPU and CUDA versions, + # e.g. FloatingPointTile(RPUSimple(4, 3)) + # v.s. FloatingPointTile(RPUCudaSimple(4, 3)) + # Here we need to manually convert the tile to the corresponding version + self.to(to_device) + # Note: `self.to(to_device)` will call `self.analog_ctx.data.to(to_device)` + # so no need to recall + # self.analog_ctx = self.analog_ctx.to(to_device) self.analog_ctx.set_data(analog_ctx.data) @no_grad() @@ -623,6 +774,33 @@ def _combine_weights( # Use only the ``[out_size, in_size]`` matrix. return weight + def _combine_weights_cuda( + self, weight: Union[Tensor, "ArrayLike"], bias: Optional[Union[Tensor, "ArrayLike"]] = None + ) -> Tensor: + """Like _combine_weights but keeps tensors on the tile's CUDA device. + + Returns a **contiguous** ``[x_size, d_size]`` CUDA tensor in the + internal transposed layout expected by ``set_weights_cuda``. + """ + d_type = self.get_dtype() + device = self.device # the tile's CUDA device + if not isinstance(weight, Tensor): + weight = from_numpy(array(weight)) + weight = weight.detach().to(dtype=d_type, device=device).reshape(self.out_size, self.in_size) + + if self.analog_bias: + if bias is None: + raise ValueError("Analog tile has a bias, but no bias given") + if not isinstance(bias, Tensor): + bias = from_numpy(array(bias)) + bias = unsqueeze(bias.detach().to(dtype=d_type, device=device), 1) # type: ignore + combined = cat((weight, bias), dim=1) # [out_size, in_size+1] + else: + combined = weight # [out_size, in_size] + + # Transpose to [x_size, d_size] (the internal CUDA storage layout). + return combined.t().contiguous() + def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[Tensor]]: """Helper to separate the combined weights and biases""" # Split the internal weights (and potentially biases) matrix. @@ -632,6 +810,16 @@ def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[ return combined_weights, None + # pylint: disable=invalid-name + def to(self, device: torch_device) -> "SimulatorTileWrapper": + """Move the tile to a device. + """ + if device.type == "cuda": + self.cuda(device) + else: + self.cpu() + return self + @no_grad() def cpu(self) -> "SimulatorTileWrapper": """Return a copy of this tile in CPU memory. @@ -642,10 +830,11 @@ def cpu(self) -> "SimulatorTileWrapper": if not self.is_cuda: return self - self.is_cuda = False - self.device = torch_device("cpu") self.analog_ctx.data = self.analog_ctx.data.cpu() self.analog_ctx.reset(self) + self._shared_weight_tensor = None + self._bind_shared_weights() + self._sync_analog_ctx_weights() return self @@ -665,10 +854,11 @@ def cuda( CudaError: if the library has not been compiled with CUDA. """ device = torch_device("cuda", cuda_device(device).idx) - self.is_cuda = True - self.device = device self.analog_ctx.data = self.analog_ctx.data.cuda(device) self.analog_ctx.reset(self) + self._shared_weight_tensor = None + self._bind_shared_weights() + self._sync_analog_ctx_weights() return self def get_hidden_parameters(self) -> "OrderedDict": diff --git a/src/aihwkit/simulator/tiles/custom.py b/src/aihwkit/simulator/tiles/custom.py index de1d4425..22c19b48 100644 --- a/src/aihwkit/simulator/tiles/custom.py +++ b/src/aihwkit/simulator/tiles/custom.py @@ -57,6 +57,9 @@ def __init__(self, x_size: int, d_size: int, rpu_config: "CustomRPUConfig", bias AnalogMVM.check_support(rpu_config.backward) self.set_config(rpu_config) + # Type declaration for static analysis (pylint/mypy): register_buffer sets this + # attribute dynamically and static analyzers cannot infer it from the string-based API. + self._analog_weight: Tensor # just buffer to handle device, since do not use auto grad self.register_buffer("_analog_weight", zeros(self.d_size, self.x_size, dtype=float32)) @@ -147,18 +150,21 @@ def set_weights(self, weight: Tensor) -> None: Args: weight: ``[out_size, in_size]`` weight matrix. """ - device = self._analog_weight.device - self._analog_weight = weight.clone().to(device) + self._analog_weight.copy_(weight) - def get_weights(self) -> Tensor: + def get_weights(self, as_ref: bool = False) -> Tensor: """Get the tile weights. + Args: + as_ref: if True, return a reference to the internal weight tensor. + If False (default), return a detached CPU copy. + Returns: - a tuple where the first item is the ``[out_size, in_size]`` weight - matrix; and the second item is either the ``[out_size]`` bias vector - or ``None`` if the tile is set not to use bias. + the ``[out_size, in_size]`` weight matrix. """ - return self._analog_weight.data.detach().cpu() + if as_ref: + return self._analog_weight.data + return self._analog_weight.data.detach().cpu().clone() def get_x_size(self) -> int: """Returns input size of tile""" diff --git a/src/aihwkit/simulator/tiles/functions.py b/src/aihwkit/simulator/tiles/functions.py index e7ead29f..84938a83 100644 --- a/src/aihwkit/simulator/tiles/functions.py +++ b/src/aihwkit/simulator/tiles/functions.py @@ -32,6 +32,8 @@ def forward( Note: Indexed versions can used when analog_ctx.use_indexed is set to True. """ + # `ctx` is the parameter required by PyTorch to store the context + # no need to pass it through ```AnalogFunction.apply(...)````. # Store in context for using during `backward()`. ctx.analog_ctx = analog_ctx ctx.analog_tile = analog_tile diff --git a/src/aihwkit/simulator/tiles/module.py b/src/aihwkit/simulator/tiles/module.py index 4f242e6e..1b4daf20 100644 --- a/src/aihwkit/simulator/tiles/module.py +++ b/src/aihwkit/simulator/tiles/module.py @@ -177,6 +177,7 @@ def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "TileM # at the end. shared weight might be updated above which might # yeild issues if the tile is not first updated self._apply_without_context(lambda t: t.cuda(device)) + self._sync_analog_ctx_weights() return self def cpu(self) -> "TileModule": @@ -194,6 +195,7 @@ def cpu(self) -> "TileModule": super(BaseTile, self).cpu() # type: ignore self._apply_without_context(lambda t: t.cpu()) + self._sync_analog_ctx_weights() return self def to(self, *args: Any, **kwargs: Any) -> "TileModule": diff --git a/src/aihwkit/simulator/tiles/periphery.py b/src/aihwkit/simulator/tiles/periphery.py index 420ff013..ba29ffad 100644 --- a/src/aihwkit/simulator/tiles/periphery.py +++ b/src/aihwkit/simulator/tiles/periphery.py @@ -158,14 +158,23 @@ def set_weights( if not isinstance(bias, Tensor): bias = from_numpy(array(bias)) - self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device) + self.bias.data.copy_(bias) bias = None - combined_weights = self._combine_weights(weight, bias) - - if apply_weight_scaling: - combined_weights = self.apply_weight_scaling(combined_weights, weight_scaling_omega) - self.tile.set_weights(combined_weights) + if hasattr(self.tile, "set_weights_cuda"): + combined_weights = self._combine_weights_cuda(weight, bias) + if apply_weight_scaling: + # apply_weight_scaling works on any device; re-transpose after. + combined_weights = self.apply_weight_scaling( + combined_weights.t().contiguous(), weight_scaling_omega + ).t().contiguous() + self.tile.set_weights_cuda(combined_weights) + else: + combined_weights = self._combine_weights(weight, bias) + if apply_weight_scaling: + combined_weights = self.apply_weight_scaling(combined_weights, weight_scaling_omega) + self.tile.set_weights(combined_weights) + self._sync_analog_ctx_weights() if realistic: self.program_weights() @@ -211,7 +220,13 @@ def get_weights( return self.read_weights(apply_weight_scaling=apply_weight_scaling) # Retrieve the internal weights (and potentially biases) matrix. - combined_weights = self.tile.get_weights() + # Use the CUDA-native path when available to avoid a full device sync + # and the CPU-side transpose loop inside the C++ binding. + if hasattr(self.tile, "get_weights_cuda"): + # get_weights_cuda() → [x_size, d_size] on CUDA; .t() → [d_size, x_size] + combined_weights = self.tile.get_weights_cuda().t().detach().cpu() + else: + combined_weights = self.tile.get_weights() weight, bias = self._separate_weights(combined_weights) if self.digital_bias: diff --git a/src/aihwkit/simulator/tiles/rpucuda.py b/src/aihwkit/simulator/tiles/rpucuda.py index e549d438..80a526ec 100644 --- a/src/aihwkit/simulator/tiles/rpucuda.py +++ b/src/aihwkit/simulator/tiles/rpucuda.py @@ -95,6 +95,20 @@ def __init__( handle_output_bound=True, ) + # shared_weights: an nn.Parameter that holds a copy of the tile weights + # in PyTorch memory. Its role differs by tile type: + # + # InferenceTile (supports_ddp=True): + # Exposes weights to PyTorch autograd so that DDP can synchronise + # gradients across GPUs. Must be kept in sync with the C++ tile via + # ensure_shared_weights(), which is called each forward/backward pass. + # + # RPUCuda training tiles (supports_ddp=False): + # Not used for DDP (raises ModuleError if attempted). Here + # shared_weights serves only as the zero-copy backing store that + # bridges the Python-side Parameter and the C++ tile's dev_weights_. + # _shared_weight_tensor (base class) points to the same storage and + # is used for direct weight access without going through the C++ API. self.shared_weights = None # type: Parameter if shared_weights: self.shared_weights = Parameter( @@ -153,9 +167,18 @@ def cuda( if self.tile.__class__ in MAP_TILE_CLASS_TO_CUDA: with cuda_device(device): self.tile = MAP_TILE_CLASS_TO_CUDA[self.tile.__class__](self.tile) - self.is_cuda = True - self.device = device - self.analog_ctx.data = self.analog_ctx.data.cuda(device) + # CPU shared tensor is no longer valid for the new CUDA tile. + self._shared_weight_tensor = None + # Re-establish shared weight binding for the new CUDA tile, + # but only when not using the shared_weights DDP path. When + # shared_weights is set, ensure_shared_weights() (called on + # the first forward) handles populating dev_weights_ from + # self.shared_weights.data — calling _bind_shared_weights() + # here would set shared_weights_if_=True prematurely and + # prevent that population step from running. + if self.shared_weights is None: + self._bind_shared_weights() + self.analog_ctx.data = self.tile.get_weights().cuda(device) self.analog_ctx.reset(self) # type: ignore if self.shared_weights is not None: @@ -165,7 +188,22 @@ def cuda( dtype=self.get_dtype(), requires_grad=True, ).cuda(device) - # ensure shared weights will be called later (needs copying still) + # Eagerly populate shared_weights and bind _shared_weight_tensor, + # unless the tile uses is_perfect forward. When is_perfect=True, + # calling ensure_shared_weights() here would invoke C++ + # setSharedWeights() which replaces dev_weights_ with the + # still-zero shared buffer *before* the first forward has a + # chance to populate it — resulting in all-zero output. + # Example: InferenceRPUConfig(forward=IOParameters(is_perfect=True)) + # We defer to ensure_shared_weights() on the first forward(), + # where the C++ copyTo() properly fills the buffer first. + is_perfect_fwd = getattr( + getattr(getattr(self, "rpu_config", None), "forward", None), + "is_perfect", + False, + ) + if not is_perfect_fwd: + self.ensure_shared_weights() return self @@ -183,6 +221,9 @@ def ensure_shared_weights(self, shared_weights: Optional[Tensor] = None) -> None if self.shared_weights is not None: self.tile.set_shared_weights(self.shared_weights.data) # type: ignore + # Keep _shared_weight_tensor in sync: the RPUCuda shared_weights + # path replaces the C++ tile's backing store. + self._shared_weight_tensor = self.shared_weights.data @no_grad() def set_delta_weights(self, delta_weights: Optional[Tensor] = None) -> None: diff --git a/src/aihwkit/simulator/tiles/torch_tile.py b/src/aihwkit/simulator/tiles/torch_tile.py index 6c8dbd66..6bae2a2f 100644 --- a/src/aihwkit/simulator/tiles/torch_tile.py +++ b/src/aihwkit/simulator/tiles/torch_tile.py @@ -77,17 +77,21 @@ def set_weights(self, weight: Tensor) -> None: Args: weight: ``[out_size, in_size]`` weight matrix. """ - self.weight.data = weight.clone().to(self.weight.device) + self.weight.data.copy_(weight) - def get_weights(self) -> Tensor: + def get_weights(self, as_ref: bool = False) -> Tensor: """Get the tile weights. + Args: + as_ref: if True, return a reference to the internal weight tensor. + If False (default), return a detached CPU copy. + Returns: - a tuple where the first item is the ``[out_size, in_size]`` weight - matrix; and the second item is either the ``[out_size]`` bias vector - or ``None`` if the tile is set not to use bias. + the ``[out_size, in_size]`` weight matrix. """ - return self.weight.data.detach().cpu() + if as_ref: + return self.weight.data + return self.weight.data.detach().cpu().clone() def get_x_size(self) -> int: """Returns input size of tile""" diff --git a/src/aihwkit/simulator/tiles/transfer.py b/src/aihwkit/simulator/tiles/transfer.py index 12939aff..b110393a 100644 --- a/src/aihwkit/simulator/tiles/transfer.py +++ b/src/aihwkit/simulator/tiles/transfer.py @@ -319,9 +319,14 @@ def get_brief_info(self) -> str: """Returns a brief info""" return self.__class__.__name__ + "({})".format(self.extra_repr()) - def get_weights(self) -> Tensor: - """Returns the analog weights.""" + def get_weights(self, as_ref: bool = False) -> Tensor: + """Returns the analog weights. + Args: + as_ref: if True, return a reference to the internal weight tensor. + If False (default), return a detached CPU copy. + """ + # weight_tile.tile is a C++ tile that doesn't accept kwargs return self.weight_tile.tile.get_weights() def set_weights(self, weight: Tensor) -> None: diff --git a/src/rpucuda/cuda/rpucuda.cu b/src/rpucuda/cuda/rpucuda.cu index d0246dec..7be19460 100644 --- a/src/rpucuda/cuda/rpucuda.cu +++ b/src/rpucuda/cuda/rpucuda.cu @@ -333,6 +333,38 @@ template void RPUCudaSimple::setSharedWeights(T *device_source) } } +template void RPUCudaSimple::getWeightsCuda(T *device_dest) const { + // D2D read of dev_weights_ in its native transposed [x_size, d_size] layout. + // + // cf. getWeights(T*): copies dev_weights_ to host and transposes to + // row-major [d_size, x_size] — involves GPU sync + D2H + O(n^2) transpose. + // This method stays entirely on device: one async D2D memcpy, no transpose. + // Use when the caller needs a GPU tensor (e.g. Python get_weights_cuda()). + dev_weights_->synchronize(); + CUDA_CALL(cudaMemcpyAsync( + device_dest, dev_weights_->getData(), + (size_t)this->x_size_ * this->d_size_ * sizeof(T), cudaMemcpyDeviceToDevice, + context_->getStream())); +} + +template void RPUCudaSimple::setWeightsCuda(const T *device_source) { + // D2D write into dev_weights_ from a device buffer already in transposed + // [x_size, d_size] layout, then sync the host copy for serialisation. + // + // cf. setWeights(const T* host_source): takes a host pointer in row-major + // [d_size, x_size], copies H2H into weights_, then assignTranspose H2D + // into dev_weights_ — involves a full transpose upload. + // This method avoids the GPU->CPU->GPU round-trip when the source tensor + // is already on the same GPU (e.g. Python set_weights_cuda()). + dev_weights_->synchronize(); + CUDA_CALL(cudaMemcpyAsync( + dev_weights_->getData(), device_source, + (size_t)this->x_size_ * this->d_size_ * sizeof(T), cudaMemcpyDeviceToDevice, + context_->getStream())); + // Keep host copy in sync for serialisation (__getstate__). + this->copyWeightsToHost(RPUSimple::getWeightsPtr()[0]); +} + template void RPUCudaSimple::getAndResetWeightUpdate(T *prev_weight_and_dw_out, T scale) { RPU::math::elemsubcopy( diff --git a/src/rpucuda/cuda/rpucuda.h b/src/rpucuda/cuda/rpucuda.h index bca8f419..8f18ae7e 100644 --- a/src/rpucuda/cuda/rpucuda.h +++ b/src/rpucuda/cuda/rpucuda.h @@ -198,6 +198,11 @@ template class RPUCudaSimple : public RPUSimple { void setWeightsUniformRandom(T min_value, T max_value) override; void setSharedWeights(T *weightsptr) override; + // Direct device-to-device weight access (no CPU roundtrip). + // Both buffers use the internal transposed [x_size, d_size] layout. + virtual void getWeightsCuda(T *device_dest) const; + virtual void setWeightsCuda(const T *device_source); + void getAndResetWeightUpdate(T *prev_weights_and_dw_out, T scale = 1.0) override; void applyWeightUpdate(T *dw_and_current_weights_out) override; diff --git a/src/rpucuda/cuda/rpucuda_pulsed.cu b/src/rpucuda/cuda/rpucuda_pulsed.cu index eb0d2130..0e276a0b 100644 --- a/src/rpucuda/cuda/rpucuda_pulsed.cu +++ b/src/rpucuda/cuda/rpucuda_pulsed.cu @@ -529,6 +529,28 @@ template void RPUCudaPulsed::setWeights(const T *host_source) { RPUCudaSimple::setWeights(this->getWeightsPtr()[0]); // set device weights } +template void RPUCudaPulsed::setWeightsCuda(const T *device_source) { + CHECK_RPU_DEVICE_INIT; + + // 1. D2D copy device_source -> dev_weights_, then sync to host (weights_). + RPUCudaSimple::setWeightsCuda(device_source); + + // 2. If a pulsed device is attached (e.g. ConstantStep, LinearStep), + // let it inspect / clamp the new host weights via onSetWeights(). + if (rpu_device_) { + // onSetWeights() returns true when it modified the host weights + // (e.g. applied weight bounds or symmetry constraints). + if (rpu_device_->onSetWeights(this->getWeightsPtr())) { + // 3. Device parameters derived from weights may have changed + // (e.g. slope / reference levels). Sync host device -> CUDA device. + rpucuda_device_->populateFrom(*rpu_device_); + // 4. Host weights were modified by onSetWeights(); push them back + // to dev_weights_ so the GPU copy reflects the clamped values. + RPUCudaSimple::setWeights(this->getWeightsPtr()[0]); + } + } +} + template void RPUCudaPulsed::applyWeightUpdate(T *dw_and_current_weight_out) { CHECK_RPU_DEVICE_INIT; diff --git a/src/rpucuda/cuda/rpucuda_pulsed.h b/src/rpucuda/cuda/rpucuda_pulsed.h index 31bced3c..4e68c007 100644 --- a/src/rpucuda/cuda/rpucuda_pulsed.h +++ b/src/rpucuda/cuda/rpucuda_pulsed.h @@ -165,6 +165,7 @@ template class RPUCudaPulsed : public RPUCudaSimple { void getWeightsReal(T *weightsptr) override; void setWeightsReal(const T *weightsptr, int n_loops = 25) override; void setWeights(const T *weightsptr) override; + void setWeightsCuda(const T *device_source) override; void applyWeightUpdate(T *dw_and_current_weights_out) override; diff --git a/tests/test_analog_ctx.py b/tests/test_analog_ctx.py new file mode 100644 index 00000000..480b86a9 --- /dev/null +++ b/tests/test_analog_ctx.py @@ -0,0 +1,1108 @@ +# -*- coding: utf-8 -*- + +# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# +# Licensed under the MIT license. See LICENSE file in the project root for details. + +# pylint: disable=too-many-locals, no-member +"""Tests for AnalogContext data attribution (PR #717). + +Verifies that analog_ctx.data reflects the actual weight matrix stored in the +tile, rather than being an empty scalar tensor. +""" + +from unittest import SkipTest + +from torch import zeros, randn, allclose, Tensor, Size, manual_seed +from torch.nn import Parameter +from torch.nn import Linear as TorchLinear, Sequential, Conv2d as TorchConv2d + +from aihwkit.nn import AnalogLinear, AnalogConv2d +from aihwkit.nn.conversion import convert_to_analog +from aihwkit.optim.context import AnalogContext +from aihwkit.optim.weight_view import ReadOnlyWeightView +from aihwkit.simulator.configs import ( + FloatingPointRPUConfig, + InferenceRPUConfig, + SingleRPUConfig, + TorchInferenceRPUConfig, +) +from aihwkit.simulator.configs.devices import ConstantStepDevice + +from .helpers.decorators import parametrize_over_layers +from .helpers.layers import Linear, LinearCuda, LinearMapped, LinearMappedCuda +from .helpers.testcases import ParametrizedTestCase, SKIP_CUDA_TESTS +from .helpers.tiles import FloatingPoint, TorchInference + + +@parametrize_over_layers( + layers=[Linear, LinearMapped], + tiles=[FloatingPoint, TorchInference], + biases=["analog", "digital", None], +) +class AnalogCtxDataAttributionTest(ParametrizedTestCase): + """Tests that analog_ctx.data has the correct shape and values.""" + + def _get_analog_tile(self, model): + """Return the first analog tile from a model.""" + return next(model.analog_tiles()) + + def test_ctx_data_shape_matches_weights(self): + """analog_ctx.size() must return the tile weight shape, not torch.Size([]).""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + ctx = tile.analog_ctx + + # The old implementation returned torch.Size([]) — a scalar. + # The new one must return the tile weight matrix shape. + self.assertNotEqual(ctx.size(), Size([])) + self.assertEqual(len(ctx.size()), 2) + + expected_rows = tile.out_size + in_size = tile.in_size + (1 if tile.analog_bias else 0) + self.assertEqual(ctx.size(), Size([expected_rows, in_size])) + + def test_ctx_data_values_match_tile_weights(self): + """analog_ctx.data must reflect the actual tile weights.""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + weights_from_tile = tile.tile.get_weights() + ctx_data = tile.analog_ctx.data + + self.assertEqual(ctx_data.shape, weights_from_tile.shape) + + def test_ctx_norm_is_meaningful(self): + """analog_ctx.norm() should reflect the weight magnitude, not 1.0.""" + manual_seed(42) + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + # With randomly initialized weights, the norm should be > 0 + # and should NOT be exactly 1.0 (which the old scalar ones(()) returned). + norm_val = tile.analog_ctx.norm().item() + self.assertGreater(norm_val, 0.0) + + def test_ctx_nonzero_works(self): + """analog_ctx.nonzero() should return indices of nonzero weights.""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + # With random initialization, most weights are nonzero. + nz = tile.analog_ctx.nonzero() + self.assertGreater(len(nz), 0) + + def test_ctx_comparison_ops(self): + """Comparison operators on analog_ctx should work on actual weights.""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + # Weights are initialized near zero with std ~1, so most are < 10. + mask = tile.analog_ctx > 10 + self.assertIsInstance(mask, Tensor) + self.assertEqual(mask.shape, tile.analog_ctx.shape) + + def test_ctx_is_parameter(self): + """analog_ctx should be a torch.nn.Parameter.""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + self.assertIsInstance(tile.analog_ctx, Parameter) + self.assertIsInstance(tile.analog_ctx, AnalogContext) + + def test_ctx_after_set_weights(self): + """analog_ctx.data should remain valid after set_weights.""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + # Set new weights + new_weight = randn(tile.out_size, tile.in_size) + new_bias = randn(tile.out_size) if tile.analog_bias else None + tile.set_weights(new_weight, new_bias) + + # ctx should still have a valid non-scalar shape + self.assertNotEqual(tile.analog_ctx.size(), Size([])) + self.assertEqual(len(tile.analog_ctx.size()), 2) + + +@parametrize_over_layers( + layers=[LinearCuda, LinearMappedCuda], + tiles=[FloatingPoint, TorchInference], + biases=["analog", "digital", None], +) +class AnalogCtxDataAttributionCudaTest(ParametrizedTestCase): + """Tests that analog_ctx.data is correct after moving to CUDA.""" + + def _get_analog_tile(self, model): + """Return the first analog tile from a model.""" + return next(model.analog_tiles()) + + def test_ctx_shape_after_cuda(self): + """analog_ctx.data should retain correct shape after .cuda().""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + self.assertTrue(tile.analog_ctx.is_cuda) + self.assertNotEqual(tile.analog_ctx.size(), Size([])) + self.assertEqual(len(tile.analog_ctx.size()), 2) + + def test_ctx_device_after_cuda(self): + """analog_ctx.device should be CUDA after .cuda().""" + model = self.get_layer(in_features=4, out_features=6) + tile = self._get_analog_tile(model) + + self.assertEqual(tile.analog_ctx.device.type, "cuda") + self.assertEqual(tile.device.type, "cuda") + + +class AnalogCtxBackwardCompatibilityTest(ParametrizedTestCase): + """Tests for backward compatibility with old checkpoints.""" + + use_cuda = False + + def test_old_checkpoint_empty_ctx_loads(self): + """Checkpoints with empty-size analog_ctx should load without error.""" + model = AnalogLinear(4, 6, bias=True, rpu_config=FloatingPointRPUConfig()) + + # Simulate an old checkpoint where analog_ctx was torch.Size([]) + state = model.state_dict() + for key in list(state.keys()): + if "analog_ctx" in key: + state[key] = zeros(()) # Simulate old empty scalar + + # Should load without error (non-strict mode for old ctx) + model.load_state_dict(state, strict=False, load_rpu_config=False) + + def test_new_checkpoint_loads(self): + """Checkpoints with properly-shaped analog_ctx should load correctly.""" + model = AnalogLinear(4, 6, bias=True, rpu_config=FloatingPointRPUConfig()) + state = model.state_dict() + + model2 = AnalogLinear(4, 6, bias=True, rpu_config=FloatingPointRPUConfig()) + model2.load_state_dict(state, strict=True, load_rpu_config=False) + + +class AnalogCtxConversionTest(ParametrizedTestCase): + """Tests that convert_to_analog produces valid analog_ctx.""" + + use_cuda = False + + def test_conversion_ctx_shape(self) -> None: + """Converted model should have correct analog_ctx shape.""" + digital_model = Sequential(TorchLinear(8, 4), TorchLinear(4, 2)) + analog_model = convert_to_analog( + digital_model, FloatingPointRPUConfig(), ensure_analog_root=False + ) + + for tile in analog_model.analog_tiles(): + ctx = tile.analog_ctx + self.assertNotEqual(ctx.size(), Size([])) + self.assertEqual(len(ctx.size()), 2) + + def test_conversion_conv2d_ctx_shape(self) -> None: + """Converted Conv2d should have correct analog_ctx shape.""" + digital_conv = TorchConv2d(3, 16, kernel_size=3, padding=1, bias=True) + analog_conv = AnalogConv2d.from_digital(digital_conv, FloatingPointRPUConfig()) + + for tile in analog_conv.analog_tiles(): + ctx = tile.analog_ctx + self.assertNotEqual(ctx.size(), Size([])) + self.assertEqual(len(ctx.size()), 2) + + +class AnalogCtxDevicePropertyTest(ParametrizedTestCase): + """Tests that tile.device and tile.is_cuda are computed from analog_ctx.""" + + use_cuda = False + + def test_device_property_cpu(self): + """tile.device should return CPU device for CPU tiles.""" + model = AnalogLinear(4, 6, rpu_config=FloatingPointRPUConfig()) + tile = next(model.analog_tiles()) + + self.assertEqual(tile.device.type, "cpu") + self.assertFalse(tile.is_cuda) + + def test_device_property_cuda(self): + """tile.device should return CUDA device after .cuda().""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + + model = AnalogLinear(4, 6, rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + + self.assertEqual(tile.device.type, "cuda") + self.assertTrue(tile.is_cuda) + + +class AnalogCtxSyncAfterSetWeightsTest(ParametrizedTestCase): + """Reviewer concern #1: analog_ctx.data must stay in sync after set_weights.""" + + use_cuda = False + + def _test_sync(self, rpu_config, use_cuda): + """Helper: verify ctx.data matches tile weights after set_weights.""" + model = AnalogLinear(4, 6, bias=False, rpu_config=rpu_config) + if use_cuda: + model = model.cuda() + tile = next(model.analog_tiles()) + + new_w = randn(6, 4) + tile.set_weights(new_w, None) + + w_from_tile, _ = tile.get_weights() + ctx_data = tile.analog_ctx.data.detach().cpu() + self.assertTrue(allclose(ctx_data, w_from_tile), + "analog_ctx.data out of sync after set_weights") + + def test_sync_torch_inference_cpu(self): + """TorchInference CPU: ctx stays in sync after set_weights.""" + self._test_sync(TorchInferenceRPUConfig(), use_cuda=False) + + def test_sync_floating_point_cpu(self): + """FloatingPoint CPU: ctx stays in sync after set_weights.""" + self._test_sync(FloatingPointRPUConfig(), use_cuda=False) + + def test_sync_torch_inference_cuda(self): + """TorchInference CUDA: ctx stays in sync after set_weights.""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + self._test_sync(TorchInferenceRPUConfig(), use_cuda=True) + + def test_sync_floating_point_cuda(self): + """FloatingPoint CUDA: ctx stays in sync after set_weights.""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + self._test_sync(FloatingPointRPUConfig(), use_cuda=True) + + def test_sync_after_multiple_set_weights(self): + """ctx should stay in sync after multiple consecutive set_weights.""" + model = AnalogLinear(4, 6, bias=False, rpu_config=TorchInferenceRPUConfig()) + tile = next(model.analog_tiles()) + + for _ in range(5): + new_w = randn(6, 4) + tile.set_weights(new_w, None) + w_from_tile, _ = tile.get_weights() + ctx_data = tile.analog_ctx.data.detach().cpu() + self.assertTrue(allclose(ctx_data, w_from_tile)) + + def test_sync_after_cuda_move(self): + """ctx should sync after CPU->CUDA->CPU round-trip.""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + + model = AnalogLinear(4, 6, bias=False, rpu_config=TorchInferenceRPUConfig()) + tile = next(model.analog_tiles()) + + new_w = randn(6, 4) + tile.set_weights(new_w, None) + + # Move to CUDA + model.cuda() + tile = next(model.analog_tiles()) + w_cuda, _ = tile.get_weights() + ctx_cuda = tile.analog_ctx.data.detach().cpu() + self.assertTrue(allclose(ctx_cuda, w_cuda), + "ctx out of sync after cuda()") + + +class AnalogCtxGetWeightsConventionTest(ParametrizedTestCase): + """Reviewer concern #3: get_weights default returns detached CPU copy.""" + + use_cuda = False + + def test_get_weights_returns_cpu_torch_inference(self): + """TorchInference: get_weights() returns CPU tensor by default.""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + model = AnalogLinear( + 4, 6, bias=False, rpu_config=TorchInferenceRPUConfig() + ).cuda() + tile = next(model.analog_tiles()) + w, _ = tile.get_weights() + self.assertEqual(w.device.type, "cpu", + "get_weights() should return CPU tensor by default") + + def test_get_weights_returns_cpu_floating_point(self): + """FloatingPoint: get_weights() returns CPU tensor by default.""" + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + model = AnalogLinear( + 4, 6, bias=False, rpu_config=FloatingPointRPUConfig() + ).cuda() + tile = next(model.analog_tiles()) + w, _ = tile.get_weights() + self.assertEqual(w.device.type, "cpu", + "get_weights() should return CPU tensor by default") + + def test_get_weights_is_detached(self): + """get_weights() result should not have grad_fn (detached).""" + model = AnalogLinear( + 4, 6, bias=False, rpu_config=TorchInferenceRPUConfig() + ) + tile = next(model.analog_tiles()) + w, _ = tile.get_weights() + self.assertFalse(w.requires_grad, + "get_weights() should return detached tensor") + self.assertIsNone(w.grad_fn, + "get_weights() should return detached tensor") + + def test_get_weights_is_copy(self): + """Modifying get_weights() result should NOT change tile weights.""" + model = AnalogLinear( + 4, 6, bias=False, rpu_config=TorchInferenceRPUConfig() + ) + tile = next(model.analog_tiles()) + w_before, _ = tile.get_weights() + w_copy, _ = tile.get_weights() + w_copy.fill_(999.0) + w_after, _ = tile.get_weights() + self.assertTrue(allclose(w_before, w_after), + "get_weights() should return a copy, not a reference") + + +class AnalogCtxTileGetWeightsRefTest(ParametrizedTestCase): + """Test tile.get_weights(as_ref=...) reference vs copy semantics.""" + + use_cuda = False + + def _get_tile(self, rpu_config): + model = AnalogLinear(4, 6, bias=False, rpu_config=rpu_config) + return next(model.analog_tiles()) + + def test_as_ref_true_shares_storage(self): + """as_ref=True should return tensors sharing the same storage.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + ref1 = tile.tile.get_weights(as_ref=True) + ref2 = tile.tile.get_weights(as_ref=True) + self.assertEqual(ref1.data_ptr(), ref2.data_ptr(), + "as_ref=True should return the same data pointer") + + def test_clone_does_not_share_storage(self): + """clone() of as_ref=True should NOT share storage.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + ref = tile.tile.get_weights(as_ref=True) + clone = ref.clone() + self.assertNotEqual(ref.data_ptr(), clone.data_ptr(), + "clone should have a different data pointer") + + def test_as_ref_true_modification_propagates(self): + """Modifying as_ref=True tensor should change tile weights.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + ref = tile.tile.get_weights(as_ref=True) + original = ref[0, 0].item() + ref[0, 0] += 999.0 + check = tile.tile.get_weights(as_ref=True) + self.assertAlmostEqual(check[0, 0].item(), original + 999.0, places=2, + msg="as_ref=True modification should propagate to tile") + + def test_as_ref_false_modification_does_not_propagate(self): + """Modifying as_ref=False tensor should NOT change tile weights.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + original = tile.tile.get_weights(as_ref=True)[0, 0].item() + copy = tile.tile.get_weights(as_ref=False) + copy[0, 0] += 999.0 + check = tile.tile.get_weights(as_ref=True) + self.assertAlmostEqual(check[0, 0].item(), original, places=5, + msg="as_ref=False modification should NOT propagate to tile") + + def test_clone_modification_does_not_propagate(self): + """Modifying clone of as_ref=True should NOT change tile weights.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + original = tile.tile.get_weights(as_ref=True)[0, 0].item() + clone = tile.tile.get_weights(as_ref=True).clone() + clone[0, 0] += 999.0 + check = tile.tile.get_weights(as_ref=True) + self.assertAlmostEqual(check[0, 0].item(), original, places=5, + msg="clone modification should NOT propagate to tile") + + +class AnalogCtxSharedWeightsZeroCopyTest(ParametrizedTestCase): + """Tests that C++ tiles use zero-copy shared weights with analog_ctx. + + After ``_bind_shared_weights``, the C++ tile's internal weight storage + and ``analog_ctx.data`` share the same memory. ``tile.update()`` and + ``tile.set_weights()`` must be visible through the shared tensor + without any explicit sync call. + """ + + use_cuda = False + + def _get_tile(self, rpu_config): + model = AnalogLinear(4, 6, bias=False, rpu_config=rpu_config) + return next(model.analog_tiles()) + + # -- FloatingPoint (C++ tile) tests ---------------------------------------- + + def test_shared_tensor_exists_for_cpp_tile(self): + """C++ tiles should have a non-None _shared_weight_tensor.""" + tile = self._get_tile(FloatingPointRPUConfig()) + self.assertIsNotNone(tile._shared_weight_tensor, + "C++ tile should have a shared weight tensor") + + def test_shared_tensor_none_for_python_tile(self): + """Pure Python tiles should NOT use _shared_weight_tensor.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + self.assertIsNone(tile._shared_weight_tensor, + "Python tile should not use shared weight tensor") + + def test_ctx_data_shares_memory_with_cpp_tile(self): + """analog_ctx.data and _shared_weight_tensor should share memory.""" + tile = self._get_tile(FloatingPointRPUConfig()) + ctx_ptr = tile.analog_ctx.data.data_ptr() + shared_ptr = tile._shared_weight_tensor.data_ptr() + self.assertEqual(ctx_ptr, shared_ptr, + "analog_ctx.data and shared tensor should have same data_ptr") + + def test_update_reflects_in_ctx_without_sync(self): + """After tile.update(), analog_ctx.data should reflect changes (zero-copy).""" + tile = self._get_tile(FloatingPointRPUConfig()) + tile.tile.set_learning_rate(0.1) + + snapshot = tile.analog_ctx.data.detach().clone() + + x = randn(1, 4) + d = randn(1, 6) + tile.tile.update(x, d, False) + + # analog_ctx.data should have changed without explicit sync + self.assertFalse( + allclose(tile.analog_ctx.data.detach(), snapshot), + "analog_ctx.data should change after tile.update() without sync") + + # And it should match get_weights() + w_from_tile = tile.tile.get_weights() + self.assertTrue( + allclose(tile.analog_ctx.data.detach(), w_from_tile), + "analog_ctx.data should match get_weights() after update") + + def test_set_weights_reflects_in_ctx_without_sync(self): + """After tile.set_weights(), analog_ctx.data should reflect changes.""" + tile = self._get_tile(FloatingPointRPUConfig()) + new_w = randn(6, 4) + tile.tile.set_weights(new_w) + + self.assertTrue( + allclose(tile.analog_ctx.data.detach(), new_w), + "analog_ctx.data should match new weights after set_weights()") + + def test_multiple_updates_stay_in_sync(self): + """analog_ctx.data should stay in sync across 10 consecutive updates.""" + tile = self._get_tile(FloatingPointRPUConfig()) + tile.tile.set_learning_rate(0.01) + + for _ in range(10): + x = randn(4, 4) + d = randn(4, 6) + tile.tile.update(x, d, False) + + w_from_tile = tile.tile.get_weights() + self.assertTrue( + allclose(tile.analog_ctx.data.detach(), w_from_tile), + "analog_ctx.data drifted from tile weights during updates") + + def test_get_weights_ref_returns_shared_tensor(self): + """_get_tile_weights_ref should return the shared tensor for C++ tiles.""" + tile = self._get_tile(FloatingPointRPUConfig()) + ref = tile._get_tile_weights_ref() + self.assertEqual(ref.data_ptr(), tile._shared_weight_tensor.data_ptr(), + "_get_tile_weights_ref should return shared tensor") + + # -- as_ref=True after update (the key new test) --------------------------- + + def test_as_ref_true_reflects_update_python_tile(self): + """Python tile: as_ref=True weight should reflect tile.update() changes.""" + tile = self._get_tile(TorchInferenceRPUConfig()) + ref = tile.tile.get_weights(as_ref=True) + snapshot = ref.clone() + + # Manually modify via the ref (simulating what update would do) + ref[0, 0] += 100.0 + check = tile.tile.get_weights(as_ref=True) + self.assertAlmostEqual(check[0, 0].item(), snapshot[0, 0].item() + 100.0, places=2, + msg="Python tile: as_ref write should propagate") + + def test_as_ref_true_reflects_update_cpp_tile(self): + """C++ tile: _get_tile_weights_ref should reflect tile.update() changes. + + This is the key test: for C++ tiles, the shared weight tensor + (returned by _get_tile_weights_ref) must automatically reflect + updates performed by the C++ tile.update() — zero-copy. + """ + tile = self._get_tile(FloatingPointRPUConfig()) + tile.tile.set_learning_rate(0.1) + + ref = tile._get_tile_weights_ref() + snapshot = ref.clone() + + x = randn(1, 4) + d = randn(1, 6) + tile.tile.update(x, d, False) + + # ref should have been modified in-place by the C++ update + self.assertFalse( + allclose(ref, snapshot), + "C++ tile: shared weight ref should change after update (zero-copy)") + + # And the ref should match get_weights + w_copy = tile.tile.get_weights() + self.assertTrue( + allclose(ref, w_copy), + "C++ tile: shared weight ref should match get_weights() after update") + + +class AnalogCtxReadOnlyTest(ParametrizedTestCase): + """Tests for ReadOnlyWeightView and the readonly flag on AnalogContext.""" + + use_cuda = False + + def _make_model(self, readonly=True): + """Create an AnalogLinear model with the given readonly setting.""" + rpu_config = TorchInferenceRPUConfig() + rpu_config.mapping.readonly_weights = readonly + return AnalogLinear(4, 6, bias=False, rpu_config=rpu_config) + + def _get_ctx(self, model): + tile = next(model.analog_tiles()) + return tile.analog_ctx + + # -- default behaviour ---------------------------------------------------- + + def test_default_readonly_true(self): + """By default, analog_ctx.data should be a ReadOnlyWeightView.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + self.assertTrue(ctx.readonly) + self.assertIsInstance(ctx.data, ReadOnlyWeightView) + + def test_default_readonly_false_via_config(self): + """Setting readonly_weights=False in config should disable protection.""" + model = self._make_model(readonly=False) + ctx = self._get_ctx(model) + self.assertFalse(ctx.readonly) + self.assertNotIsInstance(ctx.data, ReadOnlyWeightView) + + # -- read operations work transparently ----------------------------------- + + def test_read_ops_work_when_readonly(self): + """size, norm, nonzero, comparisons should all work on readonly data.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + + self.assertEqual(ctx.size(), Size([6, 4])) + self.assertGreater(ctx.norm().item(), 0.0) + self.assertGreater(len(ctx.nonzero()), 0) + mask = ctx > 10 + self.assertEqual(mask.shape, ctx.shape) + + # -- in-place ops blocked when readonly ----------------------------------- + + def test_add_inplace_blocked(self): + """ctx.data.add_() should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data.add_(1.0) + + def test_mul_inplace_blocked(self): + """ctx.data.mul_() should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data.mul_(2.0) + + def test_copy_inplace_blocked(self): + """ctx.data.copy_() should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data.copy_(randn(6, 4)) + + def test_fill_inplace_blocked(self): + """ctx.data.fill_() should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data.fill_(0.0) + + def test_zero_inplace_blocked(self): + """ctx.data.zero_() should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data.zero_() + + def test_setitem_blocked(self): + """ctx.data[0, 0] = ... should raise RuntimeError when readonly.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + with self.assertRaises(RuntimeError): + ctx.data[0, 0] = 999.0 + + # -- in-place ops allowed when writable ----------------------------------- + + def test_add_inplace_allowed_when_not_readonly(self): + """ctx.data.add_() should work when readonly=False.""" + model = self._make_model(readonly=False) + ctx = self._get_ctx(model) + ctx.data.add_(1.0) # should not raise + + def test_setitem_allowed_when_not_readonly(self): + """ctx.data[0,0] = ... should work when readonly=False.""" + model = self._make_model(readonly=False) + ctx = self._get_ctx(model) + ctx.data[0, 0] = 999.0 # should not raise + + # -- flag toggling -------------------------------------------------------- + + def test_toggle_readonly_on(self): + """Switching readonly from False to True should wrap data.""" + model = self._make_model(readonly=False) + ctx = self._get_ctx(model) + self.assertNotIsInstance(ctx.data, ReadOnlyWeightView) + + ctx.readonly = True + self.assertIsInstance(ctx.data, ReadOnlyWeightView) + with self.assertRaises(RuntimeError): + ctx.data.add_(1.0) + + def test_toggle_readonly_off(self): + """Switching readonly from True to False should unwrap data.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + self.assertIsInstance(ctx.data, ReadOnlyWeightView) + + ctx.readonly = False + self.assertNotIsInstance(ctx.data, ReadOnlyWeightView) + ctx.data.add_(1.0) # should not raise + + # -- context manager ------------------------------------------------------ + + def test_writable_context_manager(self): + """writable() should temporarily allow in-place ops.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + + with ctx.writable(): + self.assertFalse(ctx.readonly) + ctx.data.add_(1.0) # should not raise + + # Readonly restored + self.assertTrue(ctx.readonly) + with self.assertRaises(RuntimeError): + ctx.data.add_(1.0) + + def test_writable_context_manager_restores_on_exception(self): + """writable() should restore readonly even if an exception occurs.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + + try: + with ctx.writable(): + raise ValueError("test exception") + except ValueError: + pass + + self.assertTrue(ctx.readonly) + + # -- data assignment auto-wraps ------------------------------------------- + + def test_data_assignment_auto_wraps(self): + """Assigning to ctx.data should auto-wrap when readonly=True.""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + + ctx.data = randn(6, 4) + self.assertIsInstance(ctx.data, ReadOnlyWeightView) + + def test_data_assignment_no_wrap_when_writable(self): + """Assigning to ctx.data should NOT wrap when readonly=False.""" + model = self._make_model(readonly=False) + ctx = self._get_ctx(model) + + ctx.data = randn(6, 4) + self.assertNotIsInstance(ctx.data, ReadOnlyWeightView) + + # -- set_data respects readonly ------------------------------------------- + + def test_set_data_works_when_readonly(self): + """set_data() should succeed even when readonly (uses assignment).""" + model = self._make_model(readonly=True) + ctx = self._get_ctx(model) + + new_data = randn(6, 4) + ctx.set_data(new_data) + self.assertIsInstance(ctx.data, ReadOnlyWeightView) + self.assertTrue(allclose(ctx.data.detach(), new_data)) + + # -- convert_to_analog readonly parameter --------------------------------- + + def test_convert_to_analog_readonly_override_false(self): + """convert_to_analog(readonly=False) should set all ctx.readonly=False.""" + digital_model = Sequential(TorchLinear(8, 4), TorchLinear(4, 2)) + analog_model = convert_to_analog( + digital_model, TorchInferenceRPUConfig(), + ensure_analog_root=False, readonly=False, + ) + for param in analog_model.parameters(): + if isinstance(param, AnalogContext): + self.assertFalse(param.readonly) + self.assertNotIsInstance(param.data, ReadOnlyWeightView) + + def test_convert_to_analog_readonly_override_true(self): + """convert_to_analog(readonly=True) should set all ctx.readonly=True.""" + rpu_config = TorchInferenceRPUConfig() + rpu_config.mapping.readonly_weights = False # config says writable + digital_model = Sequential(TorchLinear(8, 4), TorchLinear(4, 2)) + analog_model = convert_to_analog( + digital_model, rpu_config, + ensure_analog_root=False, readonly=True, + ) + for param in analog_model.parameters(): + if isinstance(param, AnalogContext): + self.assertTrue(param.readonly) + self.assertIsInstance(param.data, ReadOnlyWeightView) + + def test_convert_to_analog_readonly_default_from_config(self): + """convert_to_analog() without readonly uses rpu_config.mapping value.""" + rpu_config = TorchInferenceRPUConfig() + rpu_config.mapping.readonly_weights = False + digital_model = Sequential(TorchLinear(8, 4)) + analog_model = convert_to_analog( + digital_model, rpu_config, ensure_analog_root=False, + ) + for param in analog_model.parameters(): + if isinstance(param, AnalogContext): + self.assertFalse(param.readonly) + + +class SharedWeightsCudaBindingTest(ParametrizedTestCase): + """Tests that shared weight binding works correctly after .cuda(). + + ``_bind_shared_weights()`` provides zero-copy weight access, but originally + only worked on CPU. After ``.cuda()``, the binding was lost because + ``RPUCudaSimulatorTileWrapper.cuda()`` never re-bound it. + + These tests cover: + + - ``_bind_shared_weights()`` must allocate on the correct device + with the correct layout (CUDA uses transposed ``(x_size, d_size)``). + - ``.cuda()`` / ``.cpu()`` must re-bind shared weights for + FloatingPoint-family tiles. + - ``CudaAnalogTile.set_shared_weights()`` breaks ``is_perfect`` forward, + so AnalogTile variants must NOT be bound on CUDA. + - Training must still converge after shared weight binding. + """ + + use_cuda = False # We manually skip inside each test + + def _skip_if_no_cuda(self): + if SKIP_CUDA_TESTS: + raise SkipTest("not compiled with CUDA support") + + # -- FloatingPoint: shared binding survives .cuda() ----------------------- + + def test_shared_tensor_exists_after_cuda_floating_point(self): + """FloatingPoint tile should have shared tensor after .cuda().""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + self.assertIsNotNone(tile._shared_weight_tensor, + "shared tensor should exist after .cuda()") + self.assertTrue(tile._shared_weight_tensor.is_cuda, + "shared tensor should be on CUDA") + + def test_shared_tensor_device_matches_tile(self): + """Shared tensor device should match the CUDA tile device.""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + self.assertEqual(tile._shared_weight_tensor.device.type, "cuda") + + def test_shared_tensor_has_correct_values_after_cuda(self): + """Shared tensor should contain actual weights, not zeros.""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()) + tile_cpu = next(model.analog_tiles()) + w_cpu = tile_cpu.tile.get_weights().clone() + + model = model.cuda() + tile = next(model.analog_tiles()) + + # Shared tensor values should match the original CPU weights. + # The shared tensor is transposed on CUDA, so compare sorted values. + shared_vals = tile._shared_weight_tensor.cpu().flatten().sort()[0] + cpu_vals = w_cpu.flatten().sort()[0] + self.assertTrue(allclose(shared_vals, cpu_vals, atol=1e-5), + "shared tensor should contain the original weights") + + def test_update_reflects_in_shared_tensor_cuda(self): + """tile.update() should modify the shared tensor on CUDA (zero-copy).""" + self._skip_if_no_cuda() + model = AnalogLinear(4, 6, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + tile.tile.set_learning_rate(0.1) + + snapshot = tile._shared_weight_tensor.clone() + x = randn(2, 4, device="cuda") + d = randn(2, 6, device="cuda") + tile.tile.update(x, d, False) + + self.assertFalse(allclose(tile._shared_weight_tensor, snapshot), + "shared tensor should change after tile.update()") + + def test_set_weights_reflects_in_shared_tensor_cuda(self): + """tile.set_weights() should sync to the shared tensor on CUDA.""" + self._skip_if_no_cuda() + model = AnalogLinear(4, 6, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + + new_w = randn(6, 4) + tile.tile.set_weights(new_w) + + # Compare values (shared is transposed on CUDA) + shared_vals = tile._shared_weight_tensor.cpu().flatten().sort()[0] + expected_vals = new_w.flatten().sort()[0] + self.assertTrue(allclose(shared_vals, expected_vals, atol=1e-5), + "shared tensor should reflect set_weights()") + + # -- FloatingPoint: training convergence after binding -------------------- + + def test_training_converges_floating_point_cuda(self): + """Training with FloatingPoint on CUDA should converge after binding.""" + self._skip_if_no_cuda() + from torch.nn.functional import mse_loss + from aihwkit.optim import AnalogSGD + + manual_seed(42) + model = AnalogLinear(4, 2, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + x = randn(10, 4, device="cuda") + y = randn(10, 2, device="cuda") + + initial_loss = mse_loss(model(x), y).item() + + opt = AnalogSGD(model.parameters(), lr=0.1) + opt.regroup_param_groups(model) + for _ in range(50): + opt.zero_grad() + loss = mse_loss(model(x), y) + loss.backward() + opt.step() + + final_loss = mse_loss(model(x), y).item() + self.assertLess(final_loss, initial_loss, + "training should reduce loss on CUDA FloatingPoint") + + # -- CPU round-trip: .cuda() then .cpu() ---------------------------------- + + def test_shared_tensor_survives_cpu_round_trip(self): + """Shared tensor should be re-bound after .cuda() -> .cpu().""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()) + model = model.cuda() + model = model.cpu() + tile = next(model.analog_tiles()) + + self.assertIsNotNone(tile._shared_weight_tensor, + "shared tensor should exist after round-trip") + self.assertFalse(tile._shared_weight_tensor.is_cuda, + "shared tensor should be on CPU after .cpu()") + + def test_shared_tensor_survives_cuda_cpu_cuda(self): + """Shared tensor should survive CPU -> CUDA -> CPU -> CUDA.""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()) + model = model.cuda().cpu().cuda() + tile = next(model.analog_tiles()) + + self.assertIsNotNone(tile._shared_weight_tensor, + "shared tensor should exist after double round-trip") + self.assertTrue(tile._shared_weight_tensor.is_cuda, + "shared tensor should be on CUDA") + + # -- ConstantStep / Inference: CUDA binding -------------------------------- + # + # CudaAnalogTile.set_shared_weights() works for normal forward, but + # corrupts the is_perfect forward path. Verify that: + # - ConstantStep and Inference (is_perfect=False) ARE bound on CUDA + # - Inference with is_perfect=True is NOT bound (known C++ issue) + + def test_shared_binding_for_constant_step_cuda(self): + """ConstantStep tile on CUDA should have shared weight binding.""" + self._skip_if_no_cuda() + rpu = SingleRPUConfig(device=ConstantStepDevice()) + model = AnalogLinear(4, 6, bias=False, rpu_config=rpu).cuda() + tile = next(model.analog_tiles()) + self.assertIsNotNone(tile._shared_weight_tensor, + "ConstantStep CUDA should bind shared weights") + self.assertTrue(tile._shared_weight_tensor.is_cuda) + + def test_shared_binding_for_inference_default_cuda(self): + """Inference (is_perfect=False) on CUDA should have shared binding.""" + self._skip_if_no_cuda() + model = AnalogLinear(4, 6, bias=False, + rpu_config=InferenceRPUConfig()).cuda() + tile = next(model.analog_tiles()) + self.assertIsNotNone(tile._shared_weight_tensor, + "Inference (default) CUDA should bind shared weights") + + def test_no_shared_binding_for_is_perfect_cuda(self): + """Inference with is_perfect=True should NOT bind shared weights.""" + self._skip_if_no_cuda() + rpu = InferenceRPUConfig() + rpu.forward.is_perfect = True + model = AnalogLinear(4, 6, bias=False, rpu_config=rpu).cuda() + tile = next(model.analog_tiles()) + self.assertIsNone(tile._shared_weight_tensor, + "is_perfect=True should skip shared binding") + + def test_inference_is_perfect_forward_nonzero_cuda(self): + """Inference tile with is_perfect=True should produce non-zero output.""" + self._skip_if_no_cuda() + rpu = InferenceRPUConfig() + rpu.forward.is_perfect = True + manual_seed(42) + model = AnalogLinear(4, 2, bias=False, rpu_config=rpu).cuda() + x = randn(1, 4, device="cuda") + out = model(x) + self.assertTrue(out.abs().sum() > 0, + "is_perfect forward should produce non-zero output") + + def test_training_converges_constant_step_cuda(self): + """Training with ConstantStep on CUDA should converge after binding.""" + self._skip_if_no_cuda() + from torch.nn.functional import mse_loss + from aihwkit.optim import AnalogSGD + + rpu = SingleRPUConfig(device=ConstantStepDevice()) + manual_seed(42) + model = AnalogLinear(4, 2, bias=False, rpu_config=rpu).cuda() + x = randn(10, 4, device="cuda") + y = randn(10, 2, device="cuda") + + initial_loss = mse_loss(model(x), y).item() + + opt = AnalogSGD(model.parameters(), lr=0.1) + opt.regroup_param_groups(model) + for _ in range(50): + opt.zero_grad() + loss = mse_loss(model(x), y) + loss.backward() + opt.step() + + final_loss = mse_loss(model(x), y).item() + self.assertLess(final_loss, initial_loss, + "training should reduce loss on CUDA ConstantStep") + + def test_training_converges_inference_is_perfect_cuda(self): + """Training Inference+is_perfect on CUDA should converge (no binding).""" + self._skip_if_no_cuda() + from torch.nn.functional import mse_loss + from aihwkit.optim import AnalogSGD + + rpu = InferenceRPUConfig() + rpu.forward.is_perfect = True + manual_seed(4321) + model = AnalogLinear(4, 2, bias=False, rpu_config=rpu).cuda() + x = randn(10, 4, device="cuda") + y = randn(10, 2, device="cuda") + + initial_loss = mse_loss(model(x), y).item() + + opt = AnalogSGD(model.parameters(), lr=0.5) + opt.regroup_param_groups(model) + for _ in range(100): + opt.zero_grad() + loss = mse_loss(model(x), y) + loss.backward() + opt.step() + + final_loss = mse_loss(model(x), y).item() + self.assertLess(final_loss, initial_loss, + "training should reduce loss for Inference CUDA") + + # -- Transposed layout verification --------------------------------------- + + def test_cuda_get_weights_cuda_binding(self): + """get_weights_cuda() binding should return [x_size, d_size] CUDA tensor.""" + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + + self.assertTrue(hasattr(tile.tile, "get_weights_cuda"), + "CudaFloatingPointTile should have get_weights_cuda binding") + out = tile.tile.get_weights_cuda() + d = tile.tile.get_d_size() + x = tile.tile.get_x_size() + self.assertTrue(out.is_cuda, "get_weights_cuda() should return CUDA tensor") + self.assertEqual(out.shape[0], x, + f"dim 0 should be x_size={x}") + self.assertEqual(out.shape[1], d, + f"dim 1 should be d_size={d}") + # Values must match get_weights() (which returns [d_size, x_size] on CPU). + w = tile.tile.get_weights() + self.assertTrue(allclose(out.t().cpu(), w), + "get_weights_cuda().t().cpu() should match get_weights()") + + def test_get_tile_weights_ref_returns_transposed_view_for_cuda(self): + """_get_tile_weights_ref should return a CUDA tensor in standard (d_size, x_size) layout. + + CUDA C++ tiles store weights in transposed layout (x_size, d_size). + _get_tile_weights_ref uses get_weights_cuda().t() so callers see the + standard (d_size, x_size) shape without a CPU round-trip. + """ + self._skip_if_no_cuda() + model = AnalogLinear(8, 4, bias=False, + rpu_config=FloatingPointRPUConfig()).cuda() + tile = next(model.analog_tiles()) + + ref = tile._get_tile_weights_ref() + # Should be a CUDA tensor, not a CPU copy. + self.assertTrue(ref.is_cuda, + "_get_tile_weights_ref should return CUDA tensor") + # Shape should be standard (d_size, x_size), not transposed. + w = tile.tile.get_weights() + self.assertEqual(ref.shape, w.shape, + "ref shape should match get_weights() shape") + # Values should match. + self.assertTrue(allclose(ref.cpu(), w), + "_get_tile_weights_ref should match get_weights()") + + # -- Raw tile .to('cuda') ------------------------------------------------- + + def test_raw_tile_to_cuda_shared_binding(self): + """Raw tile (not AnalogLinear) should retain shared binding after .to('cuda'). + + Reproduces the scenario where a tile is constructed directly via + rpu.get_default_tile_module_class() and moved with .to('cuda'). + """ + self._skip_if_no_cuda() + rpu = SingleRPUConfig(device=ConstantStepDevice()) + cls = rpu.get_default_tile_module_class(16, 8) + tile = cls(16, 8, rpu, False) + + # CPU: shared binding from __init__ + self.assertIsNotNone(tile._shared_weight_tensor) + r1 = tile._get_tile_weights_ref() + r2 = tile._get_tile_weights_ref() + self.assertEqual(r1.data_ptr(), r2.data_ptr(), + "CPU: _get_tile_weights_ref should return same ptr") + + # .to('cuda'): shared binding must survive + tile = tile.to("cuda") + self.assertIsNotNone(tile._shared_weight_tensor, + "shared tensor should exist after .to('cuda')") + self.assertTrue(tile._shared_weight_tensor.is_cuda) + r1 = tile._get_tile_weights_ref() + r2 = tile._get_tile_weights_ref() + self.assertEqual(r1.data_ptr(), r2.data_ptr(), + "CUDA: _get_tile_weights_ref should return same ptr") diff --git a/tests/test_layers_linear.py b/tests/test_layers_linear.py index c87f6982..8c4ad88a 100644 --- a/tests/test_layers_linear.py +++ b/tests/test_layers_linear.py @@ -133,8 +133,12 @@ def test_seed(self): weight1, bias1 = layer1.get_weights() weight2, bias2 = layer2.get_weights() + if self.use_cuda: + weight1, weight2 = weight1.cpu(), weight2.cpu() assert_array_almost_equal(weight1, weight2) if bias1 is not None: + if self.use_cuda: + bias1, bias2 = bias1.cpu(), bias2.cpu() assert_array_almost_equal(bias1, bias2) def test_several_analog_layers(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 790e8b36..e1603583 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -108,14 +108,18 @@ def train_model(model, loss_func, x_b, y_b): @staticmethod def get_layer_and_tile_weights(model): """Return the weights and biases of the model and the tile and whether - it automatically syncs""" + it automatically syncs + + Note: All the weights and biases are detached and converted to numpy format.""" if isinstance(model, AnalogLinearMapped): weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if isinstance(model, AnalogConv2dMapped): weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if model.weight is not None: @@ -123,6 +127,7 @@ def get_layer_and_tile_weights(model): else: # we do not sync anymore weight, bias = model.get_weights() + weight, bias = weight.detach().cpu().numpy(), bias.detach().cpu().numpy() return weight, bias, weight, bias, True if model.bias is not None: @@ -408,6 +413,7 @@ def test_save_load_model_cross_device(self): self.assertIsInstance(new_analog_tile.analog_ctx.analog_tile, analog_tile.__class__) self.assertTrue(new_analog_tile.is_cuda != analog_tile.is_cuda) + self.assertTrue(new_analog_tile.device.type == map_location) if analog_tile.shared_weights is not None: self.assertTrue(new_analog_tile.shared_weights.device.type == map_location) @@ -879,7 +885,11 @@ def test_load_state_dict_conversion(self): state1 = new_state_dict[key] state2 = state_dict[key] - assert_array_almost_equal(state1["analog_tile_weights"], state2["analog_tile_weights"]) + weights1 = state1["analog_tile_weights"] + weights2 = state2["analog_tile_weights"] + if self.use_cuda: + weights1, weights2 = weights1.cpu(), weights2.cpu() + assert_array_almost_equal(weights1, weights2) # assert_array_almost_equal(state1['analog_alpha_scale'], # state2['analog_alpha_scale'])