Skip to content

dispatch_attention_fn silently ignores attn_mask for certain backends #12605

@zzlol63

Description

@zzlol63

Describe the bug

When using the attention backends feature, it's possible to pass an attention mask to the dispatch_attention_fn to a backend which does not support them (ie. attn_mask is not in the method signature registered to that particular backend), and have no error thrown.

This may cause unintended side-effects during training or inference where the attention will be computed as a full-attention scenario.

This is not validated either at the as part of the optional checks used for debugging purposes (DIFFUSERS_ATTN_CHECKS=yes)

Reproduction

import os
os.environ["DIFFUSERS_ATTN_CHECKS"] = "yes"

import torch
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend

query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
attn_mask = torch.zeros((1, 1, 10, 10), dtype=torch.bool, device="cuda")

with attention_backend("native"):
    output = dispatch_attention_fn(query, key, value, attn_mask)
    output2 = dispatch_attention_fn(query, key, value)

    assert not torch.equal(output, output2), "native: These outputs should not be equal!"

with attention_backend("flash"):
    output = dispatch_attention_fn(query, key, value, attn_mask)
    output2 = dispatch_attention_fn(query, key, value)

    assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
Traceback (most recent call last):
  File "C:\repro.py", line 22, in <module>
    assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: flash: These outputs should not be equal!

System Info

  • 🤗 Diffusers version: 0.35.2
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.1.2
  • Transformers version: not installed
  • Accelerate version: not installed
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions