66import torch ._inductor .config
77import torch .distributed as dist
88import torch .nn as nn
9- #from fms.models import get_model
109from vllm .model_executor .model_loader import get_model
1110from transformers import PretrainedConfig
1211from vllm .config import ModelConfig , ParallelConfig , SchedulerConfig , VllmConfig
@@ -40,17 +39,14 @@ class SpyreCausalLM(nn.Module):
4039
4140 def __init__ (
4241 self ,
43- model_config : ModelConfig ,
44- parallel_config : ParallelConfig ,
45- scheduler_config : SchedulerConfig ,
4642 vllm_config : VllmConfig ,
4743 max_prompt_length : int ,
4844 max_decode_length : int ,
4945 ) -> None :
5046 super ().__init__ ()
5147
5248 self .logits_processor = LogitsProcessor (
53- model_config .hf_config .vocab_size , logits_as_input = True )
49+ vllm_config . model_config .hf_config .vocab_size , logits_as_input = True )
5450 self .sampler = get_sampler ()
5551
5652 # boolean tensor of length batch size with indices:
@@ -63,15 +59,9 @@ def __init__(
6359
6460 # FMS Model
6561 if envs_spyre .VLLM_SPYRE_USE_CB :
66- self .model = ContinuousBatchingFmsModel (model_config ,
67- parallel_config ,
68- scheduler_config ,
69- vllm_config )
62+ self .model = ContinuousBatchingFmsModel (vllm_config )
7063 else :
7164 self .model = StaticBatchingFmsModel (
72- model_config ,
73- parallel_config ,
74- scheduler_config ,
7565 vllm_config ,
7666 max_prompt_length ,
7767 max_decode_length ,
@@ -81,8 +71,6 @@ def forward(
8171 self ,
8272 input_ids : torch .Tensor ,
8373 positions : torch .Tensor ,
84- #masks: torch.Tensor,
85- #intermediate_tensors: Optional[IntermediateTensors],
8674 is_prompt : bool ,
8775 current_tkv_mask : Optional [torch .Tensor ] = None ,
8876 left_padded_prompt_mask : Optional [torch .Tensor ] = None ,
@@ -109,7 +97,6 @@ def forward(
10997 logits = self .model (
11098 input_ids ,
11199 positions = positions ,
112- #only_last_token=not envs_spyre.VLLM_SPYRE_USE_CB,
113100 ** extra_kwargs ,
114101 )
115102
@@ -147,16 +134,14 @@ class FmsModelBase(nn.Module):
147134
148135 def __init__ (
149136 self ,
150- model_config : ModelConfig ,
151- parallel_config : ParallelConfig ,
152137 vllm_config : VllmConfig ,
153138 max_prompt_length : int ,
154139 max_decode_length : int ,
155140 sendnn_dynamic : bool ,
156141 ) -> None :
157142 super ().__init__ ()
158143
159- self .config : PretrainedConfig = model_config .hf_config
144+ self .config : PretrainedConfig = vllm_config . model_config .hf_config
160145 self .dtype = torch .float16 if envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND == \
161146 'sendnn' else torch .float32
162147
@@ -165,11 +150,11 @@ def __init__(
165150 self .vllm_config = vllm_config
166151
167152 # Load the weights from the cached or downloaded files.
168- self .load_weights (model_config = model_config ,
153+ self .load_weights (model_config = vllm_config . model_config ,
169154 max_prompt_length = max_prompt_length ,
170155 max_decode_length = max_decode_length ,
171156 distributed_strategy = "tp"
172- if parallel_config .world_size > 1 else None ,
157+ if vllm_config . parallel_config .world_size > 1 else None ,
173158 sendnn_dynamic = sendnn_dynamic )
174159
175160
@@ -224,7 +209,6 @@ def load_weights(
224209 # we can use fused weights unless running on Spyre
225210 fused_weights = envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND != "sendnn"
226211
227- #self.model = get_model(architecture="hf_configured", variant=model_config.model, model_path=model_path, source=model_source, data_type=self.dtype, distributed_strategy=distributed_strategy, group=dist.group.WORLD, fused_weights=fused_weights, linear_config=linear_config)
228212 self .model = get_model (vllm_config = self .vllm_config )
229213
230214 self .model .eval ()
@@ -273,30 +257,26 @@ class ContinuousBatchingFmsModel(FmsModelBase):
273257
274258 def __init__ (
275259 self ,
276- model_config : ModelConfig ,
277- parallel_config : ParallelConfig ,
278- scheduler_config : SchedulerConfig ,
279260 vllm_config : VllmConfig ,
280261 ) -> None :
281262
282263 BLOCK_SIZE = 64
283- max_batch = scheduler_config .max_num_seqs
284- max_model_len = scheduler_config .max_model_len
264+ max_batch = vllm_config . scheduler_config .max_num_seqs
265+ max_model_len = vllm_config . scheduler_config .max_model_len
285266
286267 # edge case: prompt fills model length: can produce 1 token with prefill
287268 max_prompt_length = max_model_len
288269 # edge case: prompt will be padded to first block:
289270 # can produce 1 token with prefill plus rest of model length
290271 max_decode_length = max_model_len - BLOCK_SIZE + 1
291272
292- super ().__init__ (model_config ,
293- parallel_config ,
273+ super ().__init__ (vllm_config ,
294274 max_prompt_length ,
295275 max_decode_length ,
296276 sendnn_dynamic = True )
297277
298278 # physical KV cache on AIU Spyre: will eventually not live in this class
299- num_kv_heads = model_config .get_num_kv_heads (parallel_config )
279+ num_kv_heads = vllm_config . model_config .get_num_kv_heads (vllm_config . parallel_config )
300280
301281 if self .config .model_type in {'llama' , 'granite' }:
302282 num_layers = self .config .num_hidden_layers
@@ -330,9 +310,7 @@ def forward(
330310 self ,
331311 input_ids : torch .Tensor ,
332312 positions : torch .Tensor ,
333- mask : torch .Tensor ,
334313 use_cache : bool ,
335- only_last_token : bool ,
336314 current_tkv_mask : torch .Tensor ,
337315 left_padded_prompt_mask : torch .Tensor ,
338316 block_table : torch .Tensor ,
@@ -343,36 +321,29 @@ def forward(
343321 output = self .model (
344322 input_ids ,
345323 positions = positions ,
346- mask = mask ,
347324 past_key_value_states = self .past_key_value_states ,
348325 use_cache = use_cache ,
349- only_last_token = only_last_token ,
350326 current_tkv_mask = current_tkv_mask ,
351327 left_padded_prompt_mask = left_padded_prompt_mask ,
352328 block_table = block_table ,
353329 slot_mapping = slot_mapping ,
354330 ** extra_kwargs ,
355331 )
356332
357- logits , self .past_key_value_states = output
333+ self .past_key_value_states = output
358334
359- return logits
335+ return output
360336
361337
362338class StaticBatchingFmsModel (FmsModelBase ):
363339
364340 def __init__ (
365341 self ,
366- model_config : ModelConfig ,
367- parallel_config : ParallelConfig ,
368- _ : SchedulerConfig ,
369342 vllm_config : VllmConfig ,
370343 max_prompt_length : int ,
371344 max_decode_length : int ,
372345 ) -> None :
373- super ().__init__ (model_config ,
374- parallel_config ,
375- vllm_config ,
346+ super ().__init__ (vllm_config ,
376347 max_prompt_length ,
377348 max_decode_length ,
378349 sendnn_dynamic = False )
@@ -385,20 +356,16 @@ def forward(
385356 self ,
386357 input_ids : torch .Tensor ,
387358 positions : torch .Tensor ,
388- #mask: torch.Tensor,
389359 ** extra_kwargs ,
390360 ) -> torch .Tensor :
391361
392362 output = self .model (
393363 input_ids ,
394364 positions = positions ,
395- #mask=mask,
396365 intermediate_tensors = self .past_key_value_states ,
397366 ** extra_kwargs ,
398367 )
399368
400- #logits, self.past_key_value_states = output
401369 self .past_key_value_states = output
402- #logits = self.model.compute_logits(output)
403370
404371 return output
0 commit comments