Skip to content

Commit a036e97

Browse files
ochougulvbaddiVinayak Baddimamtsingquic-mamta
authored
Prefill+decode gpt oss (#608)
# We should be using disaggragate serving for GPTOSS model for best performance - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok - We use read all experts only once always strategy in prefill-only model - And we treat weights activtions meaning read only chosen experts for decode-only model # Prefill-only model ## Blocking default behviour when `prefill_only=True` in compile API - NUM_Q_BLOCKS=<int> set number of Q blocks in attention - NUM_FFN_BLOCKS=<int> set number of blocks in FFN - ENABLE_OPT_SWA=0 or 1 to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs - prefix_caching is not supported with this mode ## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default - This model can be used for prefix_caching by passing `kv_cache_batch_size=<int>` in compile API # Decode-only model ## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API - This reduces the amount of DDR used by the model - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=<int>` and optinally `kv_cache_batch_size=<int>` if needed ## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=<int>` and optinally `kv_cache_batch_size=<int>` if needed - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers NOTE: * decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it * 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=<path to file>` * It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error --------- Signed-off-by: vbaddi <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Mamta Singh <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Co-authored-by: Vinayak Baddi <[email protected]> Co-authored-by: Vinayak Baddi <[email protected]> Co-authored-by: Mamta Singh <[email protected]> Co-authored-by: Mamta Singh <[email protected]>
1 parent 1b2fabe commit a036e97

26 files changed

+1805
-174
lines changed

QEfficient/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
# -----------------------------------------------------------------------------
77

88
import os
9-
import warnings
9+
10+
# ----------------------------------------------------------------------------- #
11+
# For faster downloads via hf_transfer
12+
# This code is put above import statements as this needs to be executed before
13+
# hf_transfer is imported (will happen on line 15 via leading imports)
14+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15+
# DO NOT ADD ANY CODE ABOVE THIS LINE
16+
# Please contact maintainers if you must edit this file above this line.
17+
# ----------------------------------------------------------------------------- #
18+
# Placeholder for all non-transformer models registered in QEfficient
19+
import warnings # noqa: I001
1020

1121
import QEfficient.utils.model_registery # noqa: F401
1222
from QEfficient.base import (
@@ -26,6 +36,10 @@
2636
from QEfficient.utils import custom_format_warning
2737
from QEfficient.utils.logging_utils import logger
2838

39+
# custom warning for the better logging experience
40+
warnings.formatwarning = custom_format_warning
41+
42+
2943
# Users can use QEfficient.export for exporting models to ONNX
3044
export = qualcomm_efficient_converter
3145
__all__ = [
@@ -42,14 +56,7 @@
4256
"QEFFCommonLoader",
4357
"QEffFluxPipeline",
4458
]
45-
# For faster downloads via hf_transfer
46-
# This code is put above import statements as this needs to be executed before
47-
# hf_transfer is imported (will happen on line 15 via leading imports)
48-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
49-
# Placeholder for all non-transformer models registered in QEfficient
5059

51-
# custom warning for the better logging experience
52-
warnings.formatwarning = custom_format_warning
5360

5461
# Conditionally import QAIC-related modules if the SDK is installed
5562
__version__ = "0.0.1.dev0"

QEfficient/base/modeling_qeff.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6060
super().__init__()
6161
self.model = model
6262
self.hash_params = create_model_params(self, **kwargs)
63+
self.prefill_onnx_path: Optional[str] = None
6364
self.onnx_path: Optional[str] = None
6465
self.qpc_path: Optional[str] = None
6566
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -204,10 +205,11 @@ def _export(
204205
example_inputs: Dict[str, torch.Tensor],
205206
output_names: List[str],
206207
dynamic_axes: Dict[str, Dict[int, str]],
207-
export_kwargs: Optional[Dict[str, any]] = None,
208208
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
209209
export_dir: Optional[str] = None,
210210
offload_pt_weights: bool = True,
211+
prefill_only: Optional[bool] = False,
212+
**export_kwargs,
211213
) -> str:
212214
"""
213215
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -232,11 +234,16 @@ def _export(
232234
instance using from_pretrained() for re-export.
233235
234236
"""
237+
# TODO: Hack for retain_full_kv, handle this outside
238+
export_kwargs.pop("retain_full_kv", None)
235239
onnx_path = export_dir / f"{self.model_name}.onnx"
236240

237241
# Return early if ONNX already exists
238242
if onnx_path.is_file():
239-
self.onnx_path = onnx_path
243+
if prefill_only:
244+
self.prefill_onnx_path = onnx_path
245+
else:
246+
self.onnx_path = onnx_path
240247
return onnx_path
241248

242249
# check if the model is in meta state or weights are offloaded
@@ -272,9 +279,6 @@ def _export(
272279
input_names.append(param)
273280

274281
try:
275-
# Export to ONNX
276-
export_kwargs = {} if export_kwargs is None else export_kwargs
277-
278282
torch.onnx.export(
279283
self.model,
280284
(example_inputs,),
@@ -318,9 +322,42 @@ def _export(
318322
finally:
319323
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
320324

321-
self.onnx_path = onnx_path
325+
if prefill_only:
326+
self.prefill_onnx_path = onnx_path
327+
else:
328+
self.onnx_path = onnx_path
322329
return onnx_path
323330

331+
def get_onnx_path(
332+
self,
333+
prefill_only: Optional[bool] = False,
334+
enable_chunking: Optional[bool] = False,
335+
specializations: Optional[List[Dict[str, int]]] = None,
336+
offload_pt_weights: Optional[bool] = True,
337+
use_onnx_subfunctions: Optional[bool] = False,
338+
retain_full_kv: Optional[bool] = False,
339+
):
340+
kwargs = {
341+
"offload_pt_weights": offload_pt_weights,
342+
"use_onnx_subfunctions": use_onnx_subfunctions,
343+
"retain_full_kv": retain_full_kv,
344+
}
345+
if prefill_only:
346+
if self.prefill_onnx_path is None:
347+
kwargs.update(
348+
{
349+
"prefill_only": prefill_only,
350+
"prefill_seq_len": specializations[0].get("seq_len"),
351+
"enable_chunking": enable_chunking,
352+
}
353+
)
354+
self.export(**kwargs)
355+
return self.prefill_onnx_path
356+
else:
357+
if self.onnx_path is None:
358+
self.export(**kwargs)
359+
return self.onnx_path
360+
324361
@dump_qconfig
325362
def _compile(
326363
self,
@@ -335,6 +372,10 @@ def _compile(
335372
enable_qnn: Optional[bool] = False,
336373
qnn_config: Optional[str] = None,
337374
use_onnx_subfunctions: bool = False,
375+
prefill_only: Optional[str] = None,
376+
offload_pt_weights: Optional[bool] = True,
377+
enable_chunking: Optional[bool] = False,
378+
retain_full_kv: Optional[bool] = None,
338379
**compiler_options,
339380
) -> str:
340381
"""
@@ -360,11 +401,18 @@ def _compile(
360401
361402
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
362403
"""
363-
364-
if onnx_path is None and self.onnx_path is None:
365-
self.export(use_onnx_subfunctions=use_onnx_subfunctions)
366-
367-
onnx_path = Path(onnx_path or self.onnx_path)
404+
onnx_path = Path(
405+
onnx_path
406+
if onnx_path
407+
else self.get_onnx_path(
408+
prefill_only,
409+
enable_chunking,
410+
specializations,
411+
offload_pt_weights,
412+
use_onnx_subfunctions,
413+
retain_full_kv,
414+
)
415+
)
368416
compile_dir = Path(compile_dir or onnx_path.parent)
369417
qpc_path = compile_dir / "qpc"
370418
if not onnx_path.is_file():
@@ -426,6 +474,7 @@ def _compile(
426474
"mdp_ts_num_devices": mdp_ts_num_devices,
427475
"mdp_ts_json": mdp_ts_json,
428476
"num_speculative_tokens": num_speculative_tokens,
477+
"prefill_only": prefill_only,
429478
}
430479
compile_hash = hash_dict_params(compile_hash_params)
431480

@@ -465,6 +514,16 @@ def _compile(
465514

466515
command.append(f"-aic-binary-dir={qpc_path}")
467516
logger.info(f"Running compiler: {' '.join(command)}")
517+
if use_onnx_subfunctions:
518+
519+
class FeatureNotAvailableError(Exception):
520+
pass
521+
522+
exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
523+
raise FeatureNotAvailableError(
524+
"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
525+
+ f"\nRun following command manually with assert compiler:\n{exec_command}"
526+
)
468527
try:
469528
subprocess.run(command, capture_output=True, check=True)
470529
except subprocess.CalledProcessError as e:
@@ -485,5 +544,4 @@ def _compile(
485544
logger.info("Hashed parameters exported successfully.")
486545

487546
self.qpc_path = qpc_path
488-
489547
return qpc_path

QEfficient/base/onnx_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ class CustomOpTransform(BaseOnnxTransform):
9595
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
9696
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
9797
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
98-
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
9998
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
100-
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
10199
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
102100
"CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
103101
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
102+
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
103+
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
104104
}
105105

106106
@classmethod

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function):
136136
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
137137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
138138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
139+
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
139140
return data[batch_indices, head_indices, ctx_indices]
140141

141142
@staticmethod

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
126126
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
127127
batch_indices = batch_index.view(-1, 1, 1)
128128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
129+
ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
129130
return data[batch_indices, head_indices, ctx_indices]
130131

131132
@staticmethod

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def export(
102102
output_names: List[str],
103103
dynamic_axes: Dict,
104104
export_dir: str = None,
105-
export_kwargs: Dict = None,
105+
export_kwargs: Dict = {},
106106
) -> str:
107107
"""
108108
Export the text encoder model to ONNX format.
@@ -122,7 +122,7 @@ def export(
122122
output_names=output_names,
123123
dynamic_axes=dynamic_axes,
124124
export_dir=export_dir,
125-
export_kwargs=export_kwargs,
125+
**export_kwargs,
126126
)
127127

128128
def compile(self, specializations: List[Dict], **compiler_options) -> None:
@@ -179,7 +179,7 @@ def export(
179179
output_names: List[str],
180180
dynamic_axes: Dict,
181181
export_dir: str = None,
182-
export_kwargs: Dict = None,
182+
export_kwargs: Dict = {},
183183
) -> str:
184184
"""
185185
Export the UNet model to ONNX format.
@@ -199,7 +199,7 @@ def export(
199199
output_names=output_names,
200200
dynamic_axes=dynamic_axes,
201201
export_dir=export_dir,
202-
export_kwargs=export_kwargs,
202+
**export_kwargs,
203203
)
204204

205205
def compile(self, specializations: List[Dict], **compiler_options) -> None:
@@ -292,7 +292,7 @@ def export(
292292
output_names: List[str],
293293
dynamic_axes: Dict,
294294
export_dir: str = None,
295-
export_kwargs: Dict = None,
295+
export_kwargs: Dict = {},
296296
) -> str:
297297
"""
298298
Export the VAE model to ONNX format.
@@ -312,7 +312,7 @@ def export(
312312
output_names=output_names,
313313
dynamic_axes=dynamic_axes,
314314
export_dir=export_dir,
315-
export_kwargs=export_kwargs,
315+
**export_kwargs,
316316
)
317317

318318
def compile(self, specializations: List[Dict], **compiler_options) -> None:
@@ -438,7 +438,7 @@ def export(
438438
output_names: List[str],
439439
dynamic_axes: Dict,
440440
export_dir: str = None,
441-
export_kwargs: Dict = None,
441+
export_kwargs: Dict = {},
442442
use_onnx_subfunctions: bool = False,
443443
) -> str:
444444
"""
@@ -466,8 +466,8 @@ def export(
466466
output_names=output_names,
467467
dynamic_axes=dynamic_axes,
468468
export_dir=export_dir,
469-
export_kwargs=export_kwargs,
470469
offload_pt_weights=False, # As weights are needed with AdaLN changes
470+
**export_kwargs,
471471
)
472472

473473
def compile(self, specializations: List[Dict], **compiler_options) -> None:

QEfficient/peft/auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
253253
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
254254
return obj
255255

256-
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
256+
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
257257
"""
258258
Export the model with the active adapter to ONNX format.
259259
@@ -291,10 +291,10 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
291291
example_inputs,
292292
output_names,
293293
dynamic_axes,
294-
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
294+
do_constant_folding=False, # To avoid merging adapter weights with base weights
295295
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
296296
export_dir=export_dir,
297-
use_onnx_subfunctions=use_onnx_subfunctions,
297+
**kwargs,
298298
)
299299

300300
def compile(

QEfficient/peft/lora/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _init_adapter_model(self):
327327
# load_weight to model
328328
self._load_adapter_weights_to_model()
329329

330-
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
330+
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
331331
"""
332332
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.
333333
@@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
387387
output_names,
388388
dynamic_axes,
389389
export_dir=export_dir,
390-
use_onnx_subfunctions=use_onnx_subfunctions,
390+
**kwargs,
391391
)
392392

393393
def generate(

0 commit comments

Comments
 (0)