Skip to content

Commit 8b8f3a0

Browse files
committed
Warmup for attn fix
1 parent 92eeda9 commit 8b8f3a0

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

vllm_spyre/v1/attention/backends/spyre.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(
169169
raise NotImplementedError("Encoder self-attention and "
170170
"encoder/decoder cross-attention "
171171
"are not implemented for "
172-
"PallasAttentionBackendImpl")
172+
"SpyreSDPABackendImpl")
173173

174174
def forward(
175175
self,

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,10 @@ def _prepare_decode(
387387
input_masks=self._mask,
388388
is_prompt=False,
389389
)
390-
self.attn_metadata = self.attn_metadata_builder.build(input)
390+
if self.warmup_mode:
391+
self.attn_metadata = None
392+
else:
393+
self.attn_metadata = self.attn_metadata_builder.build(input)
391394
return input
392395

393396
def _update_position_ids(self) -> None:
@@ -479,16 +482,12 @@ def execute_model(
479482
self.model.indices = self.input_batch.get_model_indices()
480483

481484
print("Attn_metadata:", self.attn_metadata)
482-
if self.warmup_mode:
483-
self.attn_metadata = None
484485

485-
with set_forward_context(self.attn_metadata, self.vllm_config, 0):
486+
with set_forward_context(self.attn_metadata, self.vllm_config, virtual_engine=0,):
486487
# Execute the model
487-
hidden_states = self.model(
488+
hidden_or_intermediate_states = self.model(
488489
input_ids=model_input.input_tokens,
489490
positions=model_input.input_positions,
490-
#masks=model_input.input_masks,
491-
#intermediate_tensors=None,
492491
is_prompt=model_input.is_prompt,
493492
)
494493

@@ -497,7 +496,7 @@ def execute_model(
497496
return []
498497

499498
# Compute the logits.
500-
logits = self.model.compute_logits(hidden_states, None)
499+
logits = self.model.compute_logits(hidden_or_intermediate_states, None)
501500

502501
# Sample the next token.
503502
output: SamplerOutput = self.model.sample(

0 commit comments

Comments
 (0)