99from dataclasses import dataclass
1010from typing import Any , Dict , List , Optional , Tuple , Type
1111
12+ import habana_frameworks .torch as htorch
1213import torch
1314import vllm_hpu_extension .kernels as kernels
1415import vllm_hpu_extension .ops as ops
@@ -146,7 +147,21 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
146147 conv_state_indices : Optional [torch .Tensor ] = None
147148 mamba_cache_decode_indices : Optional [torch .Tensor ] = None
148149 mamba_cache_prefill_indices : Optional [torch .Tensor ] = None
149-
150+ decode_slot_mapping : Optional [torch .Tensor ] = None
151+ decode_block_list : Optional [torch .Tensor ] = None
152+ decode_attn_bias : Optional [torch .Tensor ] = None
153+ chunk_prefill_enabled : bool = False
154+
155+ class HPUAttentionData :
156+ query : torch .Tensor = None
157+ key : torch .Tensor = None
158+ value : torch .Tensor = None
159+ key_cache : torch .Tensor = None
160+ value_cache : torch .Tensor = None
161+ batch_size : int = 0
162+ seq_len : int = 0
163+ hidden_size : int = 0
164+ seq_len_kv : int = 0
150165
151166@dataclass
152167class HPUMLAMetadata (HPUAttentionMetadata , AttentionMetadata ):
@@ -461,6 +476,193 @@ def _maybe_init_alibi_biases(
461476 dtype = self .alibi_slopes .dtype ,
462477 )
463478
479+ def preprocess_forward (self , query : torch .Tensor , key : torch .Tensor ,
480+ value : torch .Tensor , kv_cache : torch .Tensor ,
481+ attn_metadata : HPUAttentionMetadata ,
482+ is_prefill : bool ) -> HPUAttentionData :
483+ attn_data : HPUAttentionData = HPUAttentionData ()
484+ seq_len = 1
485+ slot_mapping = attn_metadata .decode_slot_mapping .flatten (
486+ ) if attn_metadata .decode_slot_mapping is not None else None
487+ batch_size = attn_metadata .num_decode_tokens
488+ if is_prefill :
489+ seq_len = attn_metadata .num_prefill_tokens // \
490+ attn_metadata .num_prefills
491+ slot_mapping = attn_metadata .slot_mapping .flatten (
492+ ) if attn_metadata .slot_mapping is not None else None
493+ batch_size = attn_metadata .num_prefills
494+ # Convert Flat inputs into 2D Inputs
495+ hidden_size = query .shape [- 1 ]
496+ query = query .reshape (batch_size , seq_len , hidden_size )
497+
498+ hidden_size = key .shape [- 1 ]
499+ key = key .reshape (batch_size , seq_len , hidden_size )
500+
501+ hidden_size = value .shape [- 1 ]
502+ value = value .reshape (batch_size , seq_len , hidden_size )
503+
504+ # Insert key and value to kv cache
505+ attn_data .batch_size , attn_data .seq_len , attn_data .hidden_size \
506+ = query .shape
507+ _ , attn_data .seq_len_kv , _ = key .shape
508+ query = query .view (- 1 , self .num_heads , self .head_size )
509+ key = key .view (- 1 , self .num_kv_heads , self .head_size )
510+ value = value .view (- 1 , self .num_kv_heads , self .head_size )
511+
512+ if kv_cache is not None :
513+ key_cache , value_cache = HPUPagedAttention .split_kv_cache (
514+ kv_cache , self .num_kv_heads , self .head_size )
515+
516+ # Reshape the input keys and values and store them in the cache.
517+ # If kv_cache is not provided, the new key and value tensors are
518+ # not cached. This happens during the initial memory profiling run.
519+
520+ attn_data .key_cache = self .k_cache (key ,
521+ key_cache ,
522+ slot_mapping )
523+ attn_data .value_cache = self .v_cache (value ,
524+ value_cache ,
525+ slot_mapping )
526+ attn_data .key = key
527+ attn_data .value = value
528+ attn_data .query = query
529+ return attn_data
530+
531+ def forward_chunked_prefill (
532+ self ,
533+ layer : AttentionLayer ,
534+ query : torch .Tensor ,
535+ key : torch .Tensor ,
536+ value : torch .Tensor ,
537+ kv_cache : torch .Tensor ,
538+ attn_metadata : HPUAttentionMetadata ,
539+ output : Optional [torch .Tensor ] = None ,
540+ ) -> torch .Tensor :
541+ """Forward pass with xFormers and PagedAttention.
542+
543+ Args:
544+ query: shape = [num_tokens, num_heads * head_size]
545+ key: shape = [num_tokens, num_kv_heads * head_size]
546+ value: shape = [num_tokens, num_kv_heads * head_size]
547+ kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
548+ attn_metadata: Metadata for attention.
549+ Returns:
550+ shape = [num_tokens, num_heads * head_size]
551+ """
552+ assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
553+
554+ if (self .attn_type != AttentionType .DECODER ):
555+ raise NotImplementedError ("Chunked Prefill Enabled"
556+ "only for Decoder" )
557+ prompt_output : torch .Tensor = None
558+ decode_output : torch .Tensor = None
559+ prefill_batch_size = 0
560+ prefill_seq_len = 0
561+ prefill_hidden_size = 0
562+ decode_batch_size = 0
563+ decode_seq_len = 0
564+ decode_hidden_size = 0
565+ if attn_metadata .num_prefills > 0 :
566+ attn_data = self .preprocess_forward (
567+ query [:attn_metadata .num_prefill_tokens ],
568+ key [:attn_metadata .num_prefill_tokens ],
569+ value [:attn_metadata .num_prefill_tokens ], kv_cache ,
570+ attn_metadata , True )
571+ # Prompt run.
572+ prefill_batch_size = attn_data .batch_size
573+ prefill_seq_len = attn_data .seq_len
574+ prefill_hidden_size = attn_data .hidden_size
575+ query_shape = (prefill_batch_size , prefill_seq_len , self .num_heads ,
576+ self .head_size )
577+ kv_shape = (prefill_batch_size , attn_data .seq_len_kv ,
578+ self .num_kv_heads , self .head_size )
579+
580+ if attn_metadata is None or attn_metadata .block_list is None :
581+
582+ block_list = attn_metadata .block_list if attn_metadata \
583+ and attn_metadata .block_list is not None else None
584+
585+ common_args = self .common_attention_args (block_list , attn_data .key_cache ,
586+ attn_data .value_cache ,
587+ attn_metadata .block_size )
588+ attn_bias = attn_metadata .attn_bias
589+ position_bias = None
590+
591+ out = ops .prompt_attention (
592+ impl = self .prefill_impl ,
593+ query = attn_data .query .view (query_shape ),
594+ key = attn_data .key .view (kv_shape ),
595+ value = attn_data .value .view (kv_shape ),
596+ is_causal = True ,
597+ position_bias = position_bias ,
598+ valid_seq_lengths = attn_metadata .seq_lens_tensor ,
599+ ** common_args )
600+
601+ else :
602+ # TODO: enable FusedSDPA
603+ block_list = attn_metadata .block_list if attn_metadata \
604+ and attn_metadata .block_list is not None else None
605+
606+ common_args = self .common_attention_args (block_list , attn_data .key_cache ,
607+ attn_data .value_cache ,
608+ attn_metadata .block_size )
609+ attn_bias = attn_metadata .attn_bias
610+ position_bias = None
611+
612+ out = ops .prompt_attention (
613+ impl = self .prefill_impl ,
614+ query = attn_data .query .view (query_shape ),
615+
616+ key = self .k_cache .fetch_from_cache (
617+ attn_data .key_cache .unflatten (0 , (- 1 , attn_metadata .block_size )),
618+ attn_metadata .block_list ).view (kv_shape ),
619+ value = self .v_cache .fetch_from_cache (
620+ attn_data .value_cache .unflatten (0 , (- 1 , attn_metadata .block_size )),
621+ attn_metadata .block_list ).view (kv_shape ),
622+ is_causal = False ,
623+ attn_bias = attn_bias ,
624+ position_bias = position_bias ,
625+ ** common_args )
626+
627+ prompt_output = out .reshape (prefill_batch_size , prefill_seq_len ,
628+ prefill_hidden_size )
629+ htorch .core .mark_step ()
630+ if attn_metadata .num_decode_tokens > 0 :
631+ # Decoding run.
632+ attn_data = self .preprocess_forward (
633+ query [attn_metadata .num_prefill_tokens :],
634+ key [attn_metadata .num_prefill_tokens :],
635+ value [attn_metadata .num_prefill_tokens :], kv_cache ,
636+ attn_metadata , False )
637+ decode_batch_size = attn_data .batch_size
638+ decode_seq_len = attn_data .seq_len
639+ decode_hidden_size = attn_data .hidden_size
640+ decode_output = HPUPagedAttention .forward_decode (
641+ query = attn_data .query .view (attn_data .batch_size , attn_data .seq_len , attn_data .hidden_size ),
642+ block_mapping = attn_metadata .block_mapping ,
643+ block_bias = attn_metadata .decode_attn_bias ,
644+ block_groups = attn_metadata .block_groups ,
645+ position_bias = None ,
646+ ** self .common_attention_args (attn_metadata .decode_block_list , attn_data .key_cache ,
647+ attn_data .value_cache ,
648+ attn_metadata .block_size ))
649+ htorch .core .mark_step ()
650+ # Reshape the output tensor.
651+ if decode_output is None :
652+ prompt_output = prompt_output .view (
653+ prefill_batch_size , prefill_seq_len , prefill_hidden_size )
654+ return prompt_output
655+ elif prompt_output is None :
656+ return decode_output .view (decode_batch_size * decode_seq_len ,
657+ decode_hidden_size )
658+ else :
659+ prompt_output = prompt_output .view (
660+ prefill_batch_size * prefill_seq_len , prefill_hidden_size )
661+ decode_output = decode_output .view (
662+ decode_batch_size * decode_seq_len , decode_hidden_size )
663+ output = torch .cat ((prompt_output , decode_output ))
664+ htorch .core .mark_step ()
665+ return output
464666 def forward (
465667 self ,
466668 layer : AttentionLayer ,
@@ -483,6 +685,16 @@ def forward(
483685 shape = [num_tokens, num_heads * head_size]
484686 """
485687 assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
688+ if attn_metadata .chunk_prefill_enabled :
689+ return self .forward_chunked_prefill (
690+ layer = layer ,
691+ query = query ,
692+ key = key ,
693+ value = value ,
694+ kv_cache = kv_cache ,
695+ attn_metadata = attn_metadata ,
696+ output = output ,
697+ )
486698 if self .attn_type == AttentionType .ENCODER_DECODER :
487699 return self .forward_encoder_decoder (
488700 query = query ,
@@ -815,3 +1027,4 @@ def _make_decode_alibi_bias(
8151027 per_head_bias .mul_ (alibi_slopes [None , :, None ])
8161028
8171029 return per_head_bias
1030+
0 commit comments