@@ -626,13 +626,13 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
626626 """Last chunked prefill can be scheduled only if there is enough space
627627 in the decode batch, and if all the other spyre-related conditions
628628 are satisfied."""
629-
629+ decoding_requests = [
630+ r for r in self .running if r not in self .ongoing_prefills
631+ ]
630632 max_context_len = self .scheduler_config .max_model_len
631633
632634 # check that there is space in the current decode batch
633- num_running = len (self .running )
634- if request in self .running :
635- num_running -= 1
635+ num_running = len (decoding_requests )
636636 cond1 = num_running + len (self .waiting ) < self .max_num_running_reqs
637637
638638 # calculate new max tkv of the batch given the new sequence joins
@@ -649,7 +649,7 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
649649 # note that the -1 comes from the token we generate during prefill
650650 cond2 = request .max_tokens - 1 <= (max_context_len - new_req_tkv )
651651 # check cond2 for all other sequences in the current decode batch
652- for req in self . running :
652+ for req in decoding_requests :
653653 # current tkv of the (left aligned) decode sequence
654654 dec_req_tkv = n_blocks * self .block_size + \
655655 req .num_computed_tokens % self .block_size
@@ -667,12 +667,12 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
667667 # check that batch size x tkv is smaller than the max supported number
668668 # Note: using max_tkv is a conservative upper bound here. For the
669669 # optimal check we need model runner to return per sequence tkvs
670- cond3 = lambda : self .check_batch_tkv_limit_cp (request = request ,
671- new_req_tkv = new_req_tkv ,
672- n_blocks = n_blocks ,
673- running = self . running ,
674- max_batch_tkv_limit = self .
675- max_batch_tkv_limit )
670+ cond3 = lambda : self .check_batch_tkv_limit_cp (
671+ request = request ,
672+ new_req_tkv = new_req_tkv ,
673+ n_blocks = n_blocks ,
674+ running = decoding_requests ,
675+ max_batch_tkv_limit = self . max_batch_tkv_limit )
676676
677677 return cond1 and cond2 and cond3 ()
678678
0 commit comments