-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreference.py
More file actions
26 lines (23 loc) · 796 Bytes
/
reference.py
File metadata and controls
26 lines (23 loc) · 796 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Reference implementation -- PyTorch ground truth. DO NOT MODIFY."""
import torch
import torch.nn.functional as F
def flash_attention_ref(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
causal: bool = True,
sm_scale: float | None = None,
) -> torch.Tensor:
"""Standard scaled dot-product attention. Used as the oracle."""
if sm_scale is None:
sm_scale = Q.shape[-1] ** -0.5
attn = torch.matmul(Q, K.transpose(-2, -1)) * sm_scale
if causal:
seq_q, seq_k = Q.shape[-2], K.shape[-2]
mask = torch.triu(
torch.ones(seq_q, seq_k, device=Q.device, dtype=torch.bool),
diagonal=1,
)
attn = attn.masked_fill(mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
return torch.matmul(attn, V)