Skip to content

Commit 09f6f80

Browse files
author
Amit Raj
committed
Handled 1. Mutiple time export issue 2. Meta device error after first export 3. Hash getting changed after each export
Signed-off-by: Amit Raj <[email protected]>
1 parent b2ec576 commit 09f6f80

File tree

11 files changed

+16
-111
lines changed

11 files changed

+16
-111
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def _export(
201201
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
202202
export_dir: Optional[str] = None,
203203
offload_pt_weights: bool = True,
204-
use_onnx_subfunctions: bool = False,
205204
) -> str:
206205
"""
207206
Export the PyTorch model to ONNX and apply ONNX transforms

QEfficient/diffusers/models/attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def forward(
4747
# "feed_forward_chunk_size" can be used to save memory
4848
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
4949
else:
50-
# ff_output = self.ff(norm_hidden_states)
5150
ff_output = self.ff(norm_hidden_states, block_size=4096)
5251
ff_output = gate_mlp.unsqueeze(1) * ff_output
5352

@@ -68,7 +67,6 @@ def forward(
6867
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
6968
)
7069
else:
71-
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
7270
context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
7371
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
7472

QEfficient/diffusers/models/normalization.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,9 @@ class QEffAdaLayerNormZero(AdaLayerNormZero):
1414
def forward(
1515
self,
1616
x: torch.Tensor,
17-
timestep: Optional[torch.Tensor] = None,
18-
class_labels: Optional[torch.LongTensor] = None,
19-
hidden_dtype: Optional[torch.dtype] = None,
2017
shift_msa: Optional[torch.Tensor] = None,
2118
scale_msa: Optional[torch.Tensor] = None,
22-
# emb: Optional[torch.Tensor] = None,
2319
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
24-
# if self.emb is not None:
25-
# emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
26-
# emb = self.linear(self.silu(emb))
27-
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
2820
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2921
return x
3022

@@ -36,15 +28,12 @@ def forward(
3628
scale_msa: Optional[torch.Tensor] = None,
3729
shift_msa: Optional[torch.Tensor] = None,
3830
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
39-
# shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
4031
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
4132
return x
4233

4334

4435
class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
4536
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
46-
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
47-
# emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
4837
emb = conditioning_embedding
4938
scale, shift = torch.chunk(emb, 2, dim=1)
5039
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import torch
12-
import torch.nn as nn
1312
from diffusers.models.attention_dispatch import dispatch_attention_fn
1413
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1514
from diffusers.models.transformers.transformer_flux import (
@@ -21,11 +20,6 @@
2120
_get_qkv_projections,
2221
)
2322

24-
from QEfficient.diffusers.models.normalization import (
25-
QEffAdaLayerNormZero,
26-
QEffAdaLayerNormZeroSingle,
27-
)
28-
2923

3024
def qeff_apply_rotary_emb(
3125
x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
@@ -120,24 +114,6 @@ def __qeff_init__(self):
120114

121115

122116
class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
123-
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
124-
super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio)
125-
self.mlp_hidden_dim = int(dim * mlp_ratio)
126-
self.norm = QEffAdaLayerNormZeroSingle(dim)
127-
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
128-
self.act_mlp = nn.GELU(approximate="tanh")
129-
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
130-
self.attn = QEffFluxAttention(
131-
query_dim=dim,
132-
dim_head=attention_head_dim,
133-
heads=num_attention_heads,
134-
out_dim=dim,
135-
bias=True,
136-
processor=QEffFluxAttnProcessor(),
137-
eps=1e-6,
138-
pre_only=True,
139-
)
140-
141117
def forward(
142118
self,
143119
hidden_states: torch.Tensor,
@@ -163,33 +139,12 @@ def forward(
163139
gate = gate.unsqueeze(1)
164140
hidden_states = gate * self.proj_out(hidden_states)
165141
hidden_states = residual + hidden_states
166-
# if hidden_states.dtype == torch.float16:
167-
hidden_states = hidden_states.clip(-65504, 65504)
168142

169143
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
170144
return encoder_hidden_states, hidden_states
171145

172146

173147
class QEffFluxTransformerBlock(FluxTransformerBlock):
174-
def __init__(
175-
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
176-
):
177-
super().__init__(dim, num_attention_heads, attention_head_dim)
178-
179-
self.norm1 = QEffAdaLayerNormZero(dim)
180-
self.norm1_context = QEffAdaLayerNormZero(dim)
181-
self.attn = QEffFluxAttention(
182-
query_dim=dim,
183-
added_kv_proj_dim=dim,
184-
dim_head=attention_head_dim,
185-
heads=num_attention_heads,
186-
out_dim=dim,
187-
context_pre_only=False,
188-
bias=True,
189-
processor=QEffFluxAttnProcessor(),
190-
eps=eps,
191-
)
192-
193148
def forward(
194149
self,
195150
hidden_states: torch.Tensor,
@@ -395,31 +350,3 @@ def forward(
395350
return (output,)
396351

397352
return Transformer2DModelOutput(sample=output)
398-
399-
400-
class QEffFluxTransformer2DModelOF(QEffFluxTransformer2DModel):
401-
def __qeff_init__(self):
402-
self.transformer_blocks = nn.ModuleList()
403-
self._block_classes = set()
404-
405-
for _ in range(self.config.num_layers):
406-
BlockClass = QEffFluxTransformerBlock
407-
block = BlockClass(
408-
dim=self.inner_dim,
409-
num_attention_heads=self.config.num_attention_heads,
410-
attention_head_dim=self.config.attention_head_dim,
411-
)
412-
self.transformer_blocks.append(block)
413-
self._block_classes.add(BlockClass)
414-
415-
self.single_transformer_blocks = nn.ModuleList()
416-
417-
for _ in range(self.config.num_single_layers):
418-
SingleBlockClass = QEffFluxSingleTransformerBlock
419-
single_block = SingleBlockClass(
420-
dim=self.inner_dim,
421-
num_attention_heads=self.config.num_attention_heads,
422-
attention_head_dim=self.config.attention_head_dim,
423-
)
424-
self.single_transformer_blocks.append(single_block)
425-
self._block_classes.add(SingleBlockClass)

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,16 @@ def export(
485485
if use_onnx_subfunctions:
486486
export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}}
487487

488+
# Sort _use_default_values in config to ensure consistent hash generation during export
489+
self.model.config["_use_default_values"].sort()
490+
488491
return self._export(
489492
example_inputs=inputs,
490493
output_names=output_names,
491494
dynamic_axes=dynamic_axes,
492495
export_dir=export_dir,
493496
export_kwargs=export_kwargs,
497+
offload_pt_weights=False, # As weights are needed with AdaLN changes
494498
)
495499

496500
def compile(self, specializations: List[Dict], **compiler_options) -> None:

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,7 @@
198198
)
199199
from transformers.models.t5.modeling_t5 import (
200200
T5Attention,
201-
T5LayerCrossAttention,
202-
T5LayerFF,
203201
T5LayerNorm,
204-
T5LayerSelfAttention,
205202
)
206203
from transformers.models.whisper.modeling_whisper import (
207204
WhisperAttention,
@@ -425,10 +422,7 @@
425422
)
426423
from QEfficient.transformers.models.t5.modeling_t5 import (
427424
QEffT5Attention,
428-
QEffT5LayerCrossAttention,
429-
QEffT5LayerFF,
430425
QEffT5LayerNorm,
431-
QEffT5LayerSelfAttention,
432426
)
433427
from QEfficient.transformers.models.whisper.modeling_whisper import (
434428
QEffWhisperAttention,
@@ -824,9 +818,6 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
824818
class T5ModelTransform(ModuleMappingTransform):
825819
# supported architectures
826820
_module_mapping = {
827-
T5LayerFF: QEffT5LayerFF,
828-
T5LayerSelfAttention: QEffT5LayerSelfAttention,
829-
T5LayerCrossAttention: QEffT5LayerCrossAttention,
830821
T5Attention: QEffT5Attention,
831822
T5LayerNorm: QEffT5LayerNorm,
832823
}

QEfficient/utils/_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,6 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
530530
"""
531531
model_params = copy.deepcopy(kwargs)
532532
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
533-
534-
# TODO: Refactor this configuration handling to occur during export phase
535-
# This is necessary because diffusion models have a different way to change number of layers
536-
# that isn't properly considered in the current implementation
537-
model_params["config"] = (
538-
qeff_model.model.config.to_diff_dict()
539-
if hasattr(qeff_model.model.config, "to_diff_dict")
540-
else qeff_model.model.config
541-
)
542533
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
543534
model_params["applied_transform_names"] = qeff_model._transform_names()
544535
return model_params

QEfficient/utils/export_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ def _generate_export_hash(qeff_model, args, kwargs, func):
122122
bound_args.apply_defaults()
123123
all_args = bound_args.arguments
124124

125+
# Use the model's current configuration for hashing to ensure any post-load modifications are captured
126+
qeff_model.hash_params = {
127+
"model_config": (
128+
qeff_model.model.config.to_diff_dict()
129+
if hasattr(qeff_model.model.config, "to_diff_dict")
130+
else qeff_model.model.config
131+
),
132+
}
133+
125134
# Generate hash from relevant parameters
126135
export_hash, filtered_hash_params = create_export_hash(
127136
model_params=qeff_model.hash_params,

QEfficient/utils/hash_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
def json_serializable(obj):
1616
if isinstance(obj, set):
17-
return [cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]
18-
# Handle objects with to_dict() method (e.g., transformers config objects)
19-
if hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict")):
20-
return obj.to_dict()
17+
return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj])
2118
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
2219

2320

examples/diffusers/flux/flux_1_schnell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
max_sequence_length=256,
4343
generator=torch.manual_seed(42),
4444
parallel_compile=True,
45-
use_onnx_subfunctions=True,
45+
use_onnx_subfunctions=False,
4646
)
4747

4848
# Extract the generated image from the output

0 commit comments

Comments
 (0)