Skip to content

Commit aaa66c1

Browse files
committed
Enable warmup for chunked prefill
Signed-off-by: jkyu <[email protected]>
1 parent 64fdbc6 commit aaa66c1

File tree

2 files changed

+206
-15
lines changed

2 files changed

+206
-15
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,10 @@ def forward_chunked_prefill(
611611
position_bias = None
612612

613613
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
614-
assert prefill_batch_size == 1, "Only batch size 1 is supported for chunked prefill with dynamic block list."
614+
assert prefill_batch_size == 1, (
615+
"Only batch size 1 is supported for chunked prefill "
616+
"with dynamic block list."
617+
)
615618
key_attn = attn_data.key.view(kv_shape)
616619
value_attn = attn_data.value.view(kv_shape)
617620
common_args['need_context'] = True

vllm/worker/hpu_model_runner.py

Lines changed: 202 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,10 @@ def _update_metadata_chunked_prefill(self,
680680
attn_metadata.num_prefills)
681681
attn_bias = None
682682
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
683-
assert batch_size == 1, "Chunked prefill with dynamic block_list only supports batch_size=1"
683+
assert batch_size == 1, (
684+
"Chunked prefill with dynamic block_list "
685+
"only supports bs=1"
686+
)
684687
for i in range(batch_size):
685688
single_attn_bias = self._set_attn_bias_chunked(
686689
int(seq_len), context_lens_t[i], query_lens_t[i], device,
@@ -1864,9 +1867,9 @@ def _prepare_prompt(
18641867

18651868
computed_block_nums = seq_group_metadata.computed_block_nums
18661869
if (self.scheduler_config is not None
1870+
and self.scheduler_config is not None
18671871
and self.scheduler_config.chunked_prefill_enabled
1868-
and not (computed_block_nums is None
1869-
or computed_block_nums == [])):
1872+
and self.cache_config.enable_prefix_caching):
18701873
raise RuntimeError(
18711874
"chunked prefill cannot be used with prefix caching "
18721875
"now.")
@@ -1896,7 +1899,10 @@ def _prepare_prompt(
18961899
# Prefill has chunked before.
18971900
block_table = seq_group_metadata.block_tables[seq_id]
18981901
if envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
1899-
assert context_len % self.block_size == 0, "context len must be multiple of block size in dynamic chunked prefill mode"
1902+
assert context_len % self.block_size == 0, (
1903+
"context len must be multiple of block size in "
1904+
"dynamic chunked prefill mode"
1905+
)
19001906
prefix_blocks = context_len // self.block_size
19011907
prefix_block_tables.append(block_table[:prefix_blocks])
19021908
else:
@@ -2030,11 +2036,16 @@ def _prepare_prompt(
20302036
for _ in range(batch_size_padding))
20312037

20322038
real_num_seqs = len(query_lens)
2033-
if self.scheduler_config.chunked_prefill_enabled and envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
2034-
assert target_query_len <= self.max_num_batched_tokens, f"{target_query_len=} exceeds {self.max_num_batched_tokens=} for chunked prefill"
2039+
bs = len(seq_group_metadata_list)
2040+
if (self.scheduler_config.chunked_prefill_enabled
2041+
and envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT):
2042+
assert target_query_len <= self.max_num_batched_tokens, (
2043+
f"{target_query_len=} exceeds "
2044+
f"{self.max_num_batched_tokens=} for chunked prefill"
2045+
)
2046+
20352047
max_prompt_len = self.max_num_batched_tokens
20362048
else:
2037-
bs = len(seq_group_metadata_list)
20382049
if bs > 1 and self.use_merged_prefill:
20392050
bs = 1
20402051
max_prompt_len = max(
@@ -2069,7 +2080,9 @@ def _prepare_prompt(
20692080
if self.vllm_config.cache_config.enable_prefix_caching:
20702081
assert self.scheduler_config.max_num_prefill_seqs == 1
20712082
assert bs == 1, (
2072-
"Prefix caching or chunked prefill with multiple sequences is not supported yet.")
2083+
"Prefix caching or chunked prefill with multiple sequences "
2084+
"is not supported yet."
2085+
)
20732086
# prefix caching or chunked prefill
20742087

20752088
max_num_block = max(len(bt) for bt in prefix_block_tables)
@@ -2079,9 +2092,10 @@ def _prepare_prompt(
20792092
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
20802093
for bt in prefix_block_tables))
20812094

2082-
if self.scheduler_config.chunked_prefill_enabled and not envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT:
2083-
if max_prompt_len < max_num_block * self.block_size:
2084-
max_prompt_len = max_num_block * self.block_size
2095+
if (self.scheduler_config.chunked_prefill_enabled
2096+
and not envs.VLLM_HPU_CHUNKED_PREFILL_DYNAMIC_INPUT
2097+
and max_prompt_len < max_num_block * self.block_size):
2098+
max_prompt_len = max_num_block * self.block_size
20852099
pad_len = len(prefix_block_list)
20862100
prefix_block_list = pad_list(prefix_block_list, pad_len,
20872101
_PAD_BLOCK_ID)
@@ -2765,6 +2779,16 @@ def prepare_input_tensors(
27652779
prefill_reqs.append(seq_group_meta)
27662780
else:
27672781
decode_reqs.append(seq_group_meta)
2782+
if self.scheduler_config.enable_chunked_prefill and len(decode_reqs) != 0:
2783+
decode_reqs, real_decode_batch_size, decode_batch_size_padded = (
2784+
self._add_dummy_seq(decode_reqs, False, align_worker))
2785+
seq_group_metadata_list=[]
2786+
if len(prefill_reqs) != 0:
2787+
for req in prefill_reqs:
2788+
seq_group_metadata_list.append(req)
2789+
for req in decode_reqs:
2790+
seq_group_metadata_list.append(req)
2791+
batch_size_padded = len(seq_group_metadata_list)
27682792

27692793
# Prepare input tensors.
27702794
(
@@ -3571,6 +3595,145 @@ def warmup_scenario(self,
35713595
if not is_dummy_run:
35723596
gc.collect()
35733597

3598+
def warmup_scenario_mix(self,
3599+
batch_size,
3600+
seq_len,
3601+
ctx,
3602+
is_prompt,
3603+
kv_caches,
3604+
is_pt_profiler_run=False,
3605+
is_lora_profile_run=False,
3606+
temperature=0,
3607+
img_args=None,
3608+
num_iters=3,
3609+
align_worker=False,
3610+
is_dummy_run=False) -> None:
3611+
phase = 'mix'
3612+
use_graphs = is_dummy_run or self._use_graphs(batch_size, seq_len)
3613+
buckets = self.bucketing_manager.decode_buckets
3614+
num_candidates = len(buckets)
3615+
for idx, (decode_bs, _, decode_ctx) in enumerate(reversed(buckets)):
3616+
scenario_name = ("warmup_"
3617+
f"{phase}_"
3618+
f"prefill_bs{batch_size}_"
3619+
f"prefill_seq{seq_len}_"
3620+
f"prefill_ctx{ctx}_"
3621+
f"decode_bs{decode_bs}_"
3622+
f"decode_ctx{decode_ctx}_"
3623+
f"graphs{'T' if use_graphs else 'F'}")
3624+
3625+
self.log_warmup(f"Graph/{'mix'}/{'decode'}", idx, num_candidates, decode_bs, 1,
3626+
decode_ctx)
3627+
dummy_lora_requests: List[LoRARequest] = []
3628+
dummy_lora_requests_per_seq: List[LoRARequest] = []
3629+
if self.lora_config and is_lora_profile_run:
3630+
assert self.lora_manager is not None
3631+
with self.lora_manager.dummy_lora_cache():
3632+
for idx in range(self.lora_config.max_loras):
3633+
lora_id = idx + 1
3634+
dummy_lora_request = LoRARequest(
3635+
lora_name=f"warmup_{lora_id}",
3636+
lora_int_id=lora_id,
3637+
lora_local_path="/not/a/real/path",
3638+
)
3639+
self.lora_manager.add_dummy_lora(dummy_lora_request,
3640+
rank=LORA_WARMUP_RANK)
3641+
dummy_lora_requests.append(dummy_lora_request)
3642+
dummy_lora_requests_per_seq = [
3643+
dummy_lora_requests[idx % len(dummy_lora_requests)]
3644+
for idx in range(batch_size)
3645+
]
3646+
self.profiler.start('internal', scenario_name)
3647+
times = num_iters if use_graphs or is_pt_profiler_run else 1
3648+
seqs=[]
3649+
seqs_prefill = self.create_dummy_seq_group_metadata(
3650+
0,
3651+
seq_len + ctx * self.block_size,
3652+
True,
3653+
lora_request=dummy_lora_requests_per_seq[i]
3654+
if dummy_lora_requests_per_seq else None,
3655+
img_args=img_args,
3656+
temperature=temperature,
3657+
ctx=ctx)
3658+
3659+
seqs.append(seqs_prefill)
3660+
blocks: list[int] = [decode_ctx // decode_bs for _ in range(decode_bs)]
3661+
blocks[0] += decode_ctx % decode_bs
3662+
for i, b in enumerate(blocks):
3663+
seqs_decode = self.create_dummy_seq_group_metadata(
3664+
i, # type: ignore[has-type]
3665+
b * self.block_size - 1,
3666+
False,
3667+
lora_request=dummy_lora_requests_per_seq[i]
3668+
if dummy_lora_requests_per_seq else None,
3669+
temperature=temperature,
3670+
ctx=decode_ctx)
3671+
seqs.append(seqs_decode)
3672+
3673+
if not is_dummy_run:
3674+
torch.hpu.synchronize()
3675+
profiler = None
3676+
if is_pt_profiler_run and self.is_driver_worker:
3677+
profiler = setup_profiler()
3678+
profiler.start()
3679+
for time_index in range(times):
3680+
inputs = self.prepare_model_input_align_worker(
3681+
seqs, align_worker=align_worker)
3682+
# Chendi: Necessary fix for warmup with TP>1
3683+
if time_index == 0:
3684+
if self.is_driver_worker:
3685+
broadcast_tensor_dict(
3686+
{"input_tokens": inputs.input_tokens}, src=0)
3687+
else:
3688+
broadcast_tensor_dict(src=0)
3689+
if self._is_fla_model():
3690+
self.add_fla_dummy_data(inputs)
3691+
if is_prompt or self.is_single_step:
3692+
intermediate_tensors = None
3693+
if not get_pp_group().is_first_rank:
3694+
intermediate_tensors = \
3695+
self.model.make_empty_intermediate_tensors(
3696+
batch_size=batch_size,
3697+
context_size=seq_len if is_prompt else 1,
3698+
dtype=self.model_config.dtype,
3699+
device=self.device)
3700+
self.execute_model(inputs,
3701+
kv_caches,
3702+
intermediate_tensors=intermediate_tensors,
3703+
warmup_mode=True,
3704+
ctx_blocks=ctx,
3705+
is_dummy_run=is_dummy_run,
3706+
is_pt_profiler_run=is_pt_profiler_run)
3707+
else: # decode with multi-step
3708+
inputs = dataclasses.replace(inputs,
3709+
is_first_multi_step=True,
3710+
is_last_step=False)
3711+
self.execute_model(inputs,
3712+
kv_caches,
3713+
warmup_mode=True,
3714+
num_steps=2,
3715+
seqs=seqs,
3716+
ctx_blocks=ctx)
3717+
inputs = dataclasses.replace(inputs,
3718+
is_first_multi_step=False,
3719+
is_last_step=True)
3720+
self.execute_model(inputs,
3721+
kv_caches,
3722+
warmup_mode=True,
3723+
num_steps=2,
3724+
seqs=seqs,
3725+
ctx_blocks=ctx)
3726+
if not is_dummy_run:
3727+
torch.hpu.synchronize()
3728+
if profiler:
3729+
profiler.step()
3730+
if profiler:
3731+
profiler.stop()
3732+
self.profiler.end()
3733+
if not is_dummy_run:
3734+
gc.collect()
3735+
3736+
35743737
def remove_all_loras(self):
35753738
if not self.lora_manager:
35763739
raise RuntimeError("LoRA is not enabled.")
@@ -3665,6 +3828,30 @@ def warmup_graphs(self,
36653828
total_mem += used_mem
36663829
total_batch_seq += batch_seq
36673830

3831+
if self.scheduler_config.chunked_prefill_enabled and is_prompt:
3832+
for idx, (batch_size, query_len, ctx) in enumerate(reversed(buckets)):
3833+
# Graph memory usage is proportional to seq dimension in a batch
3834+
phase = f"Graph/{'mix'}/{'prompt'}"
3835+
seq_len = query_len + ctx * self.block_size
3836+
batch_seq = batch_size * seq_len
3837+
self.log_warmup(phase, idx, num_candidates, batch_size, query_len,
3838+
ctx)
3839+
with HabanaMemoryProfiler() as mem_prof:
3840+
self.warmup_scenario_mix(
3841+
batch_size,
3842+
query_len,
3843+
ctx,
3844+
is_prompt,
3845+
kv_caches,
3846+
temperature=1.0
3847+
if batch_size not in warmed_random_sampler_bs else 0,
3848+
)
3849+
warmed_random_sampler_bs.add(batch_size)
3850+
used_mem = align_workers(mem_prof.consumed_device_memory,
3851+
torch.distributed.ReduceOp.MAX)
3852+
total_mem += used_mem
3853+
total_batch_seq += batch_seq
3854+
36683855
if is_prompt and self.is_mm_run():
36693856
#For multimodal total_batch_seq and total_mem, we store it in the
36703857
#attribute for now.
@@ -4110,9 +4297,10 @@ def _phase(self, attn_metadata):
41104297
def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41114298
warmup_mode):
41124299
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
4300+
is_chunked_prefill = self.vllm_config.scheduler_config.enable_chunked_prefill
41134301
cfg: Optional[tuple] = None
41144302
assert cfg is None, "Configs changed between 2D and 3D"
4115-
if is_prefix_caching:
4303+
if is_prefix_caching or is_chunked_prefill:
41164304
phase = self._phase(attn_metadata)
41174305
num_blocks = self._num_blocks(attn_metadata)
41184306
cfg = (batch_size, seq_len, num_blocks, phase)
@@ -4123,8 +4311,8 @@ def _check_config(self, batch_size, seq_len, ctx, attn_metadata,
41234311
self.seen_configs.add(cfg)
41244312
if not seen and not warmup_mode:
41254313
logger.warning("Configuration: %s was not warmed-up!",
4126-
(phase.value, batch_size, seq_len,
4127-
num_blocks) if is_prefix_caching else
4314+
(phase.value, batch_size, seq_len, num_blocks)
4315+
if is_prefix_caching or is_chunked_prefill else
41284316
(phase, batch_size, seq_len))
41294317

41304318
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

0 commit comments

Comments
 (0)