@@ -387,7 +387,10 @@ def _prepare_decode(
387387 input_masks = self ._mask ,
388388 is_prompt = False ,
389389 )
390- self .attn_metadata = self .attn_metadata_builder .build (input )
390+ if self .warmup_mode :
391+ self .attn_metadata = None
392+ else :
393+ self .attn_metadata = self .attn_metadata_builder .build (input )
391394 return input
392395
393396 def _update_position_ids (self ) -> None :
@@ -479,16 +482,12 @@ def execute_model(
479482 self .model .indices = self .input_batch .get_model_indices ()
480483
481484 print ("Attn_metadata:" , self .attn_metadata )
482- if self .warmup_mode :
483- self .attn_metadata = None
484485
485- with set_forward_context (self .attn_metadata , self .vllm_config , 0 ):
486+ with set_forward_context (self .attn_metadata , self .vllm_config , virtual_engine = 0 , ):
486487 # Execute the model
487- hidden_states = self .model (
488+ hidden_or_intermediate_states = self .model (
488489 input_ids = model_input .input_tokens ,
489490 positions = model_input .input_positions ,
490- #masks=model_input.input_masks,
491- #intermediate_tensors=None,
492491 is_prompt = model_input .is_prompt ,
493492 )
494493
@@ -497,7 +496,7 @@ def execute_model(
497496 return []
498497
499498 # Compute the logits.
500- logits = self .model .compute_logits (hidden_states , None )
499+ logits = self .model .compute_logits (hidden_or_intermediate_states , None )
501500
502501 # Sample the next token.
503502 output : SamplerOutput = self .model .sample (
0 commit comments