|
8 | 8 | import gc |
9 | 9 | import inspect |
10 | 10 | import logging |
11 | | -import re |
12 | 11 | import shutil |
13 | 12 | import subprocess |
14 | 13 | import warnings |
|
21 | 20 |
|
22 | 21 | from QEfficient.base.onnx_transforms import ( |
23 | 22 | BaseOnnxTransform, |
24 | | - CustomOpTransform, |
25 | 23 | OnnxTransformPipeline, |
26 | | - RenameFunctionOutputsTransform, |
27 | 24 | ) |
28 | 25 | from QEfficient.base.pytorch_transforms import PytorchTransform |
29 | 26 | from QEfficient.compile.qnn_compiler import compile as qnn_compile |
30 | 27 | from QEfficient.generation.cloud_infer import QAICInferenceSession |
31 | | -from QEfficient.transformers.cache_utils import InvalidIndexProvider |
32 | | -from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export |
33 | 28 | from QEfficient.utils import ( |
34 | 29 | constants, |
35 | 30 | create_json, |
36 | 31 | create_model_params, |
37 | 32 | dump_qconfig, |
38 | | - export_wrapper, |
39 | 33 | generate_mdp_partition_config, |
40 | 34 | hash_dict_params, |
41 | 35 | load_json, |
42 | 36 | ) |
43 | | -from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches |
| 37 | +from QEfficient.utils.export_utils import export_wrapper |
44 | 38 |
|
45 | 39 | logger = logging.getLogger(__name__) |
46 | 40 |
|
@@ -127,9 +121,35 @@ def _model_offloaded_check(self) -> None: |
127 | 121 | logger.error(error_msg) |
128 | 122 | raise RuntimeError(error_msg) |
129 | 123 |
|
| 124 | + @property |
| 125 | + def model_name(self) -> str: |
| 126 | + """ |
| 127 | + Get the model class name without QEff/QEFF prefix. |
| 128 | +
|
| 129 | + This property extracts the underlying model's class name and removes |
| 130 | + any QEff or QEFF prefix that may have been added during wrapping. |
| 131 | +
|
| 132 | + Returns: |
| 133 | + str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel") |
| 134 | + """ |
| 135 | + mname = self.model.__class__.__name__ |
| 136 | + if mname.startswith("QEff") or mname.startswith("QEFF"): |
| 137 | + mname = mname[4:] |
| 138 | + return mname |
| 139 | + |
130 | 140 | @property |
131 | 141 | @abstractmethod |
132 | | - def model_name(self) -> str: ... |
| 142 | + def get_model_config(self) -> Dict: |
| 143 | + """ |
| 144 | + Get the model configuration as a dictionary. |
| 145 | +
|
| 146 | + This is an abstract property that must be implemented by all subclasses. |
| 147 | + Typically returns: self.model.config.__dict__ |
| 148 | +
|
| 149 | + Returns: |
| 150 | + Dict: The configuration dictionary of the underlying model |
| 151 | + """ |
| 152 | + pass |
133 | 153 |
|
134 | 154 | @abstractmethod |
135 | 155 | def export(self, export_dir: Optional[str] = None) -> Path: |
@@ -259,18 +279,8 @@ def _export( |
259 | 279 | input_names.append(param) |
260 | 280 |
|
261 | 281 | try: |
262 | | - # Initialize the registry with your custom ops |
| 282 | + # Export to ONNX |
263 | 283 | export_kwargs = {} if export_kwargs is None else export_kwargs |
264 | | - if use_onnx_subfunctions: |
265 | | - warnings.warn( |
266 | | - "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." |
267 | | - ) |
268 | | - apply_torch_patches() |
269 | | - InvalidIndexProvider.SUBFUNC_ENABLED = True |
270 | | - output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] |
271 | | - export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) |
272 | | - self._onnx_transforms.append(RenameFunctionOutputsTransform) |
273 | | - self._onnx_transforms.append(CustomOpTransform) |
274 | 284 |
|
275 | 285 | torch.onnx.export( |
276 | 286 | self.model, |
|
0 commit comments