diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py
index 7f63b34ca..2d8f72e0a 100644
--- a/QEfficient/__init__.py
+++ b/QEfficient/__init__.py
@@ -18,6 +18,7 @@
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
+from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -39,6 +40,7 @@
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
+ "QEffFluxPipeline",
]
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index ef7e83adf..ea347016b 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -8,7 +8,6 @@
import gc
import inspect
import logging
-import re
import shutil
import subprocess
import warnings
@@ -21,26 +20,21 @@
from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
- CustomOpTransform,
OnnxTransformPipeline,
- RenameFunctionOutputsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
-from QEfficient.transformers.cache_utils import InvalidIndexProvider
-from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
-from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+from QEfficient.utils.export_utils import export_wrapper
logger = logging.getLogger(__name__)
@@ -125,9 +119,35 @@ def _model_offloaded_check(self) -> None:
logger.error(error_msg)
raise RuntimeError(error_msg)
+ @property
+ def model_name(self) -> str:
+ """
+ Get the model class name without QEff/QEFF prefix.
+
+ This property extracts the underlying model's class name and removes
+ any QEff or QEFF prefix that may have been added during wrapping.
+
+ Returns:
+ str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
+ """
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
@property
@abstractmethod
- def model_name(self) -> str: ...
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ This is an abstract property that must be implemented by all subclasses.
+ Typically returns: self.model.config.__dict__
+
+ Returns:
+ Dict: The configuration dictionary of the underlying model
+ """
+ pass
@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
@@ -188,7 +208,6 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
- use_onnx_subfunctions: bool = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
@@ -253,18 +272,8 @@ def _export(
input_names.append(param)
try:
- # Initialize the registry with your custom ops
+ # Export to ONNX
export_kwargs = {} if export_kwargs is None else export_kwargs
- if use_onnx_subfunctions:
- warnings.warn(
- "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
- )
- apply_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = True
- output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
- export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
- self._onnx_transforms.append(RenameFunctionOutputsTransform)
- self._onnx_transforms.append(CustomOpTransform)
torch.onnx.export(
self.model,
@@ -309,12 +318,6 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
- if use_onnx_subfunctions:
- undo_torch_patches()
- InvalidIndexProvider.SUBFUNC_ENABLED = False
- self._onnx_transforms.remove(CustomOpTransform)
- self._onnx_transforms.remove(RenameFunctionOutputsTransform)
-
self.onnx_path = onnx_path
return onnx_path
diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md
new file mode 100644
index 000000000..40d45e984
--- /dev/null
+++ b/QEfficient/diffusers/README.md
@@ -0,0 +1,95 @@
+
+
+
+
+# **Diffusion Models on Qualcomm Cloud AI 100**
+
+
+
+
+### šØ **Experience the Future of AI Image Generation**
+
+* Optimized for Qualcomm Cloud AI 100*
+
+

+
+**Generated with**: `black-forest-labs/FLUX.1-schnell` ⢠`"A girl laughing"` ⢠4 steps ⢠0.0 guidance scale ⢠ā”
+
+
+
+
+
+
+
+[](https://github.com/huggingface/diffusers)
+
+
+---
+
+## ⨠Overview
+
+QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.
+
+## š ļø Installation
+
+### Prerequisites
+
+Ensure you have Python 3.8+ and the required dependencies:
+
+```bash
+# Create Python virtual environment (Recommended Python 3.10)
+sudo apt install python3.10-venv
+python3.10 -m venv qeff_env
+source qeff_env/bin/activate
+pip install -U pip
+```
+
+### Install QEfficient
+
+```bash
+# Install from GitHub (includes diffusers support)
+pip install git+https://github.com/quic/efficient-transformers
+
+# Or build from source
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install build wheel
+python -m build --wheel --outdir dist
+pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
+```
+
+---
+
+## šÆ Supported Models
+- ā
[`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+
+---
+
+
+## š Examples
+
+Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:
+
+---
+
+## š¤ Contributing
+
+We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.
+
+
+
+---
+
+## š Acknowledgments
+
+- **HuggingFace Diffusers**: For the excellent foundation library
+- **Stability AI**: For the amazing Stable Diffusion models
+---
+
+## š Support
+
+- š **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
+- š **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)
+
+---
+
diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py
new file mode 100644
index 000000000..933832ed8
--- /dev/null
+++ b/QEfficient/diffusers/models/normalization.py
@@ -0,0 +1,40 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+class QEffAdaLayerNormZero(AdaLayerNormZero):
+ def forward(
+ self,
+ x: torch.Tensor,
+ shift_msa: Optional[torch.Tensor] = None,
+ scale_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_msa: Optional[torch.Tensor] = None,
+ shift_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ emb = conditioning_embedding
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py
new file mode 100644
index 000000000..d3c84ee63
--- /dev/null
+++ b/QEfficient/diffusers/models/pytorch_transforms.py
@@ -0,0 +1,56 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+)
+from torch import nn
+
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.customop.rms_norm import CustomRMSNormAIC
+from QEfficient.diffusers.models.normalization import (
+ QEffAdaLayerNormContinuous,
+ QEffAdaLayerNormZero,
+ QEffAdaLayerNormZeroSingle,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxAttention,
+ QEffFluxAttnProcessor,
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformer2DModel,
+ QEffFluxTransformerBlock,
+)
+
+
+class CustomOpsTransform(ModuleMappingTransform):
+ _module_mapping = {
+ RMSNorm: CustomRMSNormAIC,
+ nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
+ }
+
+
+class AttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
+ FluxTransformerBlock: QEffFluxTransformerBlock,
+ FluxTransformer2DModel: QEffFluxTransformer2DModel,
+ FluxAttention: QEffFluxAttention,
+ FluxAttnProcessor: QEffFluxAttnProcessor,
+ }
+
+
+class NormalizationTransform(ModuleMappingTransform):
+ _module_mapping = {
+ AdaLayerNormZero: QEffAdaLayerNormZero,
+ AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
+ AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
+ }
diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py
new file mode 100644
index 000000000..5cb44af45
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/transformer_flux.py
@@ -0,0 +1,327 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.models.attention_dispatch import dispatch_attention_fn
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+ _get_qkv_projections,
+)
+
+from QEfficient.utils.logging_utils import logger
+
+
+def qeff_apply_rotary_emb(
+ x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+ B, S, H, D = x.shape
+ x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+class QEffFluxAttnProcessor(FluxAttnProcessor):
+ _attention_backend = None
+ _parallel_config = None
+
+ def __call__(
+ self,
+ attn: "QEffFluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = qeff_apply_rotary_emb(query, image_rotary_emb)
+ key = qeff_apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class QEffFluxAttention(FluxAttention):
+ def __qeff_init__(self):
+ processor = QEffFluxAttnProcessor()
+ self.processor = processor
+
+
+class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ shift_msa, scale_msa, gate = torch.split(temb, 1)
+ residual = hidden_states
+ norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ # if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformerBlock(FluxTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb1 = tuple(torch.split(temb[:6], 1))
+ temb2 = tuple(torch.split(temb[6:], 1))
+ norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1])
+ gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:]
+
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1])
+
+ c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:]
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ # if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformer2DModel(FluxTransformer2DModel):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ adaln_emb: torch.Tensor = None,
+ adaln_single_emb: torch.Tensor = None,
+ adaln_out: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_single_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
+
+ hidden_states = self.norm_out(hidden_states, adaln_out)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/configs/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
new file mode 100644
index 000000000..511746469
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
@@ -0,0 +1,854 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+# TODO: Pipeline Architecture Improvements
+# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile,
+# and inference APIs across all diffusion pipelines, promoting code reusability
+# and consistent interface design.
+# 2. Implement persistent QPC session management strategy to retain/drop compiled model
+# sessions in memory across all pipeline modules.
+
+import os
+import time
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from tqdm import tqdm
+
+from QEfficient.diffusers.pipelines.pipeline_module import (
+ QEffFluxTransformerModel,
+ QEffTextEncoder,
+ QEffVAE,
+)
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ONNX_SUBFUNCTION_MODULE,
+ ModulePerf,
+ QEffPipelineOutput,
+ calculate_compressed_latent_dimension,
+ compile_modules_parallel,
+ compile_modules_sequential,
+ config_manager,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils.logging_utils import logger
+
+
+class QEffFluxPipeline:
+ """
+ QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware.
+
+ This pipeline provides an optimized implementation of the Flux diffusion model specifically designed
+ for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX format and compiled
+ into Qualcomm Program Container (QPC) files for efficient inference.
+
+ The pipeline supports the complete Flux workflow including:
+ - Dual text encoding with CLIP and T5 encoders
+ - Transformer-based denoising with adaptive layer normalization
+ - VAE decoding for final image generation
+ - Performance monitoring and optimization
+
+ Attributes:
+ text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings
+ text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings
+ transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising
+ vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion
+ modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations
+ model (FluxPipeline): Original HuggingFace Flux model reference
+ tokenizer: CLIP tokenizer for text preprocessing
+ scheduler: Diffusion scheduler for timestep management
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> images = pipeline(
+ ... prompt="A beautiful sunset over mountains",
+ ... height=512,
+ ... width=512,
+ ... num_inference_steps=28
+ ... )
+ >>> images.images[0].save("generated_image.png")
+ """
+
+ _hf_auto_class = FluxPipeline
+
+ def __init__(self, model, *args, **kwargs):
+ """
+ Initialize the QEfficient Flux pipeline.
+
+ This pipeline provides an optimized implementation of the Flux text-to-image model
+ for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model
+ components with QEfficient-optimized versions that can be exported to ONNX and compiled
+ for QAIC devices.
+
+ Args:
+ model: Pre-loaded FluxPipeline model
+ **kwargs: Additional arguments including height and width
+ """
+
+ # Wrap model components with QEfficient optimized versions
+ self.model = model
+ self.text_encoder = QEffTextEncoder(model.text_encoder)
+ self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
+ self.transformer = QEffFluxTransformerModel(model.transformer)
+ self.vae_decode = QEffVAE(model.vae, "decoder")
+
+ # Store all modules in a dictionary for easy iteration during export/compile
+ self.modules = {
+ "text_encoder": self.text_encoder,
+ "text_encoder_2": self.text_encoder_2,
+ "transformer": self.transformer,
+ "vae_decoder": self.vae_decode,
+ }
+
+ # Copy tokenizers and scheduler from the original model
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.text_encoder_2.tokenizer = model.tokenizer_2
+ self.tokenizer_max_length = model.tokenizer_max_length
+ self.scheduler = model.scheduler
+
+ # Override VAE forward method to use decode directly
+ self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
+ latent_sample, return_dict
+ )
+
+ # Sync max position embeddings between text encoders
+ self.text_encoder_2.model.config.max_position_embeddings = (
+ self.text_encoder.model.config.max_position_embeddings
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ **kwargs,
+ ):
+ """
+ Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations.
+
+ This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained
+ Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU
+ and wraps all components with QEfficient-optimized versions for QAIC deployment.
+
+ Args:
+ pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier
+ (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory.
+ **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained().
+
+ Returns:
+ QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components
+ ready for export, compilation, and inference on QAIC devices.
+
+ Raises:
+ ValueError: If the model path is invalid or model cannot be loaded
+ OSError: If there are issues accessing the model files
+ RuntimeError: If model initialization fails
+
+ Example:
+ >>> # Load from HuggingFace Hub
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>>
+ >>> # Load from local path
+ >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model")
+ >>>
+ >>> # Load with custom cache directory
+ >>> pipeline = QEffFluxPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-dev",
+ ... cache_dir="/custom/cache/dir"
+ ... )
+ """
+ # Load the base Flux model in float32 on CPU
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ device_map="cpu",
+ **kwargs,
+ )
+
+ return cls(
+ model=model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ **kwargs,
+ )
+
+ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
+ """
+ Export all pipeline modules to ONNX format for deployment preparation.
+
+ This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder,
+ Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific
+ configuration including dynamic axes, input/output specifications, and optimization settings.
+
+ The export process prepares the models for subsequent compilation to QPC format, enabling
+ efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules
+ to optimize memory usage and performance.
+
+ Args:
+ export_dir (str, optional): Target directory for saving ONNX model files. If None,
+ uses the default export directory structure based on model name and configuration.
+ The directory will be created if it doesn't exist.
+ use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction
+ optimization for supported modules. This can optimize thegraph and
+ improve compilation efficiency for models like the transformer.
+
+ Returns:
+ str: Absolute path to the export directory containing all ONNX model files.
+ Each module will have its own subdirectory with the exported ONNX file.
+
+ Raises:
+ RuntimeError: If ONNX export fails for any module
+ OSError: If there are issues creating the export directory or writing files
+ ValueError: If module configurations are invalid
+
+ Note:
+ - All models are exported in float32 precision for maximum compatibility
+ - Dynamic axes are configured to support variable batch sizes and sequence lengths
+ - The export process may take several minutes depending on model size
+ - Exported ONNX files can be large (several GB for complete pipeline)
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> export_path = pipeline.export(
+ ... export_dir="/path/to/export",
+ ... use_onnx_subfunctions=True
+ ... )
+ >>> print(f"Models exported to: {export_path}")
+ """
+ for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"):
+ # Get ONNX export configuration for this module
+ example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params()
+
+ export_params = {
+ "inputs": example_inputs,
+ "output_names": output_names,
+ "dynamic_axes": dynamic_axes,
+ "export_dir": export_dir,
+ }
+
+ if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE:
+ export_params["use_onnx_subfunctions"] = True
+
+ module_obj.export(**export_params)
+
+ @staticmethod
+ def get_default_config_path() -> str:
+ """
+ Get the absolute path to the default Flux pipeline configuration file.
+
+ Returns:
+ str: Absolute path to the flux_config.json file containing default pipeline
+ configuration settings for compilation and device allocation.
+ """
+ return "QEfficient/diffusers/pipelines/configs/flux_config.json"
+
+ def compile(
+ self,
+ compile_config: Optional[str] = None,
+ parallel: bool = False,
+ height: int = 512,
+ width: int = 512,
+ use_onnx_subfunctions: bool = False,
+ ) -> None:
+ """
+ Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware.
+
+ Args:
+ compile_config (str, optional): Path to a JSON configuration file containing
+ compilation settings, device mappings, and optimization parameters. If None,
+ uses the default configuration from get_default_config_path().
+ parallel (bool, default=False): Compilation mode selection:
+ - True: Compile modules in parallel using ThreadPoolExecutor for faster processing
+ - False: Compile modules sequentially for lower resource usage
+ height (int, default=512): Target image height in pixels.
+ width (int, default=512): Target image width in pixels.
+ use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX
+ subfunctions before compilation.
+
+ Raises:
+ RuntimeError: If compilation fails for any module or if QAIC compiler is not available
+ FileNotFoundError: If ONNX models haven't been exported or config file is missing
+ ValueError: If configuration parameters are invalid
+ OSError: If there are issues with file I/O during compilation
+
+ Example:
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> # Sequential compilation with default config
+ >>> pipeline.compile(height=1024, width=1024)
+ >>>
+ >>> # Parallel compilation with custom config
+ >>> pipeline.compile(
+ ... compile_config="/path/to/custom_config.json",
+ ... parallel=True,
+ ... height=512,
+ ... width=512
+ ... )
+ """
+ # Ensure all modules are exported to ONNX before compilation
+ if any(
+ path is None
+ for path in [
+ self.text_encoder.onnx_path,
+ self.text_encoder_2.onnx_path,
+ self.transformer.onnx_path,
+ self.vae_decode.onnx_path,
+ ]
+ ):
+ self.export(use_onnx_subfunctions=use_onnx_subfunctions)
+
+ # Load compilation configuration
+ config_manager(self, config_source=compile_config)
+
+ # Calculate compressed latent dimension using utility function
+ cl, latent_height, latent_width = calculate_compressed_latent_dimension(
+ height, width, self.model.vae_scale_factor
+ )
+
+ # Prepare dynamic specialization updates based on image dimensions
+ specialization_updates = {
+ "transformer": {"cl": cl},
+ "vae_decoder": {
+ "latent_height": latent_height,
+ "latent_width": latent_width,
+ },
+ }
+
+ # Use generic utility functions for compilation
+ if parallel:
+ compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
+ else:
+ compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the T5 text encoder for detailed semantic understanding.
+
+ T5 provides rich sequence embeddings that capture fine-grained text details,
+ complementing CLIP's global representation in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ max_sequence_length (int): Maximum token sequence length (default: 512)
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (prompt_embeds, inference_time)
+ - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096]
+ - inference_time (float): T5 encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts with padding and truncation
+ text_inputs = self.text_encoder_2.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.text_encoder_2.tokenizer.batch_decode(
+ untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because `max_sequence_length` is set to "
+ f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder_2.qpc_session is None:
+ self.text_encoder_2.qpc_session = QAICInferenceSession(
+ str(self.text_encoder_2.qpc_path), device_ids=device_ids
+ )
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_2_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model
+ ).astype(np.int32),
+ }
+ self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run T5 encoder inference and measure time
+ start_t5_time = time.perf_counter()
+ prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
+ end_t5_time = time.perf_counter()
+ text_encoder_2_perf = end_t5_time - start_t5_time
+
+ # Duplicate embeddings for multiple images per prompt
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_encoder_2_perf
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode text prompts using the CLIP text encoder for global semantic representation.
+
+ CLIP provides pooled embeddings that capture high-level semantic meaning,
+ working alongside T5's detailed sequence embeddings in Flux's dual encoder setup.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (pooled_prompt_embeds, inference_time)
+ - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768]
+ - inference_time (float): CLIP encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ # Tokenize prompts
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ f"The following part of your input was truncated because CLIP can only handle sequences up to "
+ f"{self.tokenizer_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder.qpc_session is None:
+ self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids)
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_output = {
+ "last_hidden_state": np.random.rand(
+ batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size
+ ).astype(np.float32),
+ "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32),
+ }
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run CLIP encoder inference and measure time
+ start_text_encoder_time = time.perf_counter()
+ aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
+ end_text_encoder_time = time.perf_counter()
+ text_encoder_perf = end_text_encoder_time - start_text_encoder_time
+ # Extract pooled output (used for conditioning in Flux)
+ prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
+
+ # Duplicate embeddings for multiple images per prompt
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, text_encoder_perf
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ """
+ Encode text prompts using Flux's dual text encoder architecture.
+
+ Flux employs both CLIP and T5 encoders for comprehensive text understanding:
+ - CLIP provides pooled embeddings for global semantic conditioning
+ - T5 provides detailed sequence embeddings for fine-grained text control
+
+ Args:
+ prompt (str or List[str]): Primary prompt(s) for both encoders
+ prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt
+ num_images_per_prompt (int): Number of images to generate per prompt
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings
+ max_sequence_length (int): Maximum sequence length for T5 tokenization
+
+ Returns:
+ tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times)
+ - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096]
+ - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768]
+ - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3]
+ - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time]
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ # Use primary prompt for both encoders if secondary not provided
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # Encode with CLIP (returns pooled embeddings)
+ pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device_ids=self.text_encoder.device_ids,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ # Encode with T5 (returns sequence embeddings)
+ prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device_ids=self.text_encoder_2.device_ids,
+ )
+
+ # Create text position IDs (required by Flux transformer)
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf]
+
+ def __call__(
+ self,
+ height: int = 512,
+ width: int = 512,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ use_onnx_subfunctions: bool = False,
+ ):
+ """
+ Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware.
+
+ This is the main entry point for text-to-image generation. It orchestrates the complete Flux
+ diffusion pipeline optimized for Qualcomm AI Cloud devices.
+
+ Args:
+ height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512.
+ width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512.
+ prompt (str or List[str]): Primary text prompt(s) describing the desired image(s).
+ Required unless `prompt_embeds` is provided.
+ prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`.
+ negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid.
+ Only used when `true_cfg_scale > 1.0`.
+ negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`.
+ true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable
+ negative prompting. Default: 1.0 (disabled).
+ num_inference_steps (int, optional): Number of denoising steps. Default: 28.
+ timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`.
+ guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5.
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1.
+ generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility.
+ latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated.
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096].
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768].
+ negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings.
+ negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings.
+ output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent".
+ callback_on_step_end (Callable, optional): Callback function executed after each denoising step.
+ callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"].
+ max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512.
+ custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings.
+ parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False.
+ use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False.
+
+ Returns:
+ QEffPipelineOutput: A dataclass containing:
+ - images: Generated image(s) in the format specified by `output_type`
+ - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE)
+
+ Raises:
+ ValueError: If input validation fails or parameters are incompatible.
+ RuntimeError: If compilation fails or QAIC devices are unavailable.
+ FileNotFoundError: If custom config file is specified but not found.
+
+ Example:
+ >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline
+ >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ >>> result = pipeline(
+ ... prompt="A serene mountain landscape at sunset",
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=28,
+ ... guidance_scale=7.5
+ ... )
+ >>> result.images[0].save("mountain_sunset.png")
+ >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s")
+ """
+ device = self.model._execution_device
+
+ if height is None or width is None:
+ logger.warning("Height or width is None. Setting default values of 512 for both dimensions.")
+
+ self.compile(
+ compile_config=custom_config_path,
+ parallel=parallel_compile,
+ height=height,
+ width=width,
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ )
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(self)
+
+ # Validate all inputs
+ self.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # Step 2: Determine batch size from inputs
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 3: Encode prompts with both text encoders
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Encode negative prompts if using true classifier-free guidance
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = self.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = self.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Calculate compressed latent dimension for transformer buffer allocation
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor)
+
+ # Initialize transformer inference session
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(
+ str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
+ )
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32),
+ }
+ self.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ self.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with self.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds)
+
+ # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.transformer_blocks)):
+ block = self.transformer.model.transformer_blocks[block_idx]
+ # Process through norm1 and norm1_context
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.single_transformer_blocks)):
+ block = self.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = self.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": prompt_embeds.detach().numpy(),
+ "pooled_projections": pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.perf_counter()
+ outputs = self.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.perf_counter()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Update latents using scheduler (x_t -> x_t-1)
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch (workaround for MPS backend bug)
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Execute callback if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images (unless output_type is "latent")
+ if output_type == "latent":
+ image = latents
+ else:
+ # Unpack and denormalize latents
+ latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor)
+ latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
+
+ # Initialize VAE decoder inference session
+ if self.vae_decode.qpc_session is None:
+ self.vae_decode.qpc_session = QAICInferenceSession(
+ str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)}
+ self.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.perf_counter()
+ image = self.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.perf_counter()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = self.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py
new file mode 100644
index 000000000..41a3d29f7
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_module.py
@@ -0,0 +1,481 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+from QEfficient.base.modeling_qeff import QEFFBaseModel
+from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
+from QEfficient.diffusers.models.pytorch_transforms import (
+ AttentionTransform,
+ CustomOpsTransform,
+ NormalizationTransform,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformerBlock,
+)
+from QEfficient.transformers.models.pytorch_transforms import (
+ T5ModelTransform,
+)
+from QEfficient.utils import constants
+
+
+class QEffTextEncoder(QEFFBaseModel):
+ """
+ Wrapper for text encoder models with ONNX export and QAIC compilation capabilities.
+
+ This class handles text encoder models (CLIP, T5) with specific transformations and
+ optimizations for efficient inference on Qualcomm AI hardware. It applies custom
+ PyTorch and ONNX transformations to prepare models for deployment.
+
+ Attributes:
+ model (nn.Module): The wrapped text encoder model (deep copy of original)
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying text encoder model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the text encoder wrapper.
+
+ Args:
+ model (nn.Module): The text encoder model to wrap (CLIP or T5)
+ """
+ super().__init__(model)
+ self.model = model
+
+ def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the text encoder.
+
+ Creates example inputs, dynamic axes specifications, and output names
+ tailored to the specific text encoder type (CLIP vs T5).
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # Create example input with max sequence length
+ example_inputs = {
+ "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64),
+ }
+
+ # Define which dimensions can vary at runtime
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
+
+ # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output
+ if self.model.__class__.__name__ == "T5EncoderModel":
+ output_names = ["last_hidden_state"]
+ else:
+ output_names = ["last_hidden_state", "pooler_output"]
+ example_inputs["output_hidden_states"] = False
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the text encoder model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffUNet(QEFFBaseModel):
+ """
+ Wrapper for UNet models with ONNX export and QAIC compilation capabilities.
+
+ This class handles UNet models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. UNet is commonly used in
+ diffusion models for image generation tasks.
+
+ Attributes:
+ model (nn.Module): The wrapped UNet model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying UNet model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the UNet wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the UNet
+ """
+ super().__init__(model.unet)
+ self.model = model.unet
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the UNet model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffVAE(QEFFBaseModel):
+ """
+ Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation.
+
+ This class handles VAE models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion
+ pipelines for encoding images to latent space and decoding latents back to images.
+
+ Attributes:
+ model (nn.Module): The wrapped VAE model (deep copy of original)
+ type (str): VAE operation type ("encoder" or "decoder")
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying VAE model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module, type: str) -> None:
+ """
+ Initialize the VAE wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the VAE
+ type (str): VAE operation type ("encoder" or "decoder")
+ """
+ super().__init__(model)
+ self.model = model
+
+ # To have different hashing for encoder/decoder
+ self.model.config["type"] = type
+
+ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the VAE decoder.
+
+ Args:
+ latent_height (int): Height of latent representation (default: 32)
+ latent_width (int): Width of latent representation (default: 32)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # VAE decoder takes latent representation as input
+ example_inputs = {
+ "latent_sample": torch.randn(bs, 16, latent_height, latent_width),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ # All dimensions except channels can be dynamic
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the VAE model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffFluxTransformerModel(QEFFBaseModel):
+ """
+ Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities.
+
+ This class handles Flux Transformer2D models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion
+ architecture instead of traditional UNet, with dual transformer blocks and adaptive layer
+ normalization (AdaLN) for conditioning.
+
+ Attributes:
+ model (nn.Module): The wrapped Flux transformer model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying Flux transformer model
+ """
+ return self.model.config.__dict__
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the Flux transformer wrapper.
+
+ Args:
+ model (nn.Module): The Flux transformer model to wrap
+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
+ for better modularity and potential optimization
+ """
+ super().__init__(model)
+
+ def get_onnx_params(
+ self,
+ batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
+ seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH,
+ cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM,
+ ) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the Flux transformer.
+
+ Creates example inputs for all Flux-specific inputs including hidden states,
+ text embeddings, timestep conditioning, and AdaLN embeddings.
+
+ Args:
+ batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE)
+ seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH)
+ cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ example_inputs = {
+ # Latent representation of the image
+ "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32),
+ "encoder_hidden_states": torch.randn(
+ batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32
+ ),
+ "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32),
+ "timestep": torch.tensor([1.0], dtype=torch.float32),
+ "img_ids": torch.randn(cl, 3, dtype=torch.float32),
+ "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32),
+ # AdaLN embeddings for dual transformer blocks
+ # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_emb": torch.randn(
+ self.model.config["num_layers"],
+ constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # AdaLN embeddings for single transformer blocks
+ # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM]
+ "adaln_single_emb": torch.randn(
+ self.model.config["num_single_layers"],
+ constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS,
+ constants.FLUX_ADALN_HIDDEN_DIM,
+ dtype=torch.float32,
+ ),
+ # Output AdaLN embedding
+ # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection
+ "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32),
+ }
+
+ output_names = ["output"]
+
+ # Define dynamic dimensions for runtime flexibility
+ dynamic_axes = {
+ "hidden_states": {0: "batch_size", 1: "cl"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ "pooled_projections": {0: "batch_size"},
+ "timestep": {0: "steps"},
+ "img_ids": {0: "cl"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ use_onnx_subfunctions: bool = False,
+ ) -> str:
+ """
+ Export the Flux transformer model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+
+ if use_onnx_subfunctions:
+ export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}}
+
+ # Sort _use_default_values in config to ensure consistent hash generation during export
+ self.model.config["_use_default_values"].sort()
+
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ offload_pt_weights=False, # As weights are needed with AdaLN changes
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 000000000..24eb36f53
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,218 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+from QEfficient.utils._utils import load_json
+from QEfficient.utils.logging_utils import logger
+
+
+def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int:
+ """
+ Calculate the compressed latent dimension.
+ Args:
+ height (int): Target image height in pixels
+ width (int): Target image width in pixels
+ vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux)
+
+ Returns:
+ int: Compressed latent dimension (cl) for transformer input buffer allocation
+ """
+ latent_height = height // vae_scale_factor
+ latent_width = width // vae_scale_factor
+ # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
+ cl = (latent_height * latent_width) // 4
+ return cl, latent_height, latent_width
+
+
+def config_manager(cls, config_source: Optional[str] = None):
+ """
+ JSON-based compilation configuration manager for diffusion pipelines.
+
+ Supports loading configuration from JSON files only. Automatically detects
+ model type and handles model-specific requirements.
+ Initialize the configuration manager.
+
+ Args:
+ config_source: Path to JSON configuration file. If None, uses default config.
+ """
+ if config_source is None:
+ config_source = cls.get_default_config_path()
+
+ if not isinstance(config_source, str):
+ raise ValueError("config_source must be a path to JSON configuration file")
+
+ # Direct use of load_json utility - no wrapper needed
+ if not os.path.exists(config_source):
+ raise FileNotFoundError(f"Configuration file not found: {config_source}")
+
+ cls.custom_config = load_json(config_source)
+
+
+def set_module_device_ids(cls):
+ """
+ Set device IDs for each module based on the custom configuration.
+
+ Iterates through all modules in the pipeline and assigns device IDs
+ from the configuration file to each module's device_ids attribute.
+ """
+ config_modules = cls.custom_config["modules"]
+ for module_name, module_obj in cls.modules.items():
+ module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
+
+
+def compile_modules_parallel(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules in parallel using ThreadPoolExecutor.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+ """
+
+ def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
+ """Prepare specializations and compile a single module."""
+ specializations = config["modules"][module_name]["specializations"].copy()
+ compile_kwargs = config["modules"][module_name]["compilation"]
+
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+ # Execute compilations in parallel
+ with ThreadPoolExecutor(max_workers=len(modules)) as executor:
+ futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()}
+
+ with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar:
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error(f"Compilation failed for {futures[future]}: {e}")
+ raise
+ pbar.update(1)
+
+
+def compile_modules_sequential(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules sequentially.
+
+ This function provides a generic way to compile diffusion pipeline modules
+ sequentially, which is the default behavior for backward compatibility.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+
+ """
+ for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"):
+ module_config = config["modules"]
+ specializations = module_config[module_name]["specializations"].copy()
+ compile_kwargs = module_config[module_name]["compilation"]
+
+ # Apply dynamic specialization updates if provided
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ # Compile the module to QPC format
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+
+@dataclass(frozen=True)
+class ModulePerf:
+ """
+ Data class to store performance metrics for a pipeline module.
+
+ Attributes:
+ module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder')
+ perf: Performance metric in seconds. Can be a single float for modules that run once,
+ or a list of floats for modules that run multiple times (e.g., transformer steps)
+ """
+
+ module_name: str
+ perf: int
+
+
+@dataclass(frozen=True)
+class QEffPipelineOutput:
+ """
+ Data class to store the output of a QEfficient diffusion pipeline.
+
+ Attributes:
+ pipeline_module: List of ModulePerf objects containing performance metrics for each module
+ images: Generated images as either a list of PIL Images or numpy array
+ """
+
+ pipeline_module: list[ModulePerf]
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+ def __repr__(self):
+ output_str = "=" * 60 + "\n"
+ output_str += "QEfficient Diffusers Pipeline Inference Report\n"
+ output_str += "=" * 60 + "\n\n"
+
+ # Module-wise inference times
+ output_str += "Module-wise Inference Times:\n"
+ output_str += "-" * 60 + "\n"
+
+ # Calculate E2E time while iterating
+ e2e_time = 0
+ for module_perf in self.pipeline_module:
+ module_name = module_perf.module_name
+ inference_time = module_perf.perf
+
+ # Add to E2E time
+ e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time
+
+ # Format module name for display
+ display_name = module_name.replace("_", " ").title()
+
+ # Handle transformer specially as it has a list of times
+ if isinstance(inference_time, list) and len(inference_time) > 0:
+ total_time = sum(inference_time)
+ avg_time = total_time / len(inference_time)
+ output_str += f" {display_name:25s} {total_time:.4f} s\n"
+ output_str += f" - Total steps: {len(inference_time)}\n"
+ output_str += f" - Average per step: {avg_time:.4f} s\n"
+ output_str += f" - Min step time: {min(inference_time):.4f} s\n"
+ output_str += f" - Max step time: {max(inference_time):.4f} s\n"
+ else:
+ # Single inference time value
+ output_str += f" {display_name:25s} {inference_time:.4f} s\n"
+
+ output_str += "-" * 60 + "\n\n"
+
+ # Print E2E time after all modules
+ output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n"
+ output_str += "=" * 60 + "\n"
+
+ return output_str
+
+
+# List of module name that require special handling during export
+# when use_onnx_subfunctions is enabled
+ONNX_SUBFUNCTION_MODULE = ["transformer"]
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 8edc1f3f0..16a809c96 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -124,21 +124,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying HuggingFace model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
class MultimodalUtilityMixin:
"""
@@ -701,21 +686,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying vision encoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -869,21 +839,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying language decoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -946,21 +901,6 @@ def __init__(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
"""
@@ -2131,21 +2071,6 @@ def cloud_ai_100_generate(
),
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -2437,21 +2362,6 @@ def __init__(
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying Causal Language Model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 21a867eb5..07b9fe7e1 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -197,6 +197,10 @@
Starcoder2ForCausalLM,
Starcoder2Model,
)
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
@@ -417,6 +421,10 @@
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
+from QEfficient.transformers.models.t5.modeling_t5 import (
+ QEffT5Attention,
+ QEffT5LayerNorm,
+)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
@@ -808,6 +816,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
+class T5ModelTransform(ModuleMappingTransform):
+ # supported architectures
+ _module_mapping = {
+ T5Attention: QEffT5Attention,
+ T5LayerNorm: QEffT5LayerNorm,
+ }
+
+
class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/transformers/models/t5/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py
new file mode 100644
index 000000000..f54201465
--- /dev/null
+++ b/QEfficient/transformers/models/t5/modeling_t5.py
@@ -0,0 +1,145 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from transformers import EncoderDecoderCache
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerNorm,
+)
+
+
+class QEffT5LayerNorm(T5LayerNorm):
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
+ variance = div_first.pow(2).sum(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class QEffT5Attention(T5Attention):
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = key_value_states is not None
+
+ query_states = self.q(hidden_states)
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
+ if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
+ value_states = curr_past_key_value.layers[self.layer_idx].values
+ else:
+ key_states = self.k(current_states)
+ value_states = self.v(current_states)
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ if position_bias is None:
+ key_length = key_states.shape[-2]
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
+ )
+ if past_key_value is not None: # This block is where the patch applies
+ position_bias = position_bias[:, :, -1:, :] # Added by patch
+
+ if mask is not None:
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
+ position_bias = position_bias + causal_mask
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+
+ # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py
index 49f0ad30b..3d6583f85 100755
--- a/QEfficient/utils/__init__.py
+++ b/QEfficient/utils/__init__.py
@@ -16,7 +16,6 @@
create_model_params,
custom_format_warning,
dump_qconfig,
- export_wrapper,
generate_mdp_partition_config,
get_num_layers_from_config,
get_num_layers_vlm,
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index 131a7fc26..26bae7a34 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -12,7 +12,6 @@
import subprocess
import xml.etree.ElementTree as ET
from dataclasses import dataclass
-from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import requests
@@ -27,9 +26,8 @@
PreTrainedTokenizerFast,
)
-from QEfficient.utils.cache import QEFF_HOME
from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants
-from QEfficient.utils.hash_utils import create_export_hash, json_serializable
+from QEfficient.utils.hash_utils import json_serializable
from QEfficient.utils.logging_utils import logger
@@ -532,61 +530,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
"""
model_params = copy.deepcopy(kwargs)
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
- model_params["config"] = qeff_model.model.config.to_diff_dict()
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
model_params["applied_transform_names"] = qeff_model._transform_names()
return model_params
-def export_wrapper(func):
- def wrapper(self, *args, **kwargs):
- export_dir = kwargs.get("export_dir", None)
- parent_dir = self.model_architecture or self.model_name
- export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name))
-
- # PREPROCESSING OF PARAMETERS
-
- # Get the original signature
- original_sig = inspect.signature(func)
-
- # Remove 'self' from parameters
- params = list(original_sig.parameters.values())[1:] # skip 'self'
- new_sig = inspect.Signature(params)
-
- # Bind args and kwargs to the new signature
- bound_args = new_sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
-
- # Get arguments as a dictionary
- all_args = bound_args.arguments
-
- export_hash, filtered_hash_params = create_export_hash(
- model_params=self.hash_params,
- output_names=all_args.get("output_names"),
- dynamic_axes=all_args.get("dynamic_axes"),
- export_kwargs=all_args.get("export_kwargs", None),
- onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
- use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False),
- )
-
- export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
- kwargs["export_dir"] = export_dir
- self.export_hash = export_hash
-
- # _EXPORT CALL
- onnx_path = func(self, *args, **kwargs)
-
- # POST-PROCESSING
- # Dump JSON file with hashed parameters
- hashed_params_export_path = export_dir / "hashed_export_params.json"
- create_json(hashed_params_export_path, filtered_hash_params)
- logger.info("Hashed parameters exported successfully.")
-
- return onnx_path
-
- return wrapper
-
-
def execute_command(process: str, command: str, output_file_path: Optional[str] = None):
"""
Executes the give command using subprocess.
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index e0b003422..613d7049a 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -144,6 +144,13 @@ def get_models_dir():
# Molmo Constants
MOLMO_IMAGE_HEIGHT = 536
MOLMO_IMAGE_WIDTH = 354
+# Flux Transformer Constants
+FLUX_ONNX_EXPORT_SEQ_LENGTH = 256
+FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096
+FLUX_ADALN_HIDDEN_DIM = 3072
+FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context
+FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3
+FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM
class Constants:
diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py
new file mode 100644
index 000000000..eea92a490
--- /dev/null
+++ b/QEfficient/utils/export_utils.py
@@ -0,0 +1,235 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import inspect
+import re
+import warnings
+from pathlib import Path
+from typing import Dict
+
+from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform
+from QEfficient.transformers.cache_utils import InvalidIndexProvider
+from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
+from QEfficient.utils.cache import QEFF_HOME
+from QEfficient.utils.hash_utils import create_export_hash
+from QEfficient.utils.logging_utils import logger
+from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
+
+
+def export_wrapper(func):
+ """
+ Decorator for export methods that orchestrates the complete export lifecycle.
+
+ Responsibilities:
+ 1. Prepare export directory structure
+ 2. Generate reproducible hash for export configuration
+ 3. Setup ONNX subfunction environment (if enabled)
+ 4. Execute the wrapped export function
+ 5. Cleanup subfunction environment (if enabled)
+ 6. Save export metadata
+
+ Args:
+ func: The export method to wrap (typically _export)
+
+ Returns:
+ Wrapped function with complete export lifecycle management
+ """
+
+ def wrapper(self, *args, **kwargs):
+ # 1. Prepare export directory
+ export_dir = _prepare_export_directory(self, kwargs)
+
+ # 2. Generate hash and finalize export directory path
+ export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func)
+ export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
+ kwargs["export_dir"] = export_dir
+ self.export_hash = export_hash
+
+ # 3. Setup ONNX subfunctions if requested
+ # TODO: No need of this variable, if export_kwargs contains classes (refer diffusers)
+ if use_onnx_subfunctions := kwargs.get("use_onnx_subfunctions", False):
+ _setup_onnx_subfunctions(self, kwargs)
+
+ # 4. Execute the actual export
+ onnx_path = func(self, *args, **kwargs)
+
+ # 5. Save export metadata
+ _save_export_metadata(export_dir, filtered_hash_params)
+
+ # 6. Always cleanup subfunctions if they were setup
+ if use_onnx_subfunctions:
+ _cleanup_onnx_subfunctions(self)
+
+ return onnx_path
+
+ return wrapper
+
+
+def _prepare_export_directory(qeff_model, kwargs) -> Path:
+ """
+ Prepare and return the base export directory path.
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Keyword arguments containing optional export_dir
+
+ Returns:
+ Path object for the base export directory
+ """
+ export_dir = kwargs.get("export_dir", None)
+ parent_dir = qeff_model.model_architecture or qeff_model.model_name
+ return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name))
+
+
+def _generate_export_hash(qeff_model, args, kwargs, func):
+ """
+ Generate export hash from model parameters and export arguments.
+
+ The hash ensures reproducibility and prevents conflicts between
+ different export configurations.
+
+ Args:
+ qeff_model: The QEff model instance
+ args: Positional arguments to the export function
+ kwargs: Keyword arguments to the export function
+ func: The export function being wrapped
+
+ Returns:
+ Tuple of (export_hash: str, filtered_hash_params: dict)
+ """
+ # Extract use_onnx_subfunctions before binding (it's used by wrapper, not _export)
+ use_onnx_subfunctions = kwargs.pop("use_onnx_subfunctions", False)
+
+ # Extract function signature
+ original_sig = inspect.signature(func)
+ params = list(original_sig.parameters.values())[1:] # Skip 'self'
+ new_sig = inspect.Signature(params)
+ # Bind all arguments
+ bound_args = new_sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ all_args = bound_args.arguments
+
+ # Use the model's current configuration for hashing to ensure any post-load modifications are captured
+ # TODO: Replace with get_model_config property of modeling classes and remove the if-else
+ # Determine the config dict to use, preferring .to_diff_dict() if available
+
+ if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.config.to_diff_dict()
+ elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"):
+ config_val = qeff_model.model.model.config.to_diff_dict()
+ else:
+ config_val = qeff_model.model.config
+
+ qeff_model.hash_params.update(
+ {
+ "config": config_val,
+ }
+ )
+
+ # Generate hash from relevant parameters
+ export_hash, filtered_hash_params = create_export_hash(
+ model_params=qeff_model.hash_params,
+ output_names=all_args.get("output_names"),
+ dynamic_axes=all_args.get("dynamic_axes"),
+ export_kwargs=all_args.get("export_kwargs", None),
+ onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
+ use_onnx_subfunctions=use_onnx_subfunctions,
+ )
+
+ return export_hash, filtered_hash_params
+
+
+def _setup_onnx_subfunctions(qeff_model, kwargs):
+ """
+ Setup ONNX subfunction export environment.
+
+ This function prepares the model and environment for exporting with
+ ONNX subfunctions enabled. It:
+ - Applies necessary torch patches
+ - Modifies output names for subfunction compatibility
+ - Adds subfunction-specific ONNX transforms
+ - Updates export kwargs with module classes
+
+ Args:
+ qeff_model: The QEff model instance
+ kwargs: Export keyword arguments (modified in-place).
+ """
+ warnings.warn(
+ "The subfunction feature is experimental. Please note that using compile "
+ "consecutively with and without subfunction may produce inconsistent results."
+ )
+
+ # Apply torch patches for subfunction support
+ apply_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = True
+
+ # Store original state for restoration during cleanup
+ qeff_model._original_onnx_transforms = qeff_model._onnx_transforms.copy()
+
+ # Transform output names for subfunction compatibility
+ if "output_names" in kwargs:
+ kwargs["output_names"] = [
+ re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"]
+ ]
+
+ # Add subfunction-specific ONNX transforms
+ qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
+ qeff_model._onnx_transforms.append(CustomOpTransform)
+
+ # Configure export to use modules as functions
+ export_kwargs = kwargs.get("export_kwargs", {})
+
+ # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
+ export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
+ kwargs["export_kwargs"] = export_kwargs
+
+
+def _cleanup_onnx_subfunctions(qeff_model):
+ """
+ Cleanup ONNX subfunction export environment.
+
+ Restores the model and environment to pre-subfunction state by:
+ - Undoing torch patches
+ - Resetting InvalidIndexProvider flag
+ - Restoring original ONNX transforms list
+
+ Args:
+ qeff_model: The QEff model instance
+
+ Note:
+ This function is called in a finally block to ensure cleanup
+ even if export fails. Errors during cleanup are logged but
+ not re-raised to avoid masking the original exception.
+ """
+ try:
+ # Undo torch patches
+ undo_torch_patches()
+ InvalidIndexProvider.SUBFUNC_ENABLED = False
+
+ # Restore original ONNX transforms
+ if hasattr(qeff_model, "_original_onnx_transforms"):
+ qeff_model._onnx_transforms = qeff_model._original_onnx_transforms
+ delattr(qeff_model, "_original_onnx_transforms")
+
+ except Exception as e:
+ logger.error(f"Error during subfunction cleanup: {e}")
+
+
+def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict):
+ """
+ Save export metadata to JSON file for reproducibility.
+
+ Args:
+ export_dir: Directory where the export was saved
+ filtered_hash_params: Dictionary of parameters used for hashing
+ """
+ # Import here to avoid circular dependency
+ from QEfficient.utils._utils import create_json
+
+ hashed_params_path = export_dir / "hashed_export_params.json"
+ create_json(hashed_params_path, filtered_hash_params)
+ logger.info("Hashed parameters exported successfully.")
diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py
index 948b72e6a..68ccab0d4 100644
--- a/QEfficient/utils/hash_utils.py
+++ b/QEfficient/utils/hash_utils.py
@@ -14,7 +14,8 @@
def json_serializable(obj):
if isinstance(obj, set):
- return sorted(obj)
+ # Convert set to a sorted list of strings for consistent hashing
+ return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj])
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png
new file mode 100644
index 000000000..9e58da61d
Binary files /dev/null and b/docs/image/girl_laughing.png differ
diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md
new file mode 100644
index 000000000..2a3c1605f
--- /dev/null
+++ b/examples/diffusers/flux/README.md
@@ -0,0 +1,243 @@
+# FLUX.1-schnell Image Generation Examples
+
+This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs.
+
+## Overview
+
+FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation.
+
+## Files
+
+- **`flux_1_schnell.py`** - Basic example showing simple image generation
+- **`flux_1_shnell_custom.py`** - Advanced example with customization options
+- **`flux_config.json`** - Configuration file for pipeline modules
+
+## Quick Start
+
+### Basic Usage
+
+The simplest way to generate images with FLUX.1-schnell:
+
+```python
+from QEfficient import QEffFluxPipeline
+import torch
+
+# Initialize pipeline
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate image
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Save image
+output.images[0].save("girl_laughing.png")
+```
+
+Run the basic example:
+```bash
+python flux_1_schnell.py
+```
+
+## Advanced Customization
+
+The `flux_1_shnell_custom.py` example demonstrates several advanced features:
+
+### 1. Custom Model Components
+
+You can provide custom text encoders, transformers, and tokenizers:
+
+```python
+pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ text_encoder=custom_text_encoder,
+ transformer=custom_transformer,
+ tokenizer=custom_tokenizer,
+)
+```
+
+### 2. Custom Scheduler
+
+Replace the default scheduler with your own:
+
+```python
+pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+```
+
+### 3. Reduce Model Layers for Faster Inference
+
+Trade quality for speed by reducing transformer blocks:
+
+```python
+original_blocks = pipeline.transformer.model.transformer_blocks
+org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+pipeline.transformer.model.config['num_layers'] = 1
+pipeline.transformer.model.config['num_single_layers'] = 1
+```
+
+### 4. Pre-compile with Custom Configuration
+
+Compile the model separately before generation:
+
+```python
+pipeline.compile(
+ compile_config="examples/diffusers/flux/flux_config.json",
+ height=512,
+ width=512,
+ use_onnx_subfunctions=False
+)
+```
+
+### 5. Runtime Configuration
+
+Use custom configuration during generation:
+
+```python
+output = pipeline(
+ prompt="A girl laughing",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+```
+
+Run the advanced example:
+```bash
+python flux_1_shnell_custom.py
+```
+
+## Configuration File
+
+The `flux_config.json` file controls compilation and execution settings for each pipeline module:
+
+### Module Structure
+
+The configuration includes four main modules:
+
+1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence)
+2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence)
+3. **transformer** - Core diffusion transformer model
+4. **vae_decoder** - Decodes latents to images
+
+### Configuration Parameters
+
+Each module has three sections:
+
+#### Specializations
+- `batch_size`: Batch size for inference
+- `seq_len`: Sequence length for text encoders
+- `steps`: Number of inference steps (transformer only)
+- `channels`: Number of channels (VAE decoder only)
+
+#### Compilation
+- `onnx_path`: Path to pre-exported ONNX model (null for auto-export)
+- `compile_dir`: Directory for compiled artifacts (null for auto-generation)
+- `mdp_ts_num_devices`: Number of devices for model data parallelism
+- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication
+- `convert_to_fp16`: Convert model to FP16 precision
+- `aic_num_cores`: Number of AI cores to use
+- `mos`: Multi-output streaming (transformer only)
+- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only)
+- `aic-enable-depth-first`: Enable depth-first compilation (VAE only)
+
+#### Execute
+- `device_ids`: List of device IDs to use (null for auto-selection)
+
+### Example Configuration Snippet
+
+```json
+{
+ "transformer": {
+ "specializations": {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation": {
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute": {
+ "device_ids": null
+ }
+ }
+}
+```
+
+## Key Parameters
+
+### Generation Parameters
+
+- **`prompt`** (str): Text description of the image to generate
+- **`height`** (int): Output image height in pixels (default: 1024)
+- **`width`** (int): Output image width in pixels (default: 1024)
+- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell)
+- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell)
+- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended)
+- **`generator`** (torch.Generator): Random seed for reproducibility
+- **`parallel_compile`** (bool): Enable parallel compilation of modules
+- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental)
+
+### Performance Tuning
+
+- **Faster inference**: Reduce `num_inference_steps` or model layers
+- **Better quality**: Increase `num_inference_steps` or use full model
+- **Memory optimization**: Adjust `mdp_ts_num_devices` in config
+- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16`
+
+## Output
+
+The pipeline returns an output object containing:
+- `images`: List of generated PIL Image objects
+- Performance metrics (timing information)
+
+Example output:
+```python
+print(output) # Displays performance information
+image = output.images[0] # Access the generated image
+image.save("output.png") # Save to disk
+```
+
+## Hardware Requirements
+
+- Qualcomm Cloud AI 100 accelerator
+- Sufficient memory for model compilation and execution
+- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`)
+
+## Notes
+
+- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0`
+- The transformer module benefits most from multi-device parallelism
+- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use
+- Custom configurations allow fine-tuning for specific hardware setups
+
+## Troubleshooting
+
+- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices`
+- **Slow compilation**: Enable `parallel_compile=True`
+- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0)
+- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection
+
+## References
+
+- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+- [QEfficient Documentation](../../../README.md)
+- [Diffusers Pipeline Guide](../../README.md)
diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py
new file mode 100644
index 000000000..46f26bb6b
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_schnell.py
@@ -0,0 +1,45 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1-schnell Image Generation Example
+
+This example demonstrates how to use the QEffFluxPipeline to generate images
+using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a
+fast, distilled version of the FLUX.1 text-to-image model optimized for
+speed with minimal quality loss.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# Initialize the FLUX.1-schnell pipeline from pretrained weights
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+
+# Generate an image from a text prompt
+# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation
+output = pipeline(
+ prompt="A laughing girl",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+# Extract the generated image from the output
+image = output.images[0]
+
+# Save the generated image to disk
+image.save("girl_laughing.png")
+
+# Print the output object (contains perf info)
+print(output)
diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py
new file mode 100644
index 000000000..201ebe659
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_shnell_custom.py
@@ -0,0 +1,113 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1 Schnell Custom Configuration Example
+
+This example demonstrates how to customize the FLUX.1 model with various options:
+1. Custom image dimensions (height/width)
+2. Custom transformer model and text encoder
+3. Custom scheduler configuration
+4. Reduced model layers for faster inference
+5. Custom compilation settings
+6. Custom runtime configuration via JSON config file
+
+Use this example to learn how to fine-tune FLUX.1 for your specific needs.
+"""
+
+import torch
+
+from QEfficient import QEffFluxPipeline
+
+# ============================================================================
+# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS
+# ============================================================================
+
+# Option 1: Basic initialization with default parameters
+pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+# Option 2: Advanced initialization with custom modules
+# Uncomment and modify to use your own custom components:
+#
+# pipeline = QEffFluxPipeline.from_pretrained(
+# "black-forest-labs/FLUX.1-schnell",
+# text_encoder=custom_text_encoder, # Your custom CLIP text encoder
+# transformer=custom_transformer, # Your custom transformer model
+# tokenizer=custom_tokenizer, # Your custom tokenizer
+# )
+
+# ============================================================================
+# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION
+# ============================================================================
+# Uncomment to use a custom scheduler (e.g., different sampling methods):
+#
+# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+
+# ============================================================================
+# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE
+# ============================================================================
+# Reduce the number of transformer blocks to speed up image generation.
+#
+# Trade-off: Faster inference but potentially lower image quality
+# Use case: Quick testing, prototyping, or when speed is critical
+#
+# Uncomment the following lines to use only the first transformer block:
+#
+# original_blocks = pipeline.transformer.model.transformer_blocks
+# org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+# pipeline.transformer.model.config['num_layers'] = 1
+# pipeline.transformer.model.config['num_single_layers'] = 1
+
+# ============================================================================
+# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION
+# ============================================================================
+# Pre-compile the model for optimized performance on target hardware.
+#
+# When to use:
+# - When you want to compile the model separately before generation
+# - When you need to skip image generation and only prepare the model
+#
+# NOTE-1: If compile_config is not specified, the default configuration from
+# QEfficient/diffusers/pipelines/flux/flux_config.json will be used
+#
+# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended)
+# This feature improves export performance by breaking down the model into smaller,
+# more manageable ONNX functions, which can lead to improve compile time.
+# Uncomment to compile with a custom configuration:
+# pipeline.compile(
+# compile_config="examples/diffusers/flux/flux_config.json",
+# height=512,
+# width=512,
+# use_onnx_subfunctions=False
+# )
+
+# ============================================================================
+# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION
+# ============================================================================
+# Generate an image using the configured pipeline.
+#
+# Note: Use of custom_config_path provides flexibility to set device_ids for each
+# module, so you can skip the separate pipeline.compile() step.
+
+output = pipeline(
+ prompt="A laughing girl",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ height=1024,
+ width=1024,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+ use_onnx_subfunctions=False,
+)
+
+image = output.images[0]
+# Save the generated image to disk
+image.save("laughing_girl.png")
+print(output)
diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json
new file mode 100644
index 000000000..73b92265f
--- /dev/null
+++ b/examples/diffusers/flux/flux_config.json
@@ -0,0 +1,99 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "compile_only": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "aic-enable-depth-first": true,
+ "compile_only":true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/pyproject.toml b/pyproject.toml
index 8e179ab4a..fe0c42ec2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,9 +20,10 @@ classifiers = [
requires-python = ">=3.8,<3.11"
dependencies = [
"transformers==4.55.0",
+ "diffusers== 0.35.1",
"huggingface-hub==0.34.0",
"hf_transfer==0.1.9",
- "peft==0.13.2",
+ "peft==0.17.0",
"datasets==2.20.0",
"fsspec==2023.6.0",
"multidict==6.0.4",
diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile
index 134770638..d878076fa 100644
--- a/scripts/Jenkinsfile
+++ b/scripts/Jenkinsfile
@@ -22,6 +22,7 @@ pipeline {
. preflight_qeff/bin/activate &&
pip install --upgrade pip setuptools &&
pip install .[test] &&
+ pip install .[diffusers] &&
pip install junitparser pytest-xdist &&
pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing
pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs
@@ -69,7 +70,7 @@ pipeline {
}
stage('QAIC MultiModal Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
@@ -86,7 +87,7 @@ pipeline {
}
stage('Inference Tests') {
steps {
- timeout(time: 60, unit: 'MINUTES') {
+ timeout(time: 120, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
#source /qnn_sdk/bin/envsetup.sh &&
@@ -162,7 +163,7 @@ pipeline {
// }
stage('Finetune CLI Tests') {
steps {
- timeout(time: 5, unit: 'MINUTES') {
+ timeout(time: 20, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py
index d1b7a4653..f63b18f1a 100644
--- a/tests/base/test_export_memory_offload.py
+++ b/tests/base/test_export_memory_offload.py
@@ -27,7 +27,7 @@
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py
new file mode 100644
index 000000000..305116c03
--- /dev/null
+++ b/tests/diffusers/diffusers_utils.py
@@ -0,0 +1,175 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+Common utilities for diffusion pipeline testing.
+Provides essential functions for MAD validation, image validation
+hash verification, and other testing utilities.
+"""
+
+import os
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+
+class DiffusersTestUtils:
+ """Essential utilities for diffusion pipeline testing"""
+
+ @staticmethod
+ def validate_image_generation(
+ image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0
+ ) -> Dict[str, Any]:
+ """
+ Validate generated image properties.
+ Args:
+ image: Generated PIL Image
+ expected_size: Expected (width, height) tuple
+ min_variance: Minimum pixel variance to ensure image is not blank
+
+ Returns:
+ Dict containing validation results
+ Raises:
+ AssertionError: If image validation fails
+ """
+ # Basic image validation
+ assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}"
+ assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}"
+ assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}"
+
+ # Variance check (ensure image is not blank)
+ img_array = np.array(image)
+ image_variance = float(img_array.std())
+ assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})"
+
+ return {
+ "size": image.size,
+ "mode": image.mode,
+ "variance": image_variance,
+ "mean_pixel_value": float(img_array.mean()),
+ "min_pixel": int(img_array.min()),
+ "max_pixel": int(img_array.max()),
+ "valid": True,
+ }
+
+ @staticmethod
+ def check_file_exists(file_path: str, file_type: str = "file") -> bool:
+ """
+ Check if file exists and log result.
+ Args:
+ file_path: Path to check
+ file_type: Description of file type for logging
+ Returns:
+ bool: True if file exists
+ """
+ exists = os.path.exists(file_path)
+ status = "ā
" if exists else "ā"
+ print(f"{status} {file_type}: {file_path}")
+ return exists
+
+ @staticmethod
+ def print_test_header(title: str, config: Dict[str, Any]) -> None:
+ """
+ Print formatted test header with configuration details.
+
+ Args:
+ title: Test title
+ config: Test configuration dictionary
+ """
+ print(f"\n{'=' * 80}")
+ print(f"{title}")
+ print(f"{'=' * 80}")
+
+ if "model_setup" in config:
+ setup = config["model_setup"]
+ for k, v in setup.items():
+ print(f"{k} : {v}")
+
+ if "functional_testing" in config:
+ func = config["functional_testing"]
+ print(f"Test Prompt: {func.get('test_prompt', 'N/A')}")
+ print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}")
+ print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}")
+
+ print(f"{'=' * 80}")
+
+
+class MADValidator:
+ """Specialized class for MAD validation - always enabled, always reports, always fails on exceed"""
+
+ def __init__(self, tolerances: Dict[str, float] = None):
+ """
+ Initialize MAD validator.
+ MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded.
+
+ Args:
+ tolerances: Dictionary of module_name -> tolerance mappings
+ """
+ self.tolerances = tolerances
+ self.results = {}
+
+ def calculate_mad(
+ self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray]
+ ) -> float:
+ """
+ Calculate Max Absolute Deviation between two tensors.
+
+ Args:
+ tensor1: First tensor (PyTorch or NumPy)
+ tensor2: Second tensor (PyTorch or NumPy)
+
+ Returns:
+ float: Maximum absolute difference between tensors
+ """
+ if isinstance(tensor1, torch.Tensor):
+ tensor1 = tensor1.detach().numpy()
+ if isinstance(tensor2, torch.Tensor):
+ tensor2 = tensor2.detach().numpy()
+
+ return float(np.max(np.abs(tensor1 - tensor2)))
+
+ def validate_module_mad(
+ self,
+ pytorch_output: Union[torch.Tensor, np.ndarray],
+ qaic_output: Union[torch.Tensor, np.ndarray],
+ module_name: str,
+ step_info: str = "",
+ ) -> bool:
+ """
+ Validate MAD for a specific module.
+ Always validates, always reports, always fails if tolerance exceeded.
+
+ Args:
+ pytorch_output: PyTorch reference output
+ qaic_output: QAIC inference output
+ module_name: Name of the module
+ step_info: Additional step information for logging
+
+ Returns:
+ bool: True if validation passed
+
+ Raises:
+ AssertionError: If MAD exceeds tolerance
+ """
+ mad_value = self.calculate_mad(pytorch_output, qaic_output)
+
+ # Always report MAD value
+ step_str = f" {step_info}" if step_info else ""
+ print(f"š {module_name.upper()} MAD{step_str}: {mad_value:.8f}")
+
+ # Always validate - fail if tolerance exceeded
+ tolerance = self.tolerances.get(module_name, 1e-2)
+ if mad_value > tolerance:
+ raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}")
+
+ # Store result
+ if module_name not in self.results:
+ self.results[module_name] = []
+ self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance})
+ return True
diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json
new file mode 100644
index 000000000..7d0c17d55
--- /dev/null
+++ b/tests/diffusers/flux_test_config.json
@@ -0,0 +1,123 @@
+{
+ "model_setup": {
+ "height": 256,
+ "width": 256,
+ "num_transformer_layers": 2,
+ "num_single_layers": 2,
+ "use_onnx_subfunctions": false
+ },
+ "mad_validation": {
+ "tolerances": {
+ "clip_text_encoder": 0.1,
+ "t5_text_encoder": 5.5,
+ "transformer": 2.0,
+ "vae_decoder": 1.0
+ }
+ },
+ "pipeline_params": {
+ "test_prompt": "A cat holding a sign that says hello world",
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "validate_gen_img": true,
+ "min_image_variance": 1.0,
+ "custom_config_path": null
+ },
+ "validation_checks": {
+ "image_generation": true,
+ "onnx_export": true,
+ "compilation": true
+ },
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1,
+ "aic-enable-depth-first": true
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+
+}
diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py
new file mode 100644
index 000000000..6f4396a20
--- /dev/null
+++ b/tests/diffusers/test_flux.py
@@ -0,0 +1,448 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import os
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import pytest
+import torch
+from diffusers import FluxPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+from QEfficient import QEffFluxPipeline
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ModulePerf,
+ QEffPipelineOutput,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils._utils import load_json
+from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator
+
+# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance
+CONFIG_PATH = "tests/diffusers/flux_test_config.json"
+INITIAL_TEST_CONFIG = load_json(CONFIG_PATH)
+
+
+def flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height: int = 256,
+ width: int = 256,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ mad_tolerances: Dict[str, float] = None,
+):
+ """
+ Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__()
+ while adding comprehensive MAD validation at each step.
+
+ This function follows the EXACT same structure as QEffFluxPipeline.__call__()
+ but adds MAD validation hooks throughout the process.
+ """
+ # Initialize MAD validator
+ mad_validator = MADValidator(tolerances=mad_tolerances)
+
+ device = "cpu"
+
+ # Step 1: Load configuration, compile models
+ pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
+
+ # Set device IDs for all modules based on configuration
+ set_module_device_ids(pipeline)
+
+ # Validate all inputs
+ pipeline.model.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Set pipeline attributes
+ pipeline._guidance_scale = guidance_scale
+ pipeline._interrupt = False
+ batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"]
+
+ # Step 3: Encode prompts with both text encoders
+ # Use pipeline's encode_prompt method
+ (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ # Deactivate text encoder qpc sessions
+ pipeline.text_encoder.qpc_session.deactivate()
+ pipeline.text_encoder_2.qpc_session.deactivate()
+
+ # MAD Validation for Text Encoders
+ print("š Performing MAD validation for text encoders...")
+ mad_validator.validate_module_mad(
+ clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder"
+ )
+ mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder")
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
+ pipeline._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = pipeline.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = pipeline.model.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ t5_qaic_prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Initialize transformer inference session
+ if pipeline.transformer.qpc_session is None:
+ pipeline.transformer.qpc_session = QAICInferenceSession(
+ str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids
+ )
+
+ # Calculate compressed latent dimension (cl) for transformer buffer allocation
+ from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension
+
+ cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor)
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32),
+ }
+ pipeline.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ pipeline.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if pipeline._interrupt:
+ continue
+
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds)
+
+ # Compute AdaLN embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.transformer_blocks)):
+ block = pipeline.transformer.model.transformer_blocks[block_idx]
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)):
+ block = pipeline.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = pipeline.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(),
+ "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # MAD Validation for Transformer - PyTorch reference inference
+ noise_pred_torch = pytorch_pipeline.transformer(
+ hidden_states=latents,
+ encoder_hidden_states=t5_torch_prompt_embeds,
+ pooled_projections=clip_torch_pooled_prompt_embeds,
+ timestep=torch.tensor(timestep),
+ img_ids=latent_image_ids,
+ txt_ids=text_ids,
+ return_dict=False,
+ )[0]
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.time()
+ outputs = pipeline.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.time()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Transformer MAD validation
+ mad_validator.validate_module_mad(
+ noise_pred_torch.detach().cpu().numpy(),
+ outputs["output"],
+ "transformer",
+ f"step {i} (t={t.item():.1f})",
+ )
+
+ # Update latents using scheduler
+ latents_dtype = latents.dtype
+ latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images
+ if output_type == "latent":
+ image = latents
+ vae_decode_perf = 0.0 # No VAE decoding for latent output
+ else:
+ # Unpack and denormalize latents
+ latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor)
+
+ # Denormalize latents
+ latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor
+ # Initialize VAE decoder inference session
+ if pipeline.vae_decode.qpc_session is None:
+ pipeline.vae_decode.qpc_session = QAICInferenceSession(
+ str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)}
+ pipeline.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # MAD Validation for VAE
+ # PyTorch reference inference
+ image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0]
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.time()
+ image = pipeline.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.time()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # VAE MAD validation
+ mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder")
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
+
+
+@pytest.fixture(scope="session")
+def flux_pipeline():
+ """Setup compiled Flux pipeline for testing"""
+ config = INITIAL_TEST_CONFIG["model_setup"]
+
+ pipeline = QEffFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ use_onnx_subfunctions=config["use_onnx_subfunctions"],
+ )
+
+ # Reduce to 2 layers for testing
+ original_blocks = pipeline.transformer.model.transformer_blocks
+ org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+
+ pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"]
+ pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"]
+ pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+
+ ### Pytorch pipeline
+ pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
+ original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks
+ org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks
+ pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList(
+ [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
+ )
+ pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList(
+ [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
+ )
+ return pipeline, pytorch_pipeline
+
+
+@pytest.mark.diffusion_models
+@pytest.mark.on_qaic
+def test_flux_pipeline(flux_pipeline):
+ """
+ Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py:
+ - 256x256 resolution - 2 transformer layers
+ - MAD validation
+ - Functional image generation test
+ - Export/compilation checks
+ - Returns QEffPipelineOutput with performance metrics
+ """
+ pipeline, pytorch_pipeline = flux_pipeline
+ config = INITIAL_TEST_CONFIG
+
+ # Print test header
+ DiffusersTestUtils.print_test_header(
+ f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers",
+ config,
+ )
+
+ # Test parameters
+ test_prompt = config["pipeline_params"]["test_prompt"]
+ num_inference_steps = config["pipeline_params"]["num_inference_steps"]
+ guidance_scale = config["pipeline_params"]["guidance_scale"]
+ max_sequence_length = config["pipeline_params"]["max_sequence_length"]
+
+ # Generate with MAD validation
+ generator = torch.manual_seed(42)
+ start_time = time.time()
+
+ try:
+ # Run the pipeline with integrated MAD validation (follows exact pipeline flow)
+ result = flux_pipeline_call_with_mad_validation(
+ pipeline,
+ pytorch_pipeline,
+ height=config["model_setup"]["height"],
+ width=config["model_setup"]["width"],
+ prompt=test_prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ max_sequence_length=max_sequence_length,
+ custom_config_path=CONFIG_PATH,
+ generator=generator,
+ mad_tolerances=config["mad_validation"]["tolerances"],
+ parallel_compile=True,
+ return_dict=True,
+ )
+
+ execution_time = time.time() - start_time
+
+ # Validate image generation
+ if config["pipeline_params"]["validate_gen_img"]:
+ assert result is not None, "Pipeline returned None"
+ assert hasattr(result, "images"), "Result missing 'images' attribute"
+ assert len(result.images) > 0, "No images generated"
+
+ generated_image = result.images[0]
+ expected_size = (config["model_setup"]["height"], config["model_setup"]["width"])
+ # Validate image properties using utilities
+ image_validation = DiffusersTestUtils.validate_image_generation(
+ generated_image, expected_size, config["pipeline_params"]["min_image_variance"]
+ )
+
+ print("\nā
IMAGE VALIDATION PASSED")
+ print(f" - Size: {image_validation['size']}")
+ print(f" - Mode: {image_validation['mode']}")
+ print(f" - Variance: {image_validation['variance']:.2f}")
+ print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}")
+ file_path = "test_flux_256x256_2layers.png"
+ # Save test image
+ generated_image.save(file_path)
+
+ if os.path.exists(file_path):
+ print(f"Image saved successfully at: {file_path}")
+ else:
+ print("Image was not saved.")
+
+ if config["validation_checks"]["onnx_export"]:
+ # Check if ONNX files exist (basic check)
+ print("\nš ONNX Export Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX")
+
+ if config["validation_checks"]["compilation"]:
+ # Check if QPC files exist (basic check)
+ print("\nš Compilation Validation:")
+ for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]:
+ module_obj = getattr(pipeline, module_name, None)
+ if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path:
+ DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC")
+
+ # Print test summary using utilities
+ print(f"\nTotal execution time: {execution_time:.4f}s")
+ except Exception as e:
+ print(f"\nTEST FAILED: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ # This allows running the test file directly for debugging
+ pytest.main([__file__, "-v", "-s", "-m", "flux"])
+# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short
diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py
index 0810ac6ba..3eaaf0f69 100644
--- a/tests/transformers/test_causal_lm.py
+++ b/tests/transformers/test_causal_lm.py
@@ -211,7 +211,7 @@ def test_causal_lm_hash_creation(config, cb, tmp_path):
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py
index 59281b73b..bc53cb539 100644
--- a/tests/transformers/test_speech_seq2seq.py
+++ b/tests/transformers/test_speech_seq2seq.py
@@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path):
@pytest.fixture
def tmp_cache(tmp_path, monkeypatch):
- monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
+ monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)
yield tmp_path
diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py
index fefa73973..b7a5495c6 100644
--- a/tests/utils/test_hash_utils.py
+++ b/tests/utils/test_hash_utils.py
@@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value):
def test_json_serializable():
# Test with a set
- assert json_serializable({1, 2, 3}) == [1, 2, 3]
+ assert json_serializable({1, 2, 3}) == ["1", "2", "3"]
# Test with an unsupported type
with pytest.raises(TypeError):
json_serializable({1, 2, 3, {4, 5}})