|
8 | 8 | import gc |
9 | 9 | import inspect |
10 | 10 | import logging |
11 | | -import re |
12 | 11 | import shutil |
13 | 12 | import subprocess |
14 | 13 | import warnings |
|
21 | 20 |
|
22 | 21 | from QEfficient.base.onnx_transforms import ( |
23 | 22 | BaseOnnxTransform, |
24 | | - CustomOpTransform, |
25 | 23 | OnnxTransformPipeline, |
26 | | - RenameFunctionOutputsTransform, |
27 | 24 | ) |
28 | 25 | from QEfficient.base.pytorch_transforms import PytorchTransform |
29 | 26 | from QEfficient.compile.qnn_compiler import compile as qnn_compile |
30 | 27 | from QEfficient.generation.cloud_infer import QAICInferenceSession |
31 | | -from QEfficient.transformers.cache_utils import InvalidIndexProvider |
32 | | -from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export |
33 | 28 | from QEfficient.utils import ( |
34 | 29 | constants, |
35 | 30 | create_json, |
36 | 31 | create_model_params, |
37 | 32 | dump_qconfig, |
38 | | - export_wrapper, |
39 | 33 | generate_mdp_partition_config, |
40 | 34 | hash_dict_params, |
41 | 35 | load_json, |
42 | 36 | ) |
43 | | -from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches |
| 37 | +from QEfficient.utils.export_utils import export_wrapper |
44 | 38 |
|
45 | 39 | logger = logging.getLogger(__name__) |
46 | 40 |
|
@@ -125,9 +119,35 @@ def _model_offloaded_check(self) -> None: |
125 | 119 | logger.error(error_msg) |
126 | 120 | raise RuntimeError(error_msg) |
127 | 121 |
|
| 122 | + @property |
| 123 | + def model_name(self) -> str: |
| 124 | + """ |
| 125 | + Get the model class name without QEff/QEFF prefix. |
| 126 | +
|
| 127 | + This property extracts the underlying model's class name and removes |
| 128 | + any QEff or QEFF prefix that may have been added during wrapping. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel") |
| 132 | + """ |
| 133 | + mname = self.model.__class__.__name__ |
| 134 | + if mname.startswith("QEff") or mname.startswith("QEFF"): |
| 135 | + mname = mname[4:] |
| 136 | + return mname |
| 137 | + |
128 | 138 | @property |
129 | 139 | @abstractmethod |
130 | | - def model_name(self) -> str: ... |
| 140 | + def get_model_config(self) -> Dict: |
| 141 | + """ |
| 142 | + Get the model configuration as a dictionary. |
| 143 | +
|
| 144 | + This is an abstract property that must be implemented by all subclasses. |
| 145 | + Typically returns: self.model.config.__dict__ |
| 146 | +
|
| 147 | + Returns: |
| 148 | + Dict: The configuration dictionary of the underlying model |
| 149 | + """ |
| 150 | + pass |
131 | 151 |
|
132 | 152 | @abstractmethod |
133 | 153 | def export(self, export_dir: Optional[str] = None) -> Path: |
@@ -188,7 +208,6 @@ def _export( |
188 | 208 | onnx_transform_kwargs: Optional[Dict[str, any]] = None, |
189 | 209 | export_dir: Optional[str] = None, |
190 | 210 | offload_pt_weights: bool = True, |
191 | | - use_onnx_subfunctions: bool = False, |
192 | 211 | ) -> str: |
193 | 212 | """ |
194 | 213 | Export the PyTorch model to ONNX and apply ONNX transforms |
@@ -253,18 +272,8 @@ def _export( |
253 | 272 | input_names.append(param) |
254 | 273 |
|
255 | 274 | try: |
256 | | - # Initialize the registry with your custom ops |
| 275 | + # Export to ONNX |
257 | 276 | export_kwargs = {} if export_kwargs is None else export_kwargs |
258 | | - if use_onnx_subfunctions: |
259 | | - warnings.warn( |
260 | | - "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." |
261 | | - ) |
262 | | - apply_torch_patches() |
263 | | - InvalidIndexProvider.SUBFUNC_ENABLED = True |
264 | | - output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] |
265 | | - export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) |
266 | | - self._onnx_transforms.append(RenameFunctionOutputsTransform) |
267 | | - self._onnx_transforms.append(CustomOpTransform) |
268 | 277 |
|
269 | 278 | torch.onnx.export( |
270 | 279 | self.model, |
@@ -309,12 +318,6 @@ def _export( |
309 | 318 | finally: |
310 | 319 | shutil.rmtree(tmp_onnx_dir, ignore_errors=True) |
311 | 320 |
|
312 | | - if use_onnx_subfunctions: |
313 | | - undo_torch_patches() |
314 | | - InvalidIndexProvider.SUBFUNC_ENABLED = False |
315 | | - self._onnx_transforms.remove(CustomOpTransform) |
316 | | - self._onnx_transforms.remove(RenameFunctionOutputsTransform) |
317 | | - |
318 | 321 | self.onnx_path = onnx_path |
319 | 322 | return onnx_path |
320 | 323 |
|
|
0 commit comments