@@ -680,7 +680,10 @@ def _update_metadata_chunked_prefill(self,
680680 attn_metadata .num_prefills )
681681 attn_bias = None
682682 if envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
683- assert batch_size == 1 , "Chunked prefill with dynamic block_list only supports batch_size=1"
683+ assert batch_size == 1 , (
684+ "Chunked prefill with dynamic block_list "
685+ "only supports bs=1"
686+ )
684687 for i in range (batch_size ):
685688 single_attn_bias = self ._set_attn_bias_chunked (
686689 int (seq_len ), context_lens_t [i ], query_lens_t [i ], device ,
@@ -1864,9 +1867,9 @@ def _prepare_prompt(
18641867
18651868 computed_block_nums = seq_group_metadata .computed_block_nums
18661869 if (self .scheduler_config is not None
1870+ and self .scheduler_config is not None
18671871 and self .scheduler_config .chunked_prefill_enabled
1868- and not (computed_block_nums is None
1869- or computed_block_nums == [])):
1872+ and self .cache_config .enable_prefix_caching ):
18701873 raise RuntimeError (
18711874 "chunked prefill cannot be used with prefix caching "
18721875 "now." )
@@ -1896,7 +1899,10 @@ def _prepare_prompt(
18961899 # Prefill has chunked before.
18971900 block_table = seq_group_metadata .block_tables [seq_id ]
18981901 if envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
1899- assert context_len % self .block_size == 0 , "context len must be multiple of block size in dynamic chunked prefill mode"
1902+ assert context_len % self .block_size == 0 , (
1903+ "context len must be multiple of block size in "
1904+ "dynamic chunked prefill mode"
1905+ )
19001906 prefix_blocks = context_len // self .block_size
19011907 prefix_block_tables .append (block_table [:prefix_blocks ])
19021908 else :
@@ -2030,11 +2036,16 @@ def _prepare_prompt(
20302036 for _ in range (batch_size_padding ))
20312037
20322038 real_num_seqs = len (query_lens )
2033- if self .scheduler_config .chunked_prefill_enabled and envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
2034- assert target_query_len <= self .max_num_batched_tokens , f"{ target_query_len = } exceeds { self .max_num_batched_tokens = } for chunked prefill"
2039+ bs = len (seq_group_metadata_list )
2040+ if (self .scheduler_config .chunked_prefill_enabled
2041+ and envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT ):
2042+ assert target_query_len <= self .max_num_batched_tokens , (
2043+ f"{ target_query_len = } exceeds "
2044+ f"{ self .max_num_batched_tokens = } for chunked prefill"
2045+ )
2046+
20352047 max_prompt_len = self .max_num_batched_tokens
20362048 else :
2037- bs = len (seq_group_metadata_list )
20382049 if bs > 1 and self .use_merged_prefill :
20392050 bs = 1
20402051 max_prompt_len = max (
@@ -2069,7 +2080,9 @@ def _prepare_prompt(
20692080 if self .vllm_config .cache_config .enable_prefix_caching :
20702081 assert self .scheduler_config .max_num_prefill_seqs == 1
20712082 assert bs == 1 , (
2072- "Prefix caching or chunked prefill with multiple sequences is not supported yet." )
2083+ "Prefix caching or chunked prefill with multiple sequences "
2084+ "is not supported yet."
2085+ )
20732086 # prefix caching or chunked prefill
20742087
20752088 max_num_block = max (len (bt ) for bt in prefix_block_tables )
@@ -2079,9 +2092,10 @@ def _prepare_prompt(
20792092 ([_PAD_BLOCK_ID ] * (max_num_block - len (bt )))
20802093 for bt in prefix_block_tables ))
20812094
2082- if self .scheduler_config .chunked_prefill_enabled and not envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT :
2083- if max_prompt_len < max_num_block * self .block_size :
2084- max_prompt_len = max_num_block * self .block_size
2095+ if (self .scheduler_config .chunked_prefill_enabled
2096+ and not envs .VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT
2097+ and max_prompt_len < max_num_block * self .block_size ):
2098+ max_prompt_len = max_num_block * self .block_size
20852099 pad_len = len (prefix_block_list )
20862100 prefix_block_list = pad_list (prefix_block_list , pad_len ,
20872101 _PAD_BLOCK_ID )
@@ -2765,6 +2779,16 @@ def prepare_input_tensors(
27652779 prefill_reqs .append (seq_group_meta )
27662780 else :
27672781 decode_reqs .append (seq_group_meta )
2782+ if self .scheduler_config .enable_chunked_prefill and len (decode_reqs ) != 0 :
2783+ decode_reqs , real_decode_batch_size , decode_batch_size_padded = (
2784+ self ._add_dummy_seq (decode_reqs , False , align_worker ))
2785+ seq_group_metadata_list = []
2786+ if len (prefill_reqs ) != 0 :
2787+ for req in prefill_reqs :
2788+ seq_group_metadata_list .append (req )
2789+ for req in decode_reqs :
2790+ seq_group_metadata_list .append (req )
2791+ batch_size_padded = len (seq_group_metadata_list )
27682792
27692793 # Prepare input tensors.
27702794 (
@@ -3571,6 +3595,145 @@ def warmup_scenario(self,
35713595 if not is_dummy_run :
35723596 gc .collect ()
35733597
3598+ def warmup_scenario_mix (self ,
3599+ batch_size ,
3600+ seq_len ,
3601+ ctx ,
3602+ is_prompt ,
3603+ kv_caches ,
3604+ is_pt_profiler_run = False ,
3605+ is_lora_profile_run = False ,
3606+ temperature = 0 ,
3607+ img_args = None ,
3608+ num_iters = 3 ,
3609+ align_worker = False ,
3610+ is_dummy_run = False ) -> None :
3611+ phase = 'mix'
3612+ use_graphs = is_dummy_run or self ._use_graphs (batch_size , seq_len )
3613+ buckets = self .bucketing_manager .decode_buckets
3614+ num_candidates = len (buckets )
3615+ for idx , (decode_bs , _ , decode_ctx ) in enumerate (reversed (buckets )):
3616+ scenario_name = ("warmup_"
3617+ f"{ phase } _"
3618+ f"prefill_bs{ batch_size } _"
3619+ f"prefill_seq{ seq_len } _"
3620+ f"prefill_ctx{ ctx } _"
3621+ f"decode_bs{ decode_bs } _"
3622+ f"decode_ctx{ decode_ctx } _"
3623+ f"graphs{ 'T' if use_graphs else 'F' } " )
3624+
3625+ self .log_warmup (f"Graph/{ 'mix' } /{ 'decode' } " , idx , num_candidates , decode_bs , 1 ,
3626+ decode_ctx )
3627+ dummy_lora_requests : List [LoRARequest ] = []
3628+ dummy_lora_requests_per_seq : List [LoRARequest ] = []
3629+ if self .lora_config and is_lora_profile_run :
3630+ assert self .lora_manager is not None
3631+ with self .lora_manager .dummy_lora_cache ():
3632+ for idx in range (self .lora_config .max_loras ):
3633+ lora_id = idx + 1
3634+ dummy_lora_request = LoRARequest (
3635+ lora_name = f"warmup_{ lora_id } " ,
3636+ lora_int_id = lora_id ,
3637+ lora_local_path = "/not/a/real/path" ,
3638+ )
3639+ self .lora_manager .add_dummy_lora (dummy_lora_request ,
3640+ rank = LORA_WARMUP_RANK )
3641+ dummy_lora_requests .append (dummy_lora_request )
3642+ dummy_lora_requests_per_seq = [
3643+ dummy_lora_requests [idx % len (dummy_lora_requests )]
3644+ for idx in range (batch_size )
3645+ ]
3646+ self .profiler .start ('internal' , scenario_name )
3647+ times = num_iters if use_graphs or is_pt_profiler_run else 1
3648+ seqs = []
3649+ seqs_prefill = self .create_dummy_seq_group_metadata (
3650+ 0 ,
3651+ seq_len + ctx * self .block_size ,
3652+ True ,
3653+ lora_request = dummy_lora_requests_per_seq [i ]
3654+ if dummy_lora_requests_per_seq else None ,
3655+ img_args = img_args ,
3656+ temperature = temperature ,
3657+ ctx = ctx )
3658+
3659+ seqs .append (seqs_prefill )
3660+ blocks : list [int ] = [decode_ctx // decode_bs for _ in range (decode_bs )]
3661+ blocks [0 ] += decode_ctx % decode_bs
3662+ for i , b in enumerate (blocks ):
3663+ seqs_decode = self .create_dummy_seq_group_metadata (
3664+ i , # type: ignore[has-type]
3665+ b * self .block_size - 1 ,
3666+ False ,
3667+ lora_request = dummy_lora_requests_per_seq [i ]
3668+ if dummy_lora_requests_per_seq else None ,
3669+ temperature = temperature ,
3670+ ctx = decode_ctx )
3671+ seqs .append (seqs_decode )
3672+
3673+ if not is_dummy_run :
3674+ torch .hpu .synchronize ()
3675+ profiler = None
3676+ if is_pt_profiler_run and self .is_driver_worker :
3677+ profiler = setup_profiler ()
3678+ profiler .start ()
3679+ for time_index in range (times ):
3680+ inputs = self .prepare_model_input_align_worker (
3681+ seqs , align_worker = align_worker )
3682+ # Chendi: Necessary fix for warmup with TP>1
3683+ if time_index == 0 :
3684+ if self .is_driver_worker :
3685+ broadcast_tensor_dict (
3686+ {"input_tokens" : inputs .input_tokens }, src = 0 )
3687+ else :
3688+ broadcast_tensor_dict (src = 0 )
3689+ if self ._is_fla_model ():
3690+ self .add_fla_dummy_data (inputs )
3691+ if is_prompt or self .is_single_step :
3692+ intermediate_tensors = None
3693+ if not get_pp_group ().is_first_rank :
3694+ intermediate_tensors = \
3695+ self .model .make_empty_intermediate_tensors (
3696+ batch_size = batch_size ,
3697+ context_size = seq_len if is_prompt else 1 ,
3698+ dtype = self .model_config .dtype ,
3699+ device = self .device )
3700+ self .execute_model (inputs ,
3701+ kv_caches ,
3702+ intermediate_tensors = intermediate_tensors ,
3703+ warmup_mode = True ,
3704+ ctx_blocks = ctx ,
3705+ is_dummy_run = is_dummy_run ,
3706+ is_pt_profiler_run = is_pt_profiler_run )
3707+ else : # decode with multi-step
3708+ inputs = dataclasses .replace (inputs ,
3709+ is_first_multi_step = True ,
3710+ is_last_step = False )
3711+ self .execute_model (inputs ,
3712+ kv_caches ,
3713+ warmup_mode = True ,
3714+ num_steps = 2 ,
3715+ seqs = seqs ,
3716+ ctx_blocks = ctx )
3717+ inputs = dataclasses .replace (inputs ,
3718+ is_first_multi_step = False ,
3719+ is_last_step = True )
3720+ self .execute_model (inputs ,
3721+ kv_caches ,
3722+ warmup_mode = True ,
3723+ num_steps = 2 ,
3724+ seqs = seqs ,
3725+ ctx_blocks = ctx )
3726+ if not is_dummy_run :
3727+ torch .hpu .synchronize ()
3728+ if profiler :
3729+ profiler .step ()
3730+ if profiler :
3731+ profiler .stop ()
3732+ self .profiler .end ()
3733+ if not is_dummy_run :
3734+ gc .collect ()
3735+
3736+
35743737 def remove_all_loras (self ):
35753738 if not self .lora_manager :
35763739 raise RuntimeError ("LoRA is not enabled." )
@@ -3665,6 +3828,30 @@ def warmup_graphs(self,
36653828 total_mem += used_mem
36663829 total_batch_seq += batch_seq
36673830
3831+ if self .scheduler_config .chunked_prefill_enabled and is_prompt :
3832+ for idx , (batch_size , query_len , ctx ) in enumerate (reversed (buckets )):
3833+ # Graph memory usage is proportional to seq dimension in a batch
3834+ phase = f"Graph/{ 'mix' } /{ 'prompt' } "
3835+ seq_len = query_len + ctx * self .block_size
3836+ batch_seq = batch_size * seq_len
3837+ self .log_warmup (phase , idx , num_candidates , batch_size , query_len ,
3838+ ctx )
3839+ with HabanaMemoryProfiler () as mem_prof :
3840+ self .warmup_scenario_mix (
3841+ batch_size ,
3842+ query_len ,
3843+ ctx ,
3844+ is_prompt ,
3845+ kv_caches ,
3846+ temperature = 1.0
3847+ if batch_size not in warmed_random_sampler_bs else 0 ,
3848+ )
3849+ warmed_random_sampler_bs .add (batch_size )
3850+ used_mem = align_workers (mem_prof .consumed_device_memory ,
3851+ torch .distributed .ReduceOp .MAX )
3852+ total_mem += used_mem
3853+ total_batch_seq += batch_seq
3854+
36683855 if is_prompt and self .is_mm_run ():
36693856 #For multimodal total_batch_seq and total_mem, we store it in the
36703857 #attribute for now.
@@ -4110,9 +4297,10 @@ def _phase(self, attn_metadata):
41104297 def _check_config (self , batch_size , seq_len , ctx , attn_metadata ,
41114298 warmup_mode ):
41124299 is_prefix_caching = self .vllm_config .cache_config .enable_prefix_caching
4300+ is_chunked_prefill = self .vllm_config .scheduler_config .enable_chunked_prefill
41134301 cfg : Optional [tuple ] = None
41144302 assert cfg is None , "Configs changed between 2D and 3D"
4115- if is_prefix_caching :
4303+ if is_prefix_caching or is_chunked_prefill :
41164304 phase = self ._phase (attn_metadata )
41174305 num_blocks = self ._num_blocks (attn_metadata )
41184306 cfg = (batch_size , seq_len , num_blocks , phase )
@@ -4123,8 +4311,8 @@ def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41234311 self .seen_configs .add (cfg )
41244312 if not seen and not warmup_mode :
41254313 logger .warning ("Configuration: %s was not warmed-up!" ,
4126- (phase .value , batch_size , seq_len ,
4127- num_blocks ) if is_prefix_caching else
4314+ (phase .value , batch_size , seq_len , num_blocks )
4315+ if is_prefix_caching or is_chunked_prefill else
41284316 (phase , batch_size , seq_len ))
41294317
41304318 def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : List [int ],
0 commit comments