@@ -581,7 +581,7 @@ def __init__(
581581 # TO DO: move to InputBatch
582582 self .req_ids2blocks : dict [str , deque [int ]] = {}
583583 self .req_ids2left_pads : dict [str , int ] = {}
584- self .tkv = 0
584+ self .tkv : int = 0
585585 self .free_blocks = deque ([i for i in range (NUM_BLOCKS )])
586586 self .dummy_req_ids2blocks : list [int ] = []
587587
@@ -739,9 +739,6 @@ def _prepare_decode(
739739 dtype = torch .bool ,
740740 device = "cpu" )
741741
742- if envs_spyre .VLLM_SPYRE_RM_PADDED_BLOCKS :
743- self .reduce_left_padding (cached_requests )
744-
745742 for cached_request in cached_requests :
746743 # TODO: Will this always just be one token ID if there's no spec
747744 # or jump decoding?
@@ -818,27 +815,30 @@ def _prepare_decode(
818815 is_prompt = False ,
819816 )
820817
821- def reduce_left_padding (self , requests : list [CachedRequestData ]) -> None :
818+ def reduce_left_padding (self ) -> None :
819+
820+ if len (self .req_ids2left_pads ) == 0 :
821+ return
822822
823- min_left_pad = min (
824- [self .req_ids2left_pads [r .req_id ] for r in requests ])
823+ min_left_pad = min (self .req_ids2left_pads .values ())
825824 n_padded_blocks = min_left_pad // self .BLOCK_SIZE
825+ offset = n_padded_blocks * self .BLOCK_SIZE
826826
827- if n_padded_blocks > 0 :
827+ if offset > 0 :
828828 logger .debug ("Number of removed blocks due to left padding: %d" ,
829829 n_padded_blocks )
830830
831- for req in requests :
832- self .req_ids2left_pads [
833- req .req_id ] -= n_padded_blocks * self .BLOCK_SIZE
831+ for req_id in self .req_ids2left_pads :
832+ self .req_ids2left_pads [req_id ] -= offset
834833
835834 # free blocks
836835 for _ in range (n_padded_blocks ):
837- freed_block_id = self .req_ids2blocks [req .req_id ].popleft ()
836+ freed_block_id = self .req_ids2blocks [req_id ].popleft ()
837+ logger .debug ("Freeing block with id: %s" , freed_block_id )
838838 self .free_blocks .append (freed_block_id )
839839
840840 # update tkv
841- self .tkv -= n_padded_blocks * self . BLOCK_SIZE
841+ self .tkv -= offset
842842
843843 return
844844
@@ -905,6 +905,10 @@ def pad_input_ids(
905905 def prepare_model_input (
906906 self , scheduler_output : SchedulerOutput ) -> ModelForwardInputs :
907907
908+ # remove left padding if applicable before next prefil/decode step
909+ if envs_spyre .VLLM_SPYRE_RM_PADDED_BLOCKS :
910+ self .reduce_left_padding ()
911+
908912 # NOTE: We assume that all sequences in the group are all prompts or
909913 # all decodes.
910914 # Also assuming that new sequences are prefills
0 commit comments