Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 289 additions & 3 deletions flash-attn2/tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from flash_attn2.bert_padding import pad_input, unpad_input
from flash_attn2.flash_attn_interface import _get_block_size_n

from flash_attn2.layers.rotary import apply_rotary_emb

MAX_HEADDIM_SM8x = 192

Expand Down Expand Up @@ -1965,9 +1965,295 @@ def test_flash_attn_splitkv(
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4


# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [1])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None])
@pytest.mark.parametrize("has_leftpad", [False, True])
# @pytest.mark.parametrize("has_leftpad", [True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
has_leftpad,
paged_kv_block_size,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
alibi,
new_kv,
mha_type,
num_splits,
dtype,
device,
):
if device == "cpu":
pytest.skip("kvcache not supported on CPU")
if device == "xpu":
if alibi:
pytest.skip("alibi not supported on xpu currently")
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
if has_batch_idx and paged_kv_block_size is not None:
pytest.skip()
if has_leftpad and paged_kv_block_size is not None:
pytest.skip()

# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 6
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
if paged_kv_block_size is None:
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None
else:
(
k_cache,
v_cache,
block_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
)
cache_seqlens = torch.randint(
0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1)
),
(batch_size,),
dtype=torch.int32,
device=device,
)
if has_leftpad:
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
for i in range(batch_size)])
else:
cache_leftpad = None
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_leftpad:
key_padding_mask = torch.logical_and(
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
)
if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size
]
else:
cache_batch_idx = None
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
)
else:
alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0:
angle = (
torch.rand(
seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if causal or local:
q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=cache_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
out = flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
k,
v,
rotary_cos=cos,
rotary_sin=sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
block_table=block_table,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
alibi_slopes=alibi_slopes,
num_splits=num_splits,
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
attn_bias,
0.0,
None,
causal=causal,
window_size=window_size,
key_leftpad=cache_leftpad,
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
attn_bias,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
key_leftpad=cache_leftpad,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

# NOTE: test_flash_attn_kvcache was removed because it depended on
# flash_attn2.layers.rotary (upstream baggage not shipped with this kernel).
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
if paged_kv_block_size is None:
k_cache_select = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else:
k_cache_select = rearrange(
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
mult = 3 if not alibi else 5
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5


def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
Expand Down
21 changes: 1 addition & 20 deletions flash-attn2/torch-ext/flash_attn2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List, Optional

from typing import Optional, List
import torch

from ._ops import ops as flash_attn_ops
from .flash_attn_interface import (
flash_attn_func,
Expand All @@ -13,23 +11,6 @@
flash_attn_with_kvcache,
)

__all__ = [
# High-level API (with autograd support)
"flash_attn_func",
"flash_attn_kvpacked_func",
"flash_attn_qkvpacked_func",
"flash_attn_varlen_func",
"flash_attn_varlen_kvpacked_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_with_kvcache",
# Low-level ops (no autograd)
"fwd",
"varlen_fwd",
"bwd",
"varlen_bwd",
"fwd_kvcache",
]


def fwd(
q: torch.Tensor,
Expand Down
Loading