@@ -33,7 +33,7 @@ def get_impl_cls() -> Type["SpyreSDPABackendImpl"]:
3333 return SpyreSDPABackendImpl
3434
3535 @staticmethod
36- def get_metadata_cls () -> Type ["AttentionMetadata " ]:
36+ def get_metadata_cls () -> Type ["SpyreSDPAMetadata " ]:
3737 return SpyreSDPAMetadata
3838
3939 @staticmethod
@@ -51,13 +51,14 @@ def get_kv_cache_shape(
5151 num_kv_heads : int ,
5252 head_size : int ,
5353 ) -> Tuple [int , ...]:
54- if block_size % 16 != 0 :
55- raise ValueError ("Block size must be a multiple of 16." )
5654 return (2 , num_blocks , block_size , num_kv_heads , head_size )
5755
56+ # Should also define swap, copy methods for spyre
57+
5858@dataclass
5959class SpyreSDPAMetadata (AttentionMetadata , PagedAttentionMetadata ):
60- seq_lens : Optional [List [int ]] = None # Non-chunked prefill
60+ # Assuming non-chunked prefill
61+ seq_lens : Optional [List [int ]] = None
6162 seq_lens_tensor : Optional [List [int ]]
6263
6364 @property
@@ -154,16 +155,14 @@ def __init__(
154155
155156 assert self .num_heads % self .num_kv_heads == 0
156157 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
157- supported_head_sizes = PagedAttention .get_supported_head_sizes ()
158- if head_size not in supported_head_sizes :
159- raise ValueError (
160- f"Head size { head_size } is not supported by PagedAttention. "
161- f"Supported head sizes are: { supported_head_sizes } ." )
162-
163- if is_quantized_kv_cache (kv_cache_dtype ) and not _use_ipex :
164- raise NotImplementedError (
165- "Spyre SDPA backend FP8 KV cache requires "
166- "intel_extension_for_pytorch support." )
158+
159+ # Check for supported head sizes
160+ if alibi_slopes is not None :
161+ raise NotImplementedError ("Alibi slopes is not supported." )
162+ if kv_cache_dtype != "auto" :
163+ raise NotImplementedError ("FP8 KV cache dtype is not supported." )
164+ if blocksparse_params is not None :
165+ raise NotImplementedError ("Blocksparse is not supported." )
167166 self .attn_type = attn_type
168167 if attn_type != AttentionType .DECODER :
169168 raise NotImplementedError ("Encoder self-attention and "
@@ -187,7 +186,7 @@ def forward(
187186 query: shape = [num_tokens, num_heads * head_size]
188187 key: shape = [num_tokens, num_kv_heads * head_size]
189188 value: shape = [num_tokens, num_kv_heads * head_size]
190- kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
189+ kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
191190 NOTE: kv_cache will be an empty tensor with shape [0]
192191 for profiling run.
193192 attn_metadata: Metadata for attention.
@@ -210,63 +209,12 @@ def forward(
210209 else :
211210 assert value is None
212211
213- key_cache = kv_cache
214- value_cache = kv_cache
215- #print("kv cache: ", kv_cache.shape)
216- if kv_cache .numel () > 0 :
217- key_cache , value_cache = PagedAttention .split_kv_cache (
218- kv_cache , self .num_kv_heads , self .head_size )
219-
220- if (key is not None ) and (value is not None ):
221- PagedAttention .write_to_paged_cache (
222- key , value , key_cache , value_cache , attn_metadata .slot_mapping ,
223- self .kv_cache_dtype , layer ._k_scale , layer ._v_scale )
224-
225- # Decoder self-attention supports chunked prefill.
226- # Encoder/decoder cross-attention requires no chunked
227- # prefill (100% prefill or 100% decode tokens, no mix)
228- num_prefill_tokens = attn_metadata .num_prefill_tokens
229- num_decode_tokens = attn_metadata .num_decode_tokens
230-
231- if attn_type == AttentionType .DECODER :
232- # Only enforce this shape-constraint for decoder
233- # self-attention
234- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
235- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
236-
237- output = torch .empty_like (query )
238- if prefill_meta := attn_metadata .prefill_metadata :
239- assert attn_metadata .seq_lens is not None
240- self ._run_sdpa_forward (output ,
241- query ,
242- key ,
243- value ,
244- prefill_meta ,
245- attn_type = attn_type )
246-
247- if decode_meta := attn_metadata .decode_metadata :
248- assert attn_type != AttentionType .ENCODER_ONLY , (
249- "Encoder-only models should not have decode metadata." )
250- # Decoding run.
251- seq_lens_arg = attn_metadata .seq_lens_tensor
252- max_seq_len_arg = attn_metadata .max_decode_seq_len
253- block_tables_arg = attn_metadata .block_tables
254-
255- PagedAttention .forward_decode (
256- output [attn_metadata .num_prefill_tokens :, :, :],
257- query [attn_metadata .num_prefill_tokens :, :, :],
258- key_cache ,
259- value_cache ,
260- block_tables_arg ,
261- seq_lens_arg ,
262- max_seq_len_arg ,
263- self .kv_cache_dtype ,
264- self .num_kv_heads ,
265- self .scale ,
266- None ,
267- layer ._k_scale ,
268- layer ._v_scale ,
269- )
212+ self ._run_sdpa_forward (output ,
213+ query ,
214+ key ,
215+ value ,
216+ attn_metadata ,
217+ attn_type = attn_type )
270218
271219 # Reshape the output tensor.
272220 return output .view (- 1 , self .num_heads * self .head_size )
0 commit comments