Skip to content

Commit a82191a

Browse files
authored
[CB] refactor left padding removal (#211)
### [CB] refactor left padding removal - calls the function `reduce_left_padding` at every step (prefill and decode) - removes the dependency on cached requests - adjusts /tests for CB covering that exact case --------- Signed-off-by: Yannick Schnider <[email protected]>
1 parent 0d42959 commit a82191a

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

tests/e2e/test_spyre_cb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def get_params_test_remove_left_padding():
701701
{
702702
# Prefill sequence 2
703703
"step": 42,
704-
"tkv": 103, # TODO expecting 39 for next implementation
704+
"tkv": 39, # left padding reduction: 103 - 64 (block size)
705705
"waiting": [],
706706
"running": ["2", "1"],
707707
"request_outputs": ["2"]

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def __init__(
581581
# TO DO: move to InputBatch
582582
self.req_ids2blocks: dict[str, deque[int]] = {}
583583
self.req_ids2left_pads: dict[str, int] = {}
584-
self.tkv = 0
584+
self.tkv: int = 0
585585
self.free_blocks = deque([i for i in range(NUM_BLOCKS)])
586586
self.dummy_req_ids2blocks: list[int] = []
587587

@@ -739,9 +739,6 @@ def _prepare_decode(
739739
dtype=torch.bool,
740740
device="cpu")
741741

742-
if envs_spyre.VLLM_SPYRE_RM_PADDED_BLOCKS:
743-
self.reduce_left_padding(cached_requests)
744-
745742
for cached_request in cached_requests:
746743
# TODO: Will this always just be one token ID if there's no spec
747744
# or jump decoding?
@@ -818,27 +815,30 @@ def _prepare_decode(
818815
is_prompt=False,
819816
)
820817

821-
def reduce_left_padding(self, requests: list[CachedRequestData]) -> None:
818+
def reduce_left_padding(self) -> None:
819+
820+
if len(self.req_ids2left_pads) == 0:
821+
return
822822

823-
min_left_pad = min(
824-
[self.req_ids2left_pads[r.req_id] for r in requests])
823+
min_left_pad = min(self.req_ids2left_pads.values())
825824
n_padded_blocks = min_left_pad // self.BLOCK_SIZE
825+
offset = n_padded_blocks * self.BLOCK_SIZE
826826

827-
if n_padded_blocks > 0:
827+
if offset > 0:
828828
logger.debug("Number of removed blocks due to left padding: %d",
829829
n_padded_blocks)
830830

831-
for req in requests:
832-
self.req_ids2left_pads[
833-
req.req_id] -= n_padded_blocks * self.BLOCK_SIZE
831+
for req_id in self.req_ids2left_pads:
832+
self.req_ids2left_pads[req_id] -= offset
834833

835834
# free blocks
836835
for _ in range(n_padded_blocks):
837-
freed_block_id = self.req_ids2blocks[req.req_id].popleft()
836+
freed_block_id = self.req_ids2blocks[req_id].popleft()
837+
logger.debug("Freeing block with id: %s", freed_block_id)
838838
self.free_blocks.append(freed_block_id)
839839

840840
# update tkv
841-
self.tkv -= n_padded_blocks * self.BLOCK_SIZE
841+
self.tkv -= offset
842842

843843
return
844844

@@ -905,6 +905,10 @@ def pad_input_ids(
905905
def prepare_model_input(
906906
self, scheduler_output: SchedulerOutput) -> ModelForwardInputs:
907907

908+
# remove left padding if applicable before next prefil/decode step
909+
if envs_spyre.VLLM_SPYRE_RM_PADDED_BLOCKS:
910+
self.reduce_left_padding()
911+
908912
# NOTE: We assume that all sequences in the group are all prompts or
909913
# all decodes.
910914
# Also assuming that new sequences are prefills

0 commit comments

Comments
 (0)