Skip to content

Commit 5f1e4d5

Browse files
committed
fix: check only decoding requests in _satisfies_last_chunk_constraints
Signed-off-by: Travis Johnson <[email protected]>
1 parent 40e0639 commit 5f1e4d5

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

vllm_spyre/v1/core/scheduler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -626,13 +626,13 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
626626
"""Last chunked prefill can be scheduled only if there is enough space
627627
in the decode batch, and if all the other spyre-related conditions
628628
are satisfied."""
629-
629+
decoding_requests = [
630+
r for r in self.running if r not in self.ongoing_prefills
631+
]
630632
max_context_len = self.scheduler_config.max_model_len
631633

632634
# check that there is space in the current decode batch
633-
num_running = len(self.running)
634-
if request in self.running:
635-
num_running -= 1
635+
num_running = len(decoding_requests)
636636
cond1 = num_running + len(self.waiting) < self.max_num_running_reqs
637637

638638
# calculate new max tkv of the batch given the new sequence joins
@@ -649,7 +649,7 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
649649
# note that the -1 comes from the token we generate during prefill
650650
cond2 = request.max_tokens - 1 <= (max_context_len - new_req_tkv)
651651
# check cond2 for all other sequences in the current decode batch
652-
for req in self.running:
652+
for req in decoding_requests:
653653
# current tkv of the (left aligned) decode sequence
654654
dec_req_tkv = n_blocks * self.block_size + \
655655
req.num_computed_tokens % self.block_size
@@ -667,12 +667,12 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
667667
# check that batch size x tkv is smaller than the max supported number
668668
# Note: using max_tkv is a conservative upper bound here. For the
669669
# optimal check we need model runner to return per sequence tkvs
670-
cond3 = lambda: self.check_batch_tkv_limit_cp(request=request,
671-
new_req_tkv=new_req_tkv,
672-
n_blocks=n_blocks,
673-
running=self.running,
674-
max_batch_tkv_limit=self.
675-
max_batch_tkv_limit)
670+
cond3 = lambda: self.check_batch_tkv_limit_cp(
671+
request=request,
672+
new_req_tkv=new_req_tkv,
673+
n_blocks=n_blocks,
674+
running=decoding_requests,
675+
max_batch_tkv_limit=self.max_batch_tkv_limit)
676676

677677
return cond1 and cond2 and cond3()
678678

0 commit comments

Comments
 (0)