Skip to content

Commit 76fd76a

Browse files
committed
clean up input passing
1 parent fd56eaa commit 76fd76a

File tree

2 files changed

+12
-50
lines changed

2 files changed

+12
-50
lines changed

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch._inductor.config
77
import torch.distributed as dist
88
import torch.nn as nn
9-
#from fms.models import get_model
109
from vllm.model_executor.model_loader import get_model
1110
from transformers import PretrainedConfig
1211
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig
@@ -40,17 +39,14 @@ class SpyreCausalLM(nn.Module):
4039

4140
def __init__(
4241
self,
43-
model_config: ModelConfig,
44-
parallel_config: ParallelConfig,
45-
scheduler_config: SchedulerConfig,
4642
vllm_config: VllmConfig,
4743
max_prompt_length: int,
4844
max_decode_length: int,
4945
) -> None:
5046
super().__init__()
5147

5248
self.logits_processor = LogitsProcessor(
53-
model_config.hf_config.vocab_size, logits_as_input=True)
49+
vllm_config.model_config.hf_config.vocab_size, logits_as_input=True)
5450
self.sampler = get_sampler()
5551

5652
# boolean tensor of length batch size with indices:
@@ -63,15 +59,9 @@ def __init__(
6359

6460
# FMS Model
6561
if envs_spyre.VLLM_SPYRE_USE_CB:
66-
self.model = ContinuousBatchingFmsModel(model_config,
67-
parallel_config,
68-
scheduler_config,
69-
vllm_config)
62+
self.model = ContinuousBatchingFmsModel(vllm_config)
7063
else:
7164
self.model = StaticBatchingFmsModel(
72-
model_config,
73-
parallel_config,
74-
scheduler_config,
7565
vllm_config,
7666
max_prompt_length,
7767
max_decode_length,
@@ -81,8 +71,6 @@ def forward(
8171
self,
8272
input_ids: torch.Tensor,
8373
positions: torch.Tensor,
84-
#masks: torch.Tensor,
85-
#intermediate_tensors: Optional[IntermediateTensors],
8674
is_prompt: bool,
8775
current_tkv_mask: Optional[torch.Tensor] = None,
8876
left_padded_prompt_mask: Optional[torch.Tensor] = None,
@@ -109,7 +97,6 @@ def forward(
10997
logits = self.model(
11098
input_ids,
11199
positions=positions,
112-
#only_last_token=not envs_spyre.VLLM_SPYRE_USE_CB,
113100
**extra_kwargs,
114101
)
115102

@@ -147,16 +134,14 @@ class FmsModelBase(nn.Module):
147134

148135
def __init__(
149136
self,
150-
model_config: ModelConfig,
151-
parallel_config: ParallelConfig,
152137
vllm_config: VllmConfig,
153138
max_prompt_length: int,
154139
max_decode_length: int,
155140
sendnn_dynamic: bool,
156141
) -> None:
157142
super().__init__()
158143

159-
self.config: PretrainedConfig = model_config.hf_config
144+
self.config: PretrainedConfig = vllm_config.model_config.hf_config
160145
self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \
161146
'sendnn' else torch.float32
162147

@@ -165,11 +150,11 @@ def __init__(
165150
self.vllm_config = vllm_config
166151

167152
# Load the weights from the cached or downloaded files.
168-
self.load_weights(model_config=model_config,
153+
self.load_weights(model_config=vllm_config.model_config,
169154
max_prompt_length=max_prompt_length,
170155
max_decode_length=max_decode_length,
171156
distributed_strategy="tp"
172-
if parallel_config.world_size > 1 else None,
157+
if vllm_config.parallel_config.world_size > 1 else None,
173158
sendnn_dynamic=sendnn_dynamic)
174159

175160

@@ -224,7 +209,6 @@ def load_weights(
224209
# we can use fused weights unless running on Spyre
225210
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn"
226211

227-
#self.model = get_model(architecture="hf_configured", variant=model_config.model, model_path=model_path, source=model_source, data_type=self.dtype, distributed_strategy=distributed_strategy, group=dist.group.WORLD, fused_weights=fused_weights, linear_config=linear_config)
228212
self.model = get_model(vllm_config=self.vllm_config)
229213

230214
self.model.eval()
@@ -273,30 +257,26 @@ class ContinuousBatchingFmsModel(FmsModelBase):
273257

274258
def __init__(
275259
self,
276-
model_config: ModelConfig,
277-
parallel_config: ParallelConfig,
278-
scheduler_config: SchedulerConfig,
279260
vllm_config: VllmConfig,
280261
) -> None:
281262

282263
BLOCK_SIZE = 64
283-
max_batch = scheduler_config.max_num_seqs
284-
max_model_len = scheduler_config.max_model_len
264+
max_batch = vllm_config.scheduler_config.max_num_seqs
265+
max_model_len = vllm_config.scheduler_config.max_model_len
285266

286267
# edge case: prompt fills model length: can produce 1 token with prefill
287268
max_prompt_length = max_model_len
288269
# edge case: prompt will be padded to first block:
289270
# can produce 1 token with prefill plus rest of model length
290271
max_decode_length = max_model_len - BLOCK_SIZE + 1
291272

292-
super().__init__(model_config,
293-
parallel_config,
273+
super().__init__(vllm_config,
294274
max_prompt_length,
295275
max_decode_length,
296276
sendnn_dynamic=True)
297277

298278
# physical KV cache on AIU Spyre: will eventually not live in this class
299-
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
279+
num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config)
300280

301281
if self.config.model_type in {'llama', 'granite'}:
302282
num_layers = self.config.num_hidden_layers
@@ -330,9 +310,7 @@ def forward(
330310
self,
331311
input_ids: torch.Tensor,
332312
positions: torch.Tensor,
333-
mask: torch.Tensor,
334313
use_cache: bool,
335-
only_last_token: bool,
336314
current_tkv_mask: torch.Tensor,
337315
left_padded_prompt_mask: torch.Tensor,
338316
block_table: torch.Tensor,
@@ -343,36 +321,29 @@ def forward(
343321
output = self.model(
344322
input_ids,
345323
positions=positions,
346-
mask=mask,
347324
past_key_value_states=self.past_key_value_states,
348325
use_cache=use_cache,
349-
only_last_token=only_last_token,
350326
current_tkv_mask=current_tkv_mask,
351327
left_padded_prompt_mask=left_padded_prompt_mask,
352328
block_table=block_table,
353329
slot_mapping=slot_mapping,
354330
**extra_kwargs,
355331
)
356332

357-
logits, self.past_key_value_states = output
333+
self.past_key_value_states = output
358334

359-
return logits
335+
return output
360336

361337

362338
class StaticBatchingFmsModel(FmsModelBase):
363339

364340
def __init__(
365341
self,
366-
model_config: ModelConfig,
367-
parallel_config: ParallelConfig,
368-
_: SchedulerConfig,
369342
vllm_config: VllmConfig,
370343
max_prompt_length: int,
371344
max_decode_length: int,
372345
) -> None:
373-
super().__init__(model_config,
374-
parallel_config,
375-
vllm_config,
346+
super().__init__(vllm_config,
376347
max_prompt_length,
377348
max_decode_length,
378349
sendnn_dynamic=False)
@@ -385,20 +356,16 @@ def forward(
385356
self,
386357
input_ids: torch.Tensor,
387358
positions: torch.Tensor,
388-
#mask: torch.Tensor,
389359
**extra_kwargs,
390360
) -> torch.Tensor:
391361

392362
output = self.model(
393363
input_ids,
394364
positions=positions,
395-
#mask=mask,
396365
intermediate_tensors=self.past_key_value_states,
397366
**extra_kwargs,
398367
)
399368

400-
#logits, self.past_key_value_states = output
401369
self.past_key_value_states = output
402-
#logits = self.model.compute_logits(output)
403370

404371
return output

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch
99
from torch import nn
1010
from vllm_spyre.v1.attention.backends.spyre import SpyreSDPAMetadata, SpyreSDPABackend
11-
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
12-
from vllm.attention.backends.triton_mla import TritonMLABackend
1311
from vllm.config import DeviceConfig, VllmConfig
1412
from vllm.forward_context import get_forward_context, set_forward_context
1513
from vllm.logger import init_logger
@@ -122,9 +120,6 @@ def load_model(self, prompt_lens: Iterable[int],
122120
max_pad_length = max(prompt_lens)
123121
max_decode_length = max(num_decode_tokens)
124122
self.model = SpyreCausalLM(
125-
self.model_config,
126-
parallel_config=self.parallel_config,
127-
scheduler_config=self.scheduler_config,
128123
vllm_config=self.vllm_config,
129124
max_prompt_length=max_pad_length,
130125
max_decode_length=max_decode_length,

0 commit comments

Comments
 (0)