-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
Description
Describe the bug
When using the attention backends feature, it's possible to pass an attention mask to the dispatch_attention_fn to a backend which does not support them (ie. attn_mask is not in the method signature registered to that particular backend), and have no error thrown.
This may cause unintended side-effects during training or inference where the attention will be computed as a full-attention scenario.
This is not validated either at the as part of the optional checks used for debugging purposes (DIFFUSERS_ATTN_CHECKS=yes)
Reproduction
import os
os.environ["DIFFUSERS_ATTN_CHECKS"] = "yes"
import torch
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend
query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
attn_mask = torch.zeros((1, 1, 10, 10), dtype=torch.bool, device="cuda")
with attention_backend("native"):
output = dispatch_attention_fn(query, key, value, attn_mask)
output2 = dispatch_attention_fn(query, key, value)
assert not torch.equal(output, output2), "native: These outputs should not be equal!"
with attention_backend("flash"):
output = dispatch_attention_fn(query, key, value, attn_mask)
output2 = dispatch_attention_fn(query, key, value)
assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
Traceback (most recent call last):
File "C:\repro.py", line 22, in <module>
assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: flash: These outputs should not be equal!
System Info
- 🤗 Diffusers version: 0.35.2
- Platform: Windows-11-10.0.26100-SP0
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.1.2
- Transformers version: not installed
- Accelerate version: not installed
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response