Skip to content

Commit 7a60181

Browse files
YuJiankangjzhoulon
andcommitted
Enable chunked prefill on aice 1.22
Co-authored-by: Jiang, Zhoulong <[email protected]> Signed-off-by: jkyu <[email protected]>
1 parent 58b1f0b commit 7a60181

File tree

4 files changed

+473
-31
lines changed

4 files changed

+473
-31
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
os.environ["VLLM_SKIP_WARMUP"] = "true"
3+
os.environ['VLLM_CONTIGUOUS_PA'] = 'false'
4+
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
5+
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
6+
os.environ['PT_HPU_WEIGHT_SHARING']='0'
7+
os.environ['VLLM_MLA_PERFORM_MATRIX_ABSORPTION']='0'
8+
os.environ['VLLM_MTP_PRINT_ACCPET_RATE']='0'
9+
os.environ['PT_HPU_LAZY_MODE']='1'
10+
os.environ['VLLM_DELAYED_SAMPLING']='false'
11+
#os.environ['VLLM_USE_V1']='1'
12+
13+
14+
if __name__ == "__main__":
15+
16+
from vllm import LLM, SamplingParams
17+
18+
# Sample prompts.
19+
prompts = [
20+
"Hello, my name is",
21+
"The president of the United States is",
22+
"The capital of France is",
23+
"The future of AI is",
24+
]
25+
# Create a sampling params object.
26+
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
27+
28+
model_name = "/home/HF_models/llama-3-8b"
29+
llm = LLM(model=model_name,
30+
trust_remote_code=True,
31+
enforce_eager=True,
32+
dtype="bfloat16",
33+
use_v2_block_manager=True,
34+
tensor_parallel_size=1,
35+
max_model_len=1024,
36+
num_scheduler_steps=1,
37+
gpu_memory_utilization=0.5,
38+
enable_chunked_prefill=True,
39+
max_num_batched_tokens=128,
40+
seed=2024)
41+
# Generate texts from the prompts. The output is a list of RequestOutput objects
42+
# that contain the prompt, generated text, and other information.
43+
outputs = llm.generate(prompts, sampling_params)
44+
# Print the outputs.
45+
for output in outputs:
46+
prompt = output.prompt
47+
generated_text = output.outputs[0].text
48+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
49+

vllm/attention/backends/hpu_attn.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclasses import dataclass
1010
from typing import Any, Dict, List, Optional, Tuple, Type
1111

12+
import habana_frameworks.torch as htorch
1213
import torch
1314
import vllm_hpu_extension.kernels as kernels
1415
import vllm_hpu_extension.ops as ops
@@ -146,7 +147,21 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
146147
conv_state_indices: Optional[torch.Tensor] = None
147148
mamba_cache_decode_indices: Optional[torch.Tensor] = None
148149
mamba_cache_prefill_indices: Optional[torch.Tensor] = None
149-
150+
decode_slot_mapping: Optional[torch.Tensor] = None
151+
decode_block_list: Optional[torch.Tensor] = None
152+
decode_attn_bias: Optional[torch.Tensor] = None
153+
chunk_prefill_enabled: bool = False
154+
155+
class HPUAttentionData:
156+
query: torch.Tensor = None
157+
key: torch.Tensor = None
158+
value: torch.Tensor = None
159+
key_cache: torch.Tensor = None
160+
value_cache: torch.Tensor = None
161+
batch_size: int = 0
162+
seq_len: int = 0
163+
hidden_size: int = 0
164+
seq_len_kv: int = 0
150165

151166
@dataclass
152167
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
@@ -461,6 +476,193 @@ def _maybe_init_alibi_biases(
461476
dtype=self.alibi_slopes.dtype,
462477
)
463478

479+
def preprocess_forward(self, query: torch.Tensor, key: torch.Tensor,
480+
value: torch.Tensor, kv_cache: torch.Tensor,
481+
attn_metadata: HPUAttentionMetadata,
482+
is_prefill: bool) -> HPUAttentionData:
483+
attn_data: HPUAttentionData = HPUAttentionData()
484+
seq_len = 1
485+
slot_mapping = attn_metadata.decode_slot_mapping.flatten(
486+
) if attn_metadata.decode_slot_mapping is not None else None
487+
batch_size = attn_metadata.num_decode_tokens
488+
if is_prefill:
489+
seq_len = attn_metadata.num_prefill_tokens //\
490+
attn_metadata.num_prefills
491+
slot_mapping = attn_metadata.slot_mapping.flatten(
492+
) if attn_metadata.slot_mapping is not None else None
493+
batch_size = attn_metadata.num_prefills
494+
# Convert Flat inputs into 2D Inputs
495+
hidden_size = query.shape[-1]
496+
query = query.reshape(batch_size, seq_len, hidden_size)
497+
498+
hidden_size = key.shape[-1]
499+
key = key.reshape(batch_size, seq_len, hidden_size)
500+
501+
hidden_size = value.shape[-1]
502+
value = value.reshape(batch_size, seq_len, hidden_size)
503+
504+
# Insert key and value to kv cache
505+
attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size\
506+
= query.shape
507+
_, attn_data.seq_len_kv, _ = key.shape
508+
query = query.view(-1, self.num_heads, self.head_size)
509+
key = key.view(-1, self.num_kv_heads, self.head_size)
510+
value = value.view(-1, self.num_kv_heads, self.head_size)
511+
512+
if kv_cache is not None:
513+
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
514+
kv_cache, self.num_kv_heads, self.head_size)
515+
516+
# Reshape the input keys and values and store them in the cache.
517+
# If kv_cache is not provided, the new key and value tensors are
518+
# not cached. This happens during the initial memory profiling run.
519+
520+
attn_data.key_cache = self.k_cache(key,
521+
key_cache,
522+
slot_mapping)
523+
attn_data.value_cache = self.v_cache(value,
524+
value_cache,
525+
slot_mapping)
526+
attn_data.key = key
527+
attn_data.value = value
528+
attn_data.query = query
529+
return attn_data
530+
531+
def forward_chunked_prefill(
532+
self,
533+
layer: AttentionLayer,
534+
query: torch.Tensor,
535+
key: torch.Tensor,
536+
value: torch.Tensor,
537+
kv_cache: torch.Tensor,
538+
attn_metadata: HPUAttentionMetadata,
539+
output: Optional[torch.Tensor] = None,
540+
) -> torch.Tensor:
541+
"""Forward pass with xFormers and PagedAttention.
542+
543+
Args:
544+
query: shape = [num_tokens, num_heads * head_size]
545+
key: shape = [num_tokens, num_kv_heads * head_size]
546+
value: shape = [num_tokens, num_kv_heads * head_size]
547+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
548+
attn_metadata: Metadata for attention.
549+
Returns:
550+
shape = [num_tokens, num_heads * head_size]
551+
"""
552+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
553+
554+
if (self.attn_type != AttentionType.DECODER):
555+
raise NotImplementedError("Chunked Prefill Enabled"
556+
"only for Decoder")
557+
prompt_output: torch.Tensor = None
558+
decode_output: torch.Tensor = None
559+
prefill_batch_size = 0
560+
prefill_seq_len = 0
561+
prefill_hidden_size = 0
562+
decode_batch_size = 0
563+
decode_seq_len = 0
564+
decode_hidden_size = 0
565+
if attn_metadata.num_prefills > 0:
566+
attn_data = self.preprocess_forward(
567+
query[:attn_metadata.num_prefill_tokens],
568+
key[:attn_metadata.num_prefill_tokens],
569+
value[:attn_metadata.num_prefill_tokens], kv_cache,
570+
attn_metadata, True)
571+
# Prompt run.
572+
prefill_batch_size = attn_data.batch_size
573+
prefill_seq_len = attn_data.seq_len
574+
prefill_hidden_size = attn_data.hidden_size
575+
query_shape = (prefill_batch_size, prefill_seq_len, self.num_heads,
576+
self.head_size)
577+
kv_shape = (prefill_batch_size, attn_data.seq_len_kv,
578+
self.num_kv_heads, self.head_size)
579+
580+
if attn_metadata is None or attn_metadata.block_list is None:
581+
582+
block_list = attn_metadata.block_list if attn_metadata \
583+
and attn_metadata.block_list is not None else None
584+
585+
common_args = self.common_attention_args(block_list, attn_data.key_cache,
586+
attn_data.value_cache,
587+
attn_metadata.block_size)
588+
attn_bias = attn_metadata.attn_bias
589+
position_bias = None
590+
591+
out = ops.prompt_attention(
592+
impl=self.prefill_impl,
593+
query=attn_data.query.view(query_shape),
594+
key=attn_data.key.view(kv_shape),
595+
value=attn_data.value.view(kv_shape),
596+
is_causal=True,
597+
position_bias=position_bias,
598+
valid_seq_lengths=attn_metadata.seq_lens_tensor,
599+
**common_args)
600+
601+
else:
602+
# TODO: enable FusedSDPA
603+
block_list = attn_metadata.block_list if attn_metadata \
604+
and attn_metadata.block_list is not None else None
605+
606+
common_args = self.common_attention_args(block_list, attn_data.key_cache,
607+
attn_data.value_cache,
608+
attn_metadata.block_size)
609+
attn_bias = attn_metadata.attn_bias
610+
position_bias = None
611+
612+
out = ops.prompt_attention(
613+
impl=self.prefill_impl,
614+
query=attn_data.query.view(query_shape),
615+
616+
key=self.k_cache.fetch_from_cache(
617+
attn_data.key_cache.unflatten(0, (-1, attn_metadata.block_size)),
618+
attn_metadata.block_list).view(kv_shape),
619+
value=self.v_cache.fetch_from_cache(
620+
attn_data.value_cache.unflatten(0, (-1, attn_metadata.block_size)),
621+
attn_metadata.block_list).view(kv_shape),
622+
is_causal=False,
623+
attn_bias=attn_bias,
624+
position_bias=position_bias,
625+
**common_args)
626+
627+
prompt_output = out.reshape(prefill_batch_size, prefill_seq_len,
628+
prefill_hidden_size)
629+
htorch.core.mark_step()
630+
if attn_metadata.num_decode_tokens > 0:
631+
# Decoding run.
632+
attn_data = self.preprocess_forward(
633+
query[attn_metadata.num_prefill_tokens:],
634+
key[attn_metadata.num_prefill_tokens:],
635+
value[attn_metadata.num_prefill_tokens:], kv_cache,
636+
attn_metadata, False)
637+
decode_batch_size = attn_data.batch_size
638+
decode_seq_len = attn_data.seq_len
639+
decode_hidden_size = attn_data.hidden_size
640+
decode_output = HPUPagedAttention.forward_decode(
641+
query=attn_data.query.view(attn_data.batch_size, attn_data.seq_len, attn_data.hidden_size),
642+
block_mapping=attn_metadata.block_mapping,
643+
block_bias=attn_metadata.decode_attn_bias,
644+
block_groups=attn_metadata.block_groups,
645+
position_bias=None,
646+
**self.common_attention_args(attn_metadata.decode_block_list, attn_data.key_cache,
647+
attn_data.value_cache,
648+
attn_metadata.block_size))
649+
htorch.core.mark_step()
650+
# Reshape the output tensor.
651+
if decode_output is None:
652+
prompt_output = prompt_output.view(
653+
prefill_batch_size , prefill_seq_len, prefill_hidden_size)
654+
return prompt_output
655+
elif prompt_output is None:
656+
return decode_output.view(decode_batch_size * decode_seq_len,
657+
decode_hidden_size)
658+
else:
659+
prompt_output = prompt_output.view(
660+
prefill_batch_size * prefill_seq_len, prefill_hidden_size)
661+
decode_output = decode_output.view(
662+
decode_batch_size * decode_seq_len, decode_hidden_size)
663+
output = torch.cat((prompt_output, decode_output))
664+
htorch.core.mark_step()
665+
return output
464666
def forward(
465667
self,
466668
layer: AttentionLayer,
@@ -483,6 +685,16 @@ def forward(
483685
shape = [num_tokens, num_heads * head_size]
484686
"""
485687
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
688+
if attn_metadata.chunk_prefill_enabled:
689+
return self.forward_chunked_prefill(
690+
layer=layer,
691+
query=query,
692+
key=key,
693+
value=value,
694+
kv_cache=kv_cache,
695+
attn_metadata=attn_metadata,
696+
output=output,
697+
)
486698
if self.attn_type == AttentionType.ENCODER_DECODER:
487699
return self.forward_encoder_decoder(
488700
query=query,
@@ -815,3 +1027,4 @@ def _make_decode_alibi_bias(
8151027
per_head_bias.mul_(alibi_slopes[None, :, None])
8161028

8171029
return per_head_bias
1030+

vllm/engine/arg_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,8 @@ def create_engine_config(
12071207
if speculative_config is None \
12081208
else speculative_config.num_lookahead_slots
12091209

1210+
if self.enable_chunked_prefill:
1211+
self.use_padding_aware_scheduling = False
12101212
scheduler_config = SchedulerConfig(
12111213
runner_type=model_config.runner_type,
12121214
max_num_batched_tokens=self.max_num_batched_tokens,

0 commit comments

Comments
 (0)