Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/aihwkit/nn/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn import Module, Linear, Conv1d, Conv2d, Conv3d, Sequential

from aihwkit.exceptions import ArgumentError
from aihwkit.optim.context import AnalogContext
from aihwkit.simulator.tiles.module import TileModule
from aihwkit.nn.modules.container import AnalogWrapper
from aihwkit.nn.modules.base import AnalogLayerBase
Expand Down Expand Up @@ -88,6 +89,7 @@ def convert_to_analog(
exclude_modules: Optional[List[str]] = None,
inplace: bool = False,
verbose: bool = False,
readonly: Optional[bool] = None,
) -> Module:
"""Convert a given digital model to its analog counterpart.

Expand Down Expand Up @@ -138,6 +140,13 @@ def convert_to_analog(

verbose: Increase verbosity. Will print converted layers.

readonly: If not ``None``, override the ``readonly`` flag on
every :class:`~aihwkit.optim.context.AnalogContext` after
conversion. When ``True``, in-place weight modifications
are blocked; when ``False``, they are allowed. If ``None``
(default), each tile uses the value from
``rpu_config.mapping.readonly_weights``.

Returns:
Module where all the digital layers are replaced with analog
mapped layers.
Expand Down Expand Up @@ -193,6 +202,7 @@ def convert_to_analog(
exclude_modules,
True,
verbose,
readonly,
)
continue

Expand All @@ -211,6 +221,12 @@ def convert_to_analog(
if ensure_analog_root and not module_name and not isinstance(module, AnalogLayerBase):
module = AnalogWrapper(module)

# Apply global readonly override if specified
if readonly is not None and not module_name:
for param in module.parameters():
if isinstance(param, AnalogContext):
param.readonly = readonly

return module


Expand Down
26 changes: 26 additions & 0 deletions src/aihwkit/nn/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""Base class for adding functionality to analog layers."""
from typing import Any, List, Optional, Tuple, NamedTuple, Union, Generator, Callable, TYPE_CHECKING
from collections import OrderedDict
import warnings

from torch import Tensor
from torch.nn import Parameter
from torch import device as torch_device
from torch import Size as torch_size

from aihwkit.exceptions import ModuleError
from aihwkit.simulator.tiles.module import TileModule
Expand Down Expand Up @@ -244,6 +246,30 @@ class type when setting ``load_rpu_config`` to
``load_rpu_config=False``.

"""
keys_to_delete = []
for name, param in list(state_dict.items()):
if name.endswith("analog_ctx") and param.size() == torch_size([]):
keys_to_delete.append(name)

# For the checkpoint saved by aihwkit before version, 0.9.2, the parameters with
# class `AnalogContext` are saved as empty size tensors since `AnalogContext`,
# which is a derived class of `torch.nn.Parameter`,
# uses an empty Parameter tensor to store the context.
if len(keys_to_delete) > 0:
strict = False
for key in keys_to_delete:
del state_dict[key]
warnings.warn(
"Some parameters in the loaded checkpoint has empty size"
"(param.size() == torch.Size([]))."
"It could happens because of the loaded checkpoint"
"is generated by an older version of aihwkit."
"The parameter is skipped for compatibility reasons."
"The loading mode is set to non-strict."
"It is recommended to re-save the checkpoint with the latest version of aihwkit."
"Related parameters are: {}".format(keys_to_delete)
)

for analog_tile in self.analog_tiles():
analog_tile.set_load_rpu_config_state(load_rpu_config, strict_rpu_config_check)
return super().load_state_dict(state_dict, strict) # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions src/aihwkit/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def reset_parameters(self) -> None:
self.weight, self.bias = self.get_weights() # type: ignore
super().reset_parameters()
self.set_weights(self.weight, self.bias) # type: ignore
# AnalogLinear doesn't support access weight and bias directly, so delete them
del self.weight, self.bias
# delete them manually is necessary since asigning `bias` (a bool) is forbidden
# by torch if self.bias is already a tensor
self.weight, self.bias = None, bias

def forward(self, x_input: Tensor) -> Tensor:
Expand Down
116 changes: 109 additions & 7 deletions src/aihwkit/optim/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,49 @@

# pylint: disable=attribute-defined-outside-init

from contextlib import contextmanager
from typing import Optional, Type, Union, Any, TYPE_CHECKING

from torch import ones, dtype, Tensor, no_grad
from torch import dtype, Tensor, no_grad
from torch.nn import Parameter
from torch import device as torch_device

from aihwkit.optim.weight_view import ReadOnlyWeightView

if TYPE_CHECKING:
from aihwkit.simulator.tiles.base import SimulatorTileWrapper


class AnalogContext(Parameter):
"""Context for analog optimizer."""
"""Context for analog optimizer.

Note: `data` attribution, inherited from `torch.nn.Parameter`, is a tensor of training parameter
If `analog_bias` (which is provided by `analog_tile`) is False,
`data` has the same meaning as `torch.nn.Parameter`
If `analog_bias` (which is provided by `analog_tile`) is True,
The last column of `data` is the `bias` term

Even though it allows us to access the weights directly, always keep in mind that it is used
only for studying propuses. To simulate the real reading, call the `read_weights` method
instead, i.e. given `analog_ctx: AnalogContext`,
estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights()

Similarly, even though this feature allows us to update the weights directly,
always keep in mind that the real RPU devices change their weights only
by "pulse update" method.

Therefore, use the following update methods instead of
writing `data` directly in the analog optimizer:
---
analog_ctx.analog_tile.update(...)
analog_ctx.analog_tile.update_indexed(...)
---

The ``readonly`` flag (default ``True``) causes ``.data`` reads to
return a :class:`~aihwkit.optim.weight_view.ReadOnlyWeightView`
that blocks in-place mutations. Toggle it via the property or the
:meth:`writable` context manager.
"""

def __new__(
cls: Type["AnalogContext"],
Expand All @@ -28,33 +59,104 @@ def __new__(
) -> "AnalogContext":
# pylint: disable=signature-differs
if parameter is None:
weights_ref = analog_tile._get_tile_weights_ref()
return Parameter.__new__(
cls,
data=ones((), device=analog_tile.device, dtype=analog_tile.get_dtype()),
data=weights_ref,
requires_grad=True,
)
# analog_tile.tile can comes from different classes:
# aihwkit.silulator.rpu_base.devices.AnalogTile (C++)
# TorchInferenceTile (Python)
# It stores the "weight" matrix;
# If analog_tile.analog_bias is True, it also stores the "bias" matrix

parameter.__class__ = cls
return parameter

def __init__(
self, analog_tile: "SimulatorTileWrapper", parameter: Optional[Parameter] = None
): # pylint: disable=unused-argument
super().__init__()
self._readonly = self._default_readonly(analog_tile)
self.analog_tile = analog_tile
self.use_torch_update = False
self.use_indexed = False
self.analog_input = [] # type: list
self.analog_grad_output = [] # type: list
self.reset(analog_tile)

# -- readonly flag --------------------------------------------------------

@staticmethod
def _default_readonly(analog_tile: "SimulatorTileWrapper") -> bool:
"""Read the default ``readonly`` setting from ``rpu_config.mapping``."""
rpu_config = getattr(analog_tile, "rpu_config", None)
if rpu_config is not None:
mapping = getattr(rpu_config, "mapping", None)
if mapping is not None:
return getattr(mapping, "readonly_weights", True)
return True

@property
def readonly(self) -> bool:
"""Whether in-place modifications on ``data`` are blocked."""
try:
return object.__getattribute__(self, "_readonly")
except AttributeError:
return True

@readonly.setter
def readonly(self, value: bool) -> None:
self._readonly = value

def __getattribute__(self, name: str) -> Any:
"""Intercept ``.data`` reads: return a :class:`ReadOnlyWeightView`
when ``readonly`` is ``True``, otherwise the raw tensor."""
if name == "data":
raw = super().__getattribute__(name)
try:
readonly = object.__getattribute__(self, "_readonly")
except AttributeError:
return raw
if readonly:
return ReadOnlyWeightView(raw)
return raw
return super().__getattribute__(name)

@contextmanager
def writable(self):
"""Context manager that temporarily allows direct weight modification.

Example::

with analog_ctx.writable():
analog_ctx.data.add_(delta)
# readonly is restored automatically
"""
old = self.readonly
self.readonly = False
try:
yield self
finally:
self.readonly = old

# -- existing API ---------------------------------------------------------

def set_indexed(self, value: bool = True) -> None:
"""Set the context to forward_indexed."""
self.use_indexed = value

def set_data(self, data: Tensor) -> None:
"""Set the data value of the Tensor."""
with no_grad():
self.data.copy_(data)
# Unwrap source if it is a ReadOnlyWeightView so that
# copy_() does not trigger the in-place guard.
if isinstance(data, ReadOnlyWeightView):
data = data.as_writable()
# Access raw data directly (bypassing readonly wrap) to
# preserve storage sharing with the tile weight tensor.
super().__getattribute__("data").copy_(data)

def get_data(self) -> Tensor:
"""Get the data value of the underlying Tensor."""
Expand All @@ -77,11 +179,11 @@ def has_gradient(self) -> bool:
def __copy__(self) -> Parameter:
"""Turn off copying of the pointers. Context will be re-created
when tile is created"""
return Parameter(self.data)
return Parameter(super().__getattribute__("data"))

def __deepcopy__(self, memo: Any) -> Parameter:
"""Turn off deep copying. Context will be re-created when tile is created"""
return Parameter(self.data)
return Parameter(super().__getattribute__("data"))

def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "AnalogContext":
"""Move the context to a cuda device.
Expand All @@ -92,8 +194,8 @@ def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "Analo
Returns:
This context in the specified device.
"""
self.data = self.data.cuda(device) # type: Tensor
if not self.analog_tile.is_cuda:
self.data = self.analog_tile._get_tile_weights_ref() # type: Tensor
self.analog_tile = self.analog_tile.cuda(device)
self.reset(self.analog_tile)
return self
Expand Down
71 changes: 71 additions & 0 deletions src/aihwkit/optim/weight_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Read-only tensor view for analog tile weights."""

from torch import Tensor
from torch.utils._pytree import tree_map


class ReadOnlyWeightView(Tensor):
"""A tensor that shares storage with tile weights but blocks in-place mutations.

All read operations (``size``, ``norm``, ``sum``, indexing, comparisons, etc.)
work transparently because this IS a real tensor sharing the same memory.
In-place write operations raise ``RuntimeError`` with guidance to use the
correct analog tile API.

This class is stateless — it always blocks in-place operations. The policy
of whether to wrap or unwrap is managed by :class:`AnalogContext` via its
``readonly`` flag.
"""

@staticmethod
def __new__(cls, data: Tensor) -> "ReadOnlyWeightView":
"""Create a ReadOnlyWeightView sharing storage with ``data``."""
if isinstance(data, ReadOnlyWeightView):
return data
return Tensor._make_subclass(cls, data)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
func_name = getattr(func, "__name__", "")

# PyTorch convention: in-place ops end with single '_' (add_, mul_, ...)
# Dunder methods (__repr__, __eq__, ...) end with '__' and must pass through
if func_name.endswith("_") and not func_name.endswith("__"):
raise RuntimeError(
f"In-place operation '{func_name}' is not allowed on analog weights. "
f"Analog weights cannot be modified directly — this would bypass "
f"the physical constraints of the analog device.\n"
f" - For programmatic writes: analog_tile.set_weights(new_weight)\n"
f" - For gradient updates: analog_tile.update(x_input, d_input)\n"
f" - To unlock direct access: analog_ctx.readonly = False"
)

# Unwrap to plain Tensor so downstream ops don't propagate our subclass
def unwrap(t):
return t.as_subclass(Tensor) if isinstance(t, ReadOnlyWeightView) else t

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
return func(*args, **kwargs)

def __setitem__(self, key, value):
"""Block item assignment (e.g., ``ctx.data[0, 0] = 999``)."""
raise RuntimeError(
"Direct item assignment on analog weights is not allowed. "
"Use analog_tile.set_weights() instead, "
"or set analog_ctx.readonly = False to unlock direct access."
)

def as_writable(self) -> Tensor:
"""Return the underlying plain Tensor (for internal tile use only).

This removes the read-only guard. Only tile internals should call this.
"""
return self.as_subclass(Tensor)
15 changes: 15 additions & 0 deletions src/aihwkit/simulator/parameters/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ class MappingParameter(_PrintableMixin):
:class:`aihwkit.nn.modules.linear_mapped.AnalogLinearMapped`.
"""

readonly_weights: bool = True
"""Whether the analog context data is read-only.

When ``True`` (default), ``analog_ctx.data`` is wrapped in a
:class:`~aihwkit.optim.weight_view.ReadOnlyWeightView` that blocks
in-place mutations (e.g. ``add_``, ``copy_``, ``__setitem__``).
This prevents users from accidentally modifying analog weights in
a way that bypasses the physical constraints of the analog device.

Set to ``False`` to allow direct weight access for research or
hardware-aware training scenarios. The flag can also be toggled at
runtime via ``analog_ctx.readonly``.
"""

def compatible_with(self, mapping: "MappingParameter") -> bool:
"""Checks compatiblity

Expand All @@ -117,6 +131,7 @@ def compatible_with(self, mapping: "MappingParameter") -> bool:
"weight_scaling_omega",
"weight_scaling_columnwise",
"weight_scaling_lr_compensation",
"readonly_weights",
]:
continue

Expand Down
Loading
Loading