Skip to content

Commit d7f97b2

Browse files
committed
keep attention simple
1 parent 76fd76a commit d7f97b2

File tree

2 files changed

+20
-74
lines changed

2 files changed

+20
-74
lines changed

vllm_spyre/platform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
5959
block_size: int, use_v1: bool,
6060
use_mla: bool) -> str:
6161
logger.info("Using Torch SDPA backend.")
62-
#return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
63-
#return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend"
6462
return ("vllm_spyre.v1.attention.backends.spyre.SpyreSDPABackend")
6563

6664
@classmethod

vllm_spyre/v1/attention/backends/spyre.py

Lines changed: 20 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_impl_cls() -> Type["SpyreSDPABackendImpl"]:
3333
return SpyreSDPABackendImpl
3434

3535
@staticmethod
36-
def get_metadata_cls() -> Type["AttentionMetadata"]:
36+
def get_metadata_cls() -> Type["SpyreSDPAMetadata"]:
3737
return SpyreSDPAMetadata
3838

3939
@staticmethod
@@ -51,13 +51,14 @@ def get_kv_cache_shape(
5151
num_kv_heads: int,
5252
head_size: int,
5353
) -> Tuple[int, ...]:
54-
if block_size % 16 != 0:
55-
raise ValueError("Block size must be a multiple of 16.")
5654
return (2, num_blocks, block_size, num_kv_heads, head_size)
5755

56+
# Should also define swap, copy methods for spyre
57+
5858
@dataclass
5959
class SpyreSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
60-
seq_lens: Optional[List[int]] = None # Non-chunked prefill
60+
# Assuming non-chunked prefill
61+
seq_lens: Optional[List[int]] = None
6162
seq_lens_tensor: Optional[List[int]]
6263

6364
@property
@@ -154,16 +155,14 @@ def __init__(
154155

155156
assert self.num_heads % self.num_kv_heads == 0
156157
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
157-
supported_head_sizes = PagedAttention.get_supported_head_sizes()
158-
if head_size not in supported_head_sizes:
159-
raise ValueError(
160-
f"Head size {head_size} is not supported by PagedAttention. "
161-
f"Supported head sizes are: {supported_head_sizes}.")
162-
163-
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
164-
raise NotImplementedError(
165-
"Spyre SDPA backend FP8 KV cache requires "
166-
"intel_extension_for_pytorch support.")
158+
159+
# Check for supported head sizes
160+
if alibi_slopes is not None:
161+
raise NotImplementedError("Alibi slopes is not supported.")
162+
if kv_cache_dtype != "auto":
163+
raise NotImplementedError("FP8 KV cache dtype is not supported.")
164+
if blocksparse_params is not None:
165+
raise NotImplementedError("Blocksparse is not supported.")
167166
self.attn_type = attn_type
168167
if attn_type != AttentionType.DECODER:
169168
raise NotImplementedError("Encoder self-attention and "
@@ -187,7 +186,7 @@ def forward(
187186
query: shape = [num_tokens, num_heads * head_size]
188187
key: shape = [num_tokens, num_kv_heads * head_size]
189188
value: shape = [num_tokens, num_kv_heads * head_size]
190-
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
189+
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
191190
NOTE: kv_cache will be an empty tensor with shape [0]
192191
for profiling run.
193192
attn_metadata: Metadata for attention.
@@ -210,63 +209,12 @@ def forward(
210209
else:
211210
assert value is None
212211

213-
key_cache = kv_cache
214-
value_cache = kv_cache
215-
#print("kv cache: ", kv_cache.shape)
216-
if kv_cache.numel() > 0:
217-
key_cache, value_cache = PagedAttention.split_kv_cache(
218-
kv_cache, self.num_kv_heads, self.head_size)
219-
220-
if (key is not None) and (value is not None):
221-
PagedAttention.write_to_paged_cache(
222-
key, value, key_cache, value_cache, attn_metadata.slot_mapping,
223-
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
224-
225-
# Decoder self-attention supports chunked prefill.
226-
# Encoder/decoder cross-attention requires no chunked
227-
# prefill (100% prefill or 100% decode tokens, no mix)
228-
num_prefill_tokens = attn_metadata.num_prefill_tokens
229-
num_decode_tokens = attn_metadata.num_decode_tokens
230-
231-
if attn_type == AttentionType.DECODER:
232-
# Only enforce this shape-constraint for decoder
233-
# self-attention
234-
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
235-
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
236-
237-
output = torch.empty_like(query)
238-
if prefill_meta := attn_metadata.prefill_metadata:
239-
assert attn_metadata.seq_lens is not None
240-
self._run_sdpa_forward(output,
241-
query,
242-
key,
243-
value,
244-
prefill_meta,
245-
attn_type=attn_type)
246-
247-
if decode_meta := attn_metadata.decode_metadata:
248-
assert attn_type != AttentionType.ENCODER_ONLY, (
249-
"Encoder-only models should not have decode metadata.")
250-
# Decoding run.
251-
seq_lens_arg = attn_metadata.seq_lens_tensor
252-
max_seq_len_arg = attn_metadata.max_decode_seq_len
253-
block_tables_arg = attn_metadata.block_tables
254-
255-
PagedAttention.forward_decode(
256-
output[attn_metadata.num_prefill_tokens:, :, :],
257-
query[attn_metadata.num_prefill_tokens:, :, :],
258-
key_cache,
259-
value_cache,
260-
block_tables_arg,
261-
seq_lens_arg,
262-
max_seq_len_arg,
263-
self.kv_cache_dtype,
264-
self.num_kv_heads,
265-
self.scale,
266-
None,
267-
layer._k_scale,
268-
layer._v_scale,
269-
)
212+
self._run_sdpa_forward(output,
213+
query,
214+
key,
215+
value,
216+
attn_metadata,
217+
attn_type=attn_type)
270218

271219
# Reshape the output tensor.
272220
return output.view(-1, self.num_heads * self.head_size)

0 commit comments

Comments
 (0)