@@ -626,13 +626,13 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
626626 """Last chunked prefill can be scheduled only if there is enough space
627627 in the decode batch, and if all the other spyre-related conditions
628628 are satisfied."""
629-
629+ decoding_requests = [
630+ r for r in self .running if r not in self .ongoing_prefills
631+ ]
630632 max_context_len = self .scheduler_config .max_model_len
631633
632634 # check that there is space in the current decode batch
633- num_running = len (self .running )
634- if request in self .running :
635- num_running -= 1
635+ num_running = len (decoding_requests )
636636 cond1 = num_running + len (self .waiting ) < self .max_num_running_reqs
637637
638638 # calculate new max tkv of the batch given the new sequence joins
@@ -649,7 +649,7 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
649649 # note that the -1 comes from the token we generate during prefill
650650 cond2 = request .max_tokens - 1 <= (max_context_len - new_req_tkv )
651651 # check cond2 for all other sequences in the current decode batch
652- for req in self . running :
652+ for req in decoding_requests :
653653 # current tkv of the (left aligned) decode sequence
654654 dec_req_tkv = n_blocks * self .block_size + \
655655 req .num_computed_tokens % self .block_size
@@ -670,7 +670,7 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
670670 cond3 = lambda : self .check_batch_tkv_limit_cp (request = request ,
671671 new_req_tkv = new_req_tkv ,
672672 n_blocks = n_blocks ,
673- running = self . running ,
673+ running = decoding_requests ,
674674 max_batch_tkv_limit = self .
675675 max_batch_tkv_limit )
676676
0 commit comments