Skip to content

define data attribution for AnalogContext#717

Closed
Zhaoxian-Wu wants to merge 10 commits intoIBM:masterfrom
Zhaoxian-Wu:master
Closed

define data attribution for AnalogContext#717
Zhaoxian-Wu wants to merge 10 commits intoIBM:masterfrom
Zhaoxian-Wu:master

Conversation

@Zhaoxian-Wu
Copy link
Copy Markdown

@Zhaoxian-Wu Zhaoxian-Wu commented Mar 12, 2025

Related issues

NA

Description

In the current version, the data attribution is replaced by a dummy tensor, leading to confusiing behavior (related code). For example, when accessing the size of analog context (which is a subclass of Pytorch Parameter and Tensor), the output is an empty size, while we expect it outputs the correct size.

from aihwkit.nn import AnalogLinear

model = AnalogLinear(6, 4, False)

analog_ctx = next(model.parameters())
print(f'analog_ctx.size() = {analog_ctx.size()}')
# ============ output ============ 
# analog_ctx.size() = torch.Size([])
# ============ expected output ============ 
# analog_ctx.size() = torch.Size([4, 6])

Details

To ensure all related components work correctly, I adapt a serious of code.
Furthermore, I adopt some conventions to make the style uniformly across the library and try to follow the PyTorch style.

1. AnalogContext is a valid torch.nn.Parameter

To support the correct behavior, I bind the weight matrix from the binding analog tile by using self.data = analog_tile.tile.get_weights().

I understand that this feature introduces an additional degree of freedom for programmers and it is different from the real weights reading mechanism. So I add some comments to encourage the users to read and write the weights in "realistic" way in the comments.

2. [convention] get_weights and set_weights inplace

In the TileModuleArray class, we adapt the convention that all the reading and writing are done in-place.

In get_weights, we don't convert the Tensor back to np.Array automatically or move back to cpu, i.e., we do

return self.weight.data.detach().cpu()

instead of

return weight.data

Since the movement across devices is supposed to done explicitly by user to avoid confusion.

In set_weights, we do

self.weight.data = self.weight.data.copy_(weight)

instead of

self.weight.data = weight.clone().to(self.weight.device)

3. [convention] Run-time computing of device and is_cuda

For a analog tile class (any subclass of SimulatorTileWrapper), the device and is_cuda is in essence the one of its analog_ctx attribute.
In case someone replaces the tile in analog_ctx without remembering to update them, I introduce two properties function to ensure the consistency.

@property
def device(self) -> torch_device:
    """Return the device of the tile."""
    return self.analog_ctx.device

@property
def is_cuda(self) -> bool:
    """Return whether the tile is on CUDA."""
    return self.analog_ctx.is_cuda

Let me know if I miss anything from the library design perspective.

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the master branch 2 times, most recently from 2ef4d52 to 586450f Compare March 18, 2025 18:44
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
Signed-off-by: Zhaoxian-Wu <wuzhx23@mail2.sysu.edu.cn>
@PabloCarmona PabloCarmona added enhancement New feature or request design Related to improvement in the code design labels Mar 21, 2025
@PabloCarmona
Copy link
Copy Markdown
Collaborator

@Zhaoxian-Wu can you please update and sync the branch with master? We are gonna proceed with the review of it. Thank you for the enhancement!

@anu-pub anu-pub requested a review from maljoras March 21, 2025 15:35
@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hi @PabloCarmona , I sync the latest update, feel free to let any comment or discussion

Copy link
Copy Markdown
Contributor

@maljoras-sony maljoras-sony left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the contribution, @Zhaoxian-Wu !

However, I don't quite see the need for these low-level changes and I am not convinced that all parts of the code base will actually still work. Considering all the various cases for the RPUCuda part of the code is a bit tricky, since the weight is stored in the C++ part and only a copy is returned to the user when get_weights is called. This in part is by design, as we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality. While some tiles will work with your formulation, the ctx.data will be out-of-sync for other and this would be more confusing. The user should not be changing anything related to analog_ctx.

If you just want to make the size of the parameter of analog_ctx reflecting the shape of the analog time, why not simply add a "shape" property or custom size method to the analog context that returns the weight size rather than storing the full weights, which will then become out-of-sync with the actual trained weights?


# Recreate the tile.
self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config)
self.analog_ctx.data = self.tile.get_weights()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic here? Note that this will only the a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data will not be synchronized correctly with the actual weight. Of course the size of the weight will not change, but it will be more confusing of one maintains two different version of the weight which are not synced, or not?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, there could be an out-of-sync concern here. Therefore, I also change the definition of self.tile.get_weights(). So far, the tile will return an original weight instead of a detached tensor here. Since the data here and the actual weight tenser are the same object in essence, there is no sync issue here.

def get_weights(self) -> Tensor:
        """Get the tile weights.
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.
        """
        return self.weight.data

bias = from_numpy(array(bias))

self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device)
self.bias.data.copy_(bias)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this allow for setting the bias with the data type defined in the tile? While it is correct that torch defines the data type of a layer solely by the data type of the weight tensor, I found it more convenient to handle all the specialized code to have a dtype property on tile level (as this is essentially the "analog tensor" ). Do you suggest that the d_type should be removed from the tile, but now determined by the ctx.data tensor dtype?

or ``None`` if the tile is set not to use bias.
"""
return self.weight.data.detach().cpu()
return self.weight.data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the current convention is that get_weights always returns CPU weights. If you want to change this, the RPUCuda get_weights call need to change as well, as they producing CPU weights by default. Moreover, get_weights will always product a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights. Of course, hardware aware training is a special case, but for that we have a separate tile.

@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hi @maljoras-sony , thanks for your careful review and comments.


If you just want to make the size of the parameter of analog_ctx reflecting the shape of the analog time, why not simply add a "shape" property or custom size method to the analog context that returns the weight size rather than storing the full weights, which will then become out-of-sync with the actual trained weights?

The motivation for the enhancement mainly comes from the need for convenient weight access, but the dummy tensor in the AnalogContext incurs some conceptual confusion.

https://github.com/IBM/aihwkit/blob/62522c4c560ce30c5ecfed270a84e23101d7a664/src/aihwkit/optim/context.py#L24C1-L37C25

It may be a solution to define a series of attributions like shape to get the parameter size, but we also have various attributions related to the weights.

One example is the issue analog_ctx.size() = torch.Size([]). On top of that, basically, all the reading operations with respect to the weight have similar confusion:

from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)
analog_ctx = next(model.parameters())

print(analog_ctx.size())             # torch.Size([])
print(analog_ctx.nonzero())      # tensor([], size=(1, 0), dtype=torch.int64)
print(analog_ctx > 10)             # false, expected a boolean array
print(analog_ctx.norm())          # tensor(1., grad_fn=<LinalgVectorNormBackward0>)

Therefore, I believe it will be a straightforward way to adopt a pytorch-style convention. I.e., a Parameter object is a valid Tensor objective automatically. All the property-related methods, like norm(), sum(), or even the boolean operators like >, nonzero(), come from the Tensor class. Therefore, I bind the AnalogContext.data with the real underlined weight.


However, I don't quite see the need for these low-level changes and I am not convinced that all parts of the code base will actually still work. Considering all the various cases for the RPUCuda part of the code is a bit tricky, since the weight is stored in the C++ part and only a copy is returned to the user when get_weights is called. This in part is by design, as we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality. While some tiles will work with your formulation, the ctx.data will be out-of-sync for other and this would be more confusing. The user should not be changing anything related to analog_ctx.

I completely understand your concern. Therefore, I use a relatively safe approach to do that. In the AnalogContext initialization, I bind the data with its underlying weight

self.analog_ctx.data = self.tile.get_weights()

Note that I also make sure the self.tile.get_weights() returns a reference instead of a detached tensor

def get_weights(self) -> Tensor:
        """Get the tile weights.
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.
        """
        return self.weight.data

By this way, even though the weight is updated, the weight can remind synchronous with the underlying weights.

@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Zhaoxian-Wu commented Mar 24, 2025

I found there are still issues with my solution so far in the RPUCuda part, as you pointed out correctly. I may need some time to fix it. I will also consider your concern that the user should not have access to the actual weight unless necessary. However, I still want to retain the flexibility of accessing weight by AnalogContext. My preliminary idea is to

  1. introduce a flag to control whether the user can access the actual weight directly. Users turn on the flag only if they are clear that the access is not a physically available process.
  2. I will ensure the weight in AnalogContext always synchronizes with the actual weight.
  3. The weights returned by the C++ code is always a reference of the actual weights, no matter any movement between devices or update
    How do you feel about the plan? @maljoras-sony

@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @Zhaoxian-Wu, we are getting back to this discussion and wanted to see if you are able or eager to take this back. Sorry for not reaching out promptly, and again, thank you for your interest in the aihwkit!

Also, @maljoras-sony @maljoras could you be interested in taking this back, taking a look, and continuing the discussion where it was left? Thanks!

1 similar comment
@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @Zhaoxian-Wu, we are getting back to this discussion and wanted to see if you are able or eager to take this back. Sorry for not reaching out promptly, and again, thank you for your interest in the aihwkit!

Also, @maljoras-sony @maljoras could you be interested in taking this back, taking a look, and continuing the discussion where it was left? Thanks!

@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hi @PabloCarmona, I am definitely still interested in this issue. Last time I stopped at some unsolved corner cases, and haven't had time to fix them.

@PabloCarmona
Copy link
Copy Markdown
Collaborator

PabloCarmona commented Feb 23, 2026

That sounds great @Zhaoxian-Wu. We will try to make this possible. Please let us know if any help is needed or guidance, and we will see what we can do. Thanks for continuing your interest in this after such a long time!

Zhaoxian-Wu added a commit to Zhaoxian-Wu/aihwkit that referenced this pull request Mar 25, 2026
…nalogCtx tests

- Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of
  self.bias.data (Tensor), which caused TypeError when Conv layers re-assign
  bias as a bool in reset_parameters.
- Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution:
  correct shape, norm, nonzero, comparison ops, CUDA support, backward
  compatibility with old checkpoints, and convert_to_analog.

Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hi @maljoras-sony and @PabloCarmona, thanks for your patience — I know it's been a while since the original review, and I really appreciate you coming back to this discussion. I've taken the time to carefully address all the concerns raised.

@maljoras-sony, your review raised three key concerns, and this update addresses all of them. Feel free to let me know if there remains any concerns.

Concern 1: Out-of-sync weights for C++ tiles

"Note that this will only be a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data will not be synchronized correctly with the actual weight."

Root cause: C++ tiles store weights in C++ memory. get_weights() can only return a copy — any as_ref=True approach that works for Python tiles (previous commit) does not apply here.

Solution — _bind_shared_weights(): At tile construction time, we allocate a torch.Tensor on the Python side and pass it to the C++ tile via set_shared_weights(). After this call, both Python and C++ operate on the same memory:

Python:  analog_ctx.data ──→ _shared_weight_tensor ←── C++ tile internal storage
                              (same data_ptr)
  • tile.update() / tile.set_weights() modify the tensor in-place — no explicit sync needed during normal training
  • Device moves (cpu() / cuda()) invalidate the old pointer and rebind via _bind_shared_weights()
  • __getstate__ / __setstate__ skip the shared tensor (rebuilt on load)

Concern 2: Unintended weight modification

"we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality."

Solution — ReadOnlyWeightView: A torch.Tensor subclass that blocks all in-place mutations. Instead of a fragile blocklist, we use PyTorch's own naming convention — all in-place ops end with _ — so the guard is simply func_name.endswith('_'). This automatically covers any new ops PyTorch adds in the future.

Three levels of control (all optional, default is read-only):

Level API Use case
Per-layer rpu_config.mapping.readonly_weights Fine-grained via specific_rpu_config_fun
Global convert_to_analog(readonly=False) Quick toggle for research
Runtime ctx.writable() context manager Temporary access in a code block
ctx = next(model.parameters())
ctx.data.norm()        # ✅ reads work
ctx.data.add_(1.0)     # ❌ RuntimeError

with ctx.writable():
    ctx.data.add_(1.0) # ✅ explicit opt-in

Concern 3: Breaking the get_weights convention

"the current convention is that get_weights always returns CPU weights... Moreover, get_weights will always produce a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights."

We preserve this convention fully. The public get_weights() API still returns a detached CPU copy by default (as_ref=False). The as_ref=True path is strictly internal — used only by _get_tile_weights_ref() and _bind_shared_weights() to set up the shared storage between analog_ctx.data and the tile. Users calling tile.get_weights() see no behavior change.

# Public API — unchanged, returns detached CPU copy
w, b = analog_tile.get_weights()     # as_ref=False (default)
w.add_(1.0)                          # safe — this is a copy, tile is unaffected

# Internal only — used to bind analog_ctx.data
ref = tile.tile.get_weights(as_ref=True)  # direct reference to tile storage

Test results

  • test_analog_ctx: 123 passed, 27 skipped (pre-existing)
  • Full suite: 3763 passed, 0 regressions

Note: I observed a few pre-existing test failures that also exist on the master branch and are not introduced by this PR (e.g., Conv3dCuda FloatingPoint::test_torch_original_layer and some RNN layer training tests with ~1e-4 numerical mismatches). I suspect these may be related to my GPU/CUDA setup and I may open a separate issue or PR to address them later. My environment: NVIDIA RTX PRO 6000 Blackwell (driver 580.95), CUDA toolkit 12.0, PyTorch 2.8.0+cu128, cuDNN 91002.


@PabloCarmona, could you take a look at these changes when you have a chance? I'd really appreciate any comments or suggestions. If the direction looks good, it would be great to move toward merging — happy to make any further adjustments needed. Thanks!

@Zhaoxian-Wu
Copy link
Copy Markdown
Author

I have created a new PR (#765), since the original PR was developed on the fork's master branch, which made it difficult to keep clean and rebase. It can be closed for clear purpose.

Zhaoxian-Wu added a commit to Zhaoxian-Wu/aihwkit that referenced this pull request Mar 29, 2026
…nalogCtx tests

- Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of
  self.bias.data (Tensor), which caused TypeError when Conv layers re-assign
  bias as a bool in reset_parameters.
- Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution:
  correct shape, norm, nonzero, comparison ops, CUDA support, backward
  compatibility with old checkpoints, and convert_to_analog.

Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@PabloCarmona
Copy link
Copy Markdown
Collaborator

PabloCarmona commented Mar 30, 2026

We close this one since it is out of sync with master and makes it difficult to merge right now, and the author @Zhaoxian-Wu has moved it to a new branch that is clearer and up to date with the current master here: #765

Zhaoxian-Wu added a commit to Zhaoxian-Wu/aihwkit that referenced this pull request Apr 1, 2026
…nalogCtx tests

- Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of
  self.bias.data (Tensor), which caused TypeError when Conv layers re-assign
  bias as a bool in reset_parameters.
- Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution:
  correct shape, norm, nonzero, comparison ops, CUDA support, backward
  compatibility with old checkpoints, and convert_to_analog.

Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

design Related to improvement in the code design enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants