diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 7f63b34ca..2d8f72e0a 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -18,6 +18,7 @@ QEFFCommonLoader, ) from QEfficient.compile.compile_helper import compile +from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -39,6 +40,7 @@ "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "QEffFluxPipeline", ] # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ef7e83adf..ea347016b 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,7 +8,6 @@ import gc import inspect import logging -import re import shutil import subprocess import warnings @@ -21,26 +20,21 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - CustomOpTransform, OnnxTransformPipeline, - RenameFunctionOutputsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.transformers.cache_utils import InvalidIndexProvider -from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import ( constants, create_json, create_model_params, dump_qconfig, - export_wrapper, generate_mdp_partition_config, hash_dict_params, load_json, ) -from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches +from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -125,9 +119,35 @@ def _model_offloaded_check(self) -> None: logger.error(error_msg) raise RuntimeError(error_msg) + @property + def model_name(self) -> str: + """ + Get the model class name without QEff/QEFF prefix. + + This property extracts the underlying model's class name and removes + any QEff or QEFF prefix that may have been added during wrapping. + + Returns: + str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel") + """ + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + @property @abstractmethod - def model_name(self) -> str: ... + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + This is an abstract property that must be implemented by all subclasses. + Typically returns: self.model.config.__dict__ + + Returns: + Dict: The configuration dictionary of the underlying model + """ + pass @abstractmethod def export(self, export_dir: Optional[str] = None) -> Path: @@ -188,7 +208,6 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, - use_onnx_subfunctions: bool = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -253,18 +272,8 @@ def _export( input_names.append(param) try: - # Initialize the registry with your custom ops + # Export to ONNX export_kwargs = {} if export_kwargs is None else export_kwargs - if use_onnx_subfunctions: - warnings.warn( - "The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results." - ) - apply_torch_patches() - InvalidIndexProvider.SUBFUNC_ENABLED = True - output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] - export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) - self._onnx_transforms.append(RenameFunctionOutputsTransform) - self._onnx_transforms.append(CustomOpTransform) torch.onnx.export( self.model, @@ -309,12 +318,6 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - if use_onnx_subfunctions: - undo_torch_patches() - InvalidIndexProvider.SUBFUNC_ENABLED = False - self._onnx_transforms.remove(CustomOpTransform) - self._onnx_transforms.remove(RenameFunctionOutputsTransform) - self.onnx_path = onnx_path return onnx_path diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md new file mode 100644 index 000000000..40d45e984 --- /dev/null +++ b/QEfficient/diffusers/README.md @@ -0,0 +1,95 @@ + +
+ + +# **Diffusion Models on Qualcomm Cloud AI 100** + + +
+ +### šŸŽØ **Experience the Future of AI Image Generation** + +* Optimized for Qualcomm Cloud AI 100* + +Sample Output + +**Generated with**: `black-forest-labs/FLUX.1-schnell` • `"A girl laughing"` • 4 steps • 0.0 guidance scale • ⚔ + + + +
+ + + +[![Diffusers](https://img.shields.io/badge/Diffusers-0.35.1-orange.svg)](https://github.com/huggingface/diffusers) +
+ +--- + +## ✨ Overview + +QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware. + +## šŸ› ļø Installation + +### Prerequisites + +Ensure you have Python 3.8+ and the required dependencies: + +```bash +# Create Python virtual environment (Recommended Python 3.10) +sudo apt install python3.10-venv +python3.10 -m venv qeff_env +source qeff_env/bin/activate +pip install -U pip +``` + +### Install QEfficient + +```bash +# Install from GitHub (includes diffusers support) +pip install git+https://github.com/quic/efficient-transformers + +# Or build from source +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install build wheel +python -m build --wheel --outdir dist +pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl +``` + +--- + +## šŸŽÆ Supported Models +- āœ… [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) + +--- + + +## šŸ“š Examples + +Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory: + +--- + +## šŸ¤ Contributing + +We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details. + + + +--- + +## šŸ™ Acknowledgments + +- **HuggingFace Diffusers**: For the excellent foundation library +- **Stability AI**: For the amazing Stable Diffusion models +--- + +## šŸ“ž Support + +- šŸ“– **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/) +- šŸ› **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues) + +--- + diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py new file mode 100644 index 000000000..933832ed8 --- /dev/null +++ b/QEfficient/diffusers/models/normalization.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Optional, Tuple + +import torch +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class QEffAdaLayerNormZero(AdaLayerNormZero): + def forward( + self, + x: torch.Tensor, + shift_msa: Optional[torch.Tensor] = None, + scale_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle): + def forward( + self, + x: torch.Tensor, + scale_msa: Optional[torch.Tensor] = None, + shift_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormContinuous(AdaLayerNormContinuous): + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = conditioning_embedding + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py new file mode 100644 index 000000000..d3c84ee63 --- /dev/null +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -0,0 +1,56 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.normalization import ( + QEffAdaLayerNormContinuous, + QEffAdaLayerNormZero, + QEffAdaLayerNormZeroSingle, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxAttention, + QEffFluxAttnProcessor, + QEffFluxSingleTransformerBlock, + QEffFluxTransformer2DModel, + QEffFluxTransformerBlock, +) + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + RMSNorm: CustomRMSNormAIC, + nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm + } + + +class AttentionTransform(ModuleMappingTransform): + _module_mapping = { + FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock, + FluxTransformerBlock: QEffFluxTransformerBlock, + FluxTransformer2DModel: QEffFluxTransformer2DModel, + FluxAttention: QEffFluxAttention, + FluxAttnProcessor: QEffFluxAttnProcessor, + } + + +class NormalizationTransform(ModuleMappingTransform): + _module_mapping = { + AdaLayerNormZero: QEffAdaLayerNormZero, + AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle, + AdaLayerNormContinuous: QEffAdaLayerNormContinuous, + } diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py new file mode 100644 index 000000000..5cb44af45 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,327 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, + _get_qkv_projections, +) + +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_emb( + x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = cos.to(x.device), sin.to(x.device) + B, S, H, D = x.shape + x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +class QEffFluxAttnProcessor(FluxAttnProcessor): + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "QEffFluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = qeff_apply_rotary_emb(query, image_rotary_emb) + key = qeff_apply_rotary_emb(key, image_rotary_emb) + + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class QEffFluxAttention(FluxAttention): + def __qeff_init__(self): + processor = QEffFluxAttnProcessor() + self.processor = processor + + +class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + shift_msa, scale_msa, gate = torch.split(temb, 1) + residual = hidden_states + norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + # if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformerBlock(FluxTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + temb1 = tuple(torch.split(temb[:6], 1)) + temb2 = tuple(torch.split(temb[6:], 1)) + norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1]) + gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:] + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1]) + + c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:] + + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + # if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformer2DModel(FluxTransformer2DModel): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + adaln_emb: torch.Tensor = None, + adaln_single_emb: torch.Tensor = None, + adaln_out: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 000000000..511746469 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,854 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +# TODO: Pipeline Architecture Improvements +# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile, +# and inference APIs across all diffusion pipelines, promoting code reusability +# and consistent interface design. +# 2. Implement persistent QPC session management strategy to retain/drop compiled model +# sessions in memory across all pipeline modules. + +import os +import time +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from tqdm import tqdm + +from QEfficient.diffusers.pipelines.pipeline_module import ( + QEffFluxTransformerModel, + QEffTextEncoder, + QEffVAE, +) +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + calculate_compressed_latent_dimension, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class QEffFluxPipeline: + """ + QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware. + + This pipeline provides an optimized implementation of the Flux diffusion model specifically designed + for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX format and compiled + into Qualcomm Program Container (QPC) files for efficient inference. + + The pipeline supports the complete Flux workflow including: + - Dual text encoding with CLIP and T5 encoders + - Transformer-based denoising with adaptive layer normalization + - VAE decoding for final image generation + - Performance monitoring and optimization + + Attributes: + text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings + text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings + transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising + vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion + modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations + model (FluxPipeline): Original HuggingFace Flux model reference + tokenizer: CLIP tokenizer for text preprocessing + scheduler: Diffusion scheduler for timestep management + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> images = pipeline( + ... prompt="A beautiful sunset over mountains", + ... height=512, + ... width=512, + ... num_inference_steps=28 + ... ) + >>> images.images[0].save("generated_image.png") + """ + + _hf_auto_class = FluxPipeline + + def __init__(self, model, *args, **kwargs): + """ + Initialize the QEfficient Flux pipeline. + + This pipeline provides an optimized implementation of the Flux text-to-image model + for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX and compiled + for QAIC devices. + + Args: + model: Pre-loaded FluxPipeline model + **kwargs: Additional arguments including height and width + """ + + # Wrap model components with QEfficient optimized versions + self.model = model + self.text_encoder = QEffTextEncoder(model.text_encoder) + self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2) + self.transformer = QEffFluxTransformerModel(model.transformer) + self.vae_decode = QEffVAE(model.vae, "decoder") + + # Store all modules in a dictionary for easy iteration during export/compile + self.modules = { + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "transformer": self.transformer, + "vae_decoder": self.vae_decode, + } + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.text_encoder_2.tokenizer = model.tokenizer_2 + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + # Override VAE forward method to use decode directly + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + # Sync max position embeddings between text encoders + self.text_encoder_2.model.config.max_position_embeddings = ( + self.text_encoder.model.config.max_position_embeddings + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **kwargs, + ): + """ + Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations. + + This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained + Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU + and wraps all components with QEfficient-optimized versions for QAIC deployment. + + Args: + pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier + (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory. + **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained(). + + Returns: + QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components + ready for export, compilation, and inference on QAIC devices. + + Raises: + ValueError: If the model path is invalid or model cannot be loaded + OSError: If there are issues accessing the model files + RuntimeError: If model initialization fails + + Example: + >>> # Load from HuggingFace Hub + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> + >>> # Load from local path + >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model") + >>> + >>> # Load with custom cache directory + >>> pipeline = QEffFluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... cache_dir="/custom/cache/dir" + ... ) + """ + # Load the base Flux model in float32 on CPU + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder, + Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific + configuration including dynamic axes, input/output specifications, and optimization settings. + + The export process prepares the models for subsequent compilation to QPC format, enabling + efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules + to optimize memory usage and performance. + + Args: + export_dir (str, optional): Target directory for saving ONNX model files. If None, + uses the default export directory structure based on model name and configuration. + The directory will be created if it doesn't exist. + use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction + optimization for supported modules. This can optimize thegraph and + improve compilation efficiency for models like the transformer. + + Returns: + str: Absolute path to the export directory containing all ONNX model files. + Each module will have its own subdirectory with the exported ONNX file. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Note: + - All models are exported in float32 precision for maximum compatibility + - Dynamic axes are configured to support variable batch sizes and sequence lengths + - The export process may take several minutes depending on model size + - Exported ONNX files can be large (several GB for complete pipeline) + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + >>> print(f"Models exported to: {export_path}") + """ + for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"): + # Get ONNX export configuration for this module + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path() -> str: + """ + Get the absolute path to the default Flux pipeline configuration file. + + Returns: + str: Absolute path to the flux_config.json file containing default pipeline + configuration settings for compilation and device allocation. + """ + return "QEfficient/diffusers/pipelines/configs/flux_config.json" + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = 512, + width: int = 512, + use_onnx_subfunctions: bool = False, + ) -> None: + """ + Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware. + + Args: + compile_config (str, optional): Path to a JSON configuration file containing + compilation settings, device mappings, and optimization parameters. If None, + uses the default configuration from get_default_config_path(). + parallel (bool, default=False): Compilation mode selection: + - True: Compile modules in parallel using ThreadPoolExecutor for faster processing + - False: Compile modules sequentially for lower resource usage + height (int, default=512): Target image height in pixels. + width (int, default=512): Target image width in pixels. + use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX + subfunctions before compilation. + + Raises: + RuntimeError: If compilation fails for any module or if QAIC compiler is not available + FileNotFoundError: If ONNX models haven't been exported or config file is missing + ValueError: If configuration parameters are invalid + OSError: If there are issues with file I/O during compilation + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> # Sequential compilation with default config + >>> pipeline.compile(height=1024, width=1024) + >>> + >>> # Parallel compilation with custom config + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=512, + ... width=512 + ... ) + """ + # Ensure all modules are exported to ONNX before compilation + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.text_encoder_2.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Load compilation configuration + config_manager(self, config_source=compile_config) + + # Calculate compressed latent dimension using utility function + cl, latent_height, latent_width = calculate_compressed_latent_dimension( + height, width, self.model.vae_scale_factor + ) + + # Prepare dynamic specialization updates based on image dimensions + specialization_updates = { + "transformer": {"cl": cl}, + "vae_decoder": { + "latent_height": latent_height, + "latent_width": latent_width, + }, + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the T5 text encoder for detailed semantic understanding. + + T5 provides rich sequence embeddings that capture fine-grained text details, + complementing CLIP's global representation in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + max_sequence_length (int): Maximum token sequence length (default: 512) + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (prompt_embeds, inference_time) + - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096] + - inference_time (float): T5 encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts with padding and truncation + text_inputs = self.text_encoder_2.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.text_encoder_2.tokenizer.batch_decode( + untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because `max_sequence_length` is set to " + f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder_2.qpc_session is None: + self.text_encoder_2.qpc_session = QAICInferenceSession( + str(self.text_encoder_2.qpc_path), device_ids=device_ids + ) + + # Allocate output buffers for QAIC inference + text_encoder_2_output = { + "last_hidden_state": np.random.rand( + batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model + ).astype(np.int32), + } + self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run T5 encoder inference and measure time + start_t5_time = time.perf_counter() + prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"]) + end_t5_time = time.perf_counter() + text_encoder_2_perf = end_t5_time - start_t5_time + + # Duplicate embeddings for multiple images per prompt + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, text_encoder_2_perf + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the CLIP text encoder for global semantic representation. + + CLIP provides pooled embeddings that capture high-level semantic meaning, + working alongside T5's detailed sequence embeddings in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (pooled_prompt_embeds, inference_time) + - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768] + - inference_time (float): CLIP encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to " + f"{self.tokenizer_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder.qpc_session is None: + self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids) + + # Allocate output buffers for QAIC inference + text_encoder_output = { + "last_hidden_state": np.random.rand( + batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size + ).astype(np.float32), + "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32), + } + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run CLIP encoder inference and measure time + start_text_encoder_time = time.perf_counter() + aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input) + end_text_encoder_time = time.perf_counter() + text_encoder_perf = end_text_encoder_time - start_text_encoder_time + # Extract pooled output (used for conditioning in Flux) + prompt_embeds = torch.tensor(aic_embeddings["pooler_output"]) + + # Duplicate embeddings for multiple images per prompt + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, text_encoder_perf + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + """ + Encode text prompts using Flux's dual text encoder architecture. + + Flux employs both CLIP and T5 encoders for comprehensive text understanding: + - CLIP provides pooled embeddings for global semantic conditioning + - T5 provides detailed sequence embeddings for fine-grained text control + + Args: + prompt (str or List[str]): Primary prompt(s) for both encoders + prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt + num_images_per_prompt (int): Number of images to generate per prompt + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings + max_sequence_length (int): Maximum sequence length for T5 tokenization + + Returns: + tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times) + - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096] + - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768] + - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3] + - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time] + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + # Use primary prompt for both encoders if secondary not provided + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # Encode with CLIP (returns pooled embeddings) + pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds( + prompt=prompt, + device_ids=self.text_encoder.device_ids, + num_images_per_prompt=num_images_per_prompt, + ) + + # Encode with T5 (returns sequence embeddings) + prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids=self.text_encoder_2.device_ids, + ) + + # Create text position IDs (required by Flux transformer) + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + + return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf] + + def __call__( + self, + height: int = 512, + width: int = 512, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + use_onnx_subfunctions: bool = False, + ): + """ + Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware. + + This is the main entry point for text-to-image generation. It orchestrates the complete Flux + diffusion pipeline optimized for Qualcomm AI Cloud devices. + + Args: + height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512. + width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512. + prompt (str or List[str]): Primary text prompt(s) describing the desired image(s). + Required unless `prompt_embeds` is provided. + prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`. + negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid. + Only used when `true_cfg_scale > 1.0`. + negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`. + true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable + negative prompting. Default: 1.0 (disabled). + num_inference_steps (int, optional): Number of denoising steps. Default: 28. + timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`. + guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5. + num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1. + generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility. + latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated. + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096]. + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768]. + negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings. + negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings. + output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent". + callback_on_step_end (Callable, optional): Callback function executed after each denoising step. + callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"]. + max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512. + custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings. + parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False. + use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False. + + Returns: + QEffPipelineOutput: A dataclass containing: + - images: Generated image(s) in the format specified by `output_type` + - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE) + + Raises: + ValueError: If input validation fails or parameters are incompatible. + RuntimeError: If compilation fails or QAIC devices are unavailable. + FileNotFoundError: If custom config file is specified but not found. + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> result = pipeline( + ... prompt="A serene mountain landscape at sunset", + ... height=1024, + ... width=1024, + ... num_inference_steps=28, + ... guidance_scale=7.5 + ... ) + >>> result.images[0].save("mountain_sunset.png") + >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s") + """ + device = self.model._execution_device + + if height is None or width is None: + logger.warning("Height or width is None. Setting default values of 512 for both dimensions.") + + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # Set device IDs for all modules based on configuration + set_module_device_ids(self) + + # Validate all inputs + self.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # Step 2: Determine batch size from inputs + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 3: Encode prompts with both text encoders + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Encode negative prompts if using true classifier-free guidance + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents, latent_image_ids = self.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Calculate compressed latent dimension for transformer buffer allocation + cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor) + + # Initialize transformer inference session + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + self.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds) + + # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.transformer_blocks)): + block = self.transformer.model.transformer_blocks[block_idx] + # Process through norm1 and norm1_context + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.single_transformer_blocks)): + block = self.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = self.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": prompt_embeds.detach().numpy(), + "pooled_projections": pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # Run transformer inference and measure time + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Update latents using scheduler (x_t -> x_t-1) + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch (workaround for MPS backend bug) + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Execute callback if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images (unless output_type is "latent") + if output_type == "latent": + image = latents + else: + # Unpack and denormalize latents + latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor) + latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor + + # Initialize VAE decoder inference session + if self.vae_decode.qpc_session is None: + self.vae_decode.qpc_session = QAICInferenceSession( + str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)} + self.vae_decode.qpc_session.set_buffers(output_buffer) + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.perf_counter() + image = self.vae_decode.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decode_perf = end_decode_time - start_decode_time + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = self.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py new file mode 100644 index 000000000..41a3d29f7 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -0,0 +1,481 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.diffusers.models.pytorch_transforms import ( + AttentionTransform, + CustomOpsTransform, + NormalizationTransform, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxSingleTransformerBlock, + QEffFluxTransformerBlock, +) +from QEfficient.transformers.models.pytorch_transforms import ( + T5ModelTransform, +) +from QEfficient.utils import constants + + +class QEffTextEncoder(QEFFBaseModel): + """ + Wrapper for text encoder models with ONNX export and QAIC compilation capabilities. + + This class handles text encoder models (CLIP, T5) with specific transformations and + optimizations for efficient inference on Qualcomm AI hardware. It applies custom + PyTorch and ONNX transformations to prepare models for deployment. + + Attributes: + model (nn.Module): The wrapped text encoder model (deep copy of original) + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform, T5ModelTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying text encoder model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the text encoder wrapper. + + Args: + model (nn.Module): The text encoder model to wrap (CLIP or T5) + """ + super().__init__(model) + self.model = model + + def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the text encoder. + + Creates example inputs, dynamic axes specifications, and output names + tailored to the specific text encoder type (CLIP vs T5). + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # Create example input with max sequence length + example_inputs = { + "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64), + } + + # Define which dimensions can vary at runtime + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}} + + # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output + if self.model.__class__.__name__ == "T5EncoderModel": + output_names = ["last_hidden_state"] + else: + output_names = ["last_hidden_state", "pooler_output"] + example_inputs["output_hidden_states"] = False + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = None, + ) -> str: + """ + Export the text encoder model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + export_kwargs=export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffUNet(QEFFBaseModel): + """ + Wrapper for UNet models with ONNX export and QAIC compilation capabilities. + + This class handles UNet models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. UNet is commonly used in + diffusion models for image generation tasks. + + Attributes: + model (nn.Module): The wrapped UNet model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying UNet model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the UNet wrapper. + + Args: + model (nn.Module): The pipeline model containing the UNet + """ + super().__init__(model.unet) + self.model = model.unet + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = None, + ) -> str: + """ + Export the UNet model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + export_kwargs=export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffVAE(QEFFBaseModel): + """ + Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation. + + This class handles VAE models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion + pipelines for encoding images to latent space and decoding latents back to images. + + Attributes: + model (nn.Module): The wrapped VAE model (deep copy of original) + type (str): VAE operation type ("encoder" or "decoder") + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying VAE model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module, type: str) -> None: + """ + Initialize the VAE wrapper. + + Args: + model (nn.Module): The pipeline model containing the VAE + type (str): VAE operation type ("encoder" or "decoder") + """ + super().__init__(model) + self.model = model + + # To have different hashing for encoder/decoder + self.model.config["type"] = type + + def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the VAE decoder. + + Args: + latent_height (int): Height of latent representation (default: 32) + latent_width (int): Width of latent representation (default: 32) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, latent_height, latent_width), + "return_dict": False, + } + + output_names = ["sample"] + + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = None, + ) -> str: + """ + Export the VAE model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + export_kwargs=export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffFluxTransformerModel(QEFFBaseModel): + """ + Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities. + + This class handles Flux Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion + architecture instead of traditional UNet, with dual transformer blocks and adaptive layer + normalization (AdaLN) for conditioning. + + Attributes: + model (nn.Module): The wrapped Flux transformer model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Flux transformer model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the Flux transformer wrapper. + + Args: + model (nn.Module): The Flux transformer model to wrap + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization + """ + super().__init__(model) + + def get_onnx_params( + self, + batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH, + cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM, + ) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the Flux transformer. + + Creates example inputs for all Flux-specific inputs including hidden states, + text embeddings, timestep conditioning, and AdaLN embeddings. + + Args: + batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE) + seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH) + cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + example_inputs = { + # Latent representation of the image + "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32), + "encoder_hidden_states": torch.randn( + batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32 + ), + "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + "img_ids": torch.randn(cl, 3, dtype=torch.float32), + "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32), + # AdaLN embeddings for dual transformer blocks + # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_emb": torch.randn( + self.model.config["num_layers"], + constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # AdaLN embeddings for single transformer blocks + # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_single_emb": torch.randn( + self.model.config["num_single_layers"], + constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # Output AdaLN embedding + # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection + "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32), + } + + output_names = ["output"] + + # Define dynamic dimensions for runtime flexibility + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "cl"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_len"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "steps"}, + "img_ids": {0: "cl"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = None, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export the Flux transformer model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + + Returns: + str: Path to the exported ONNX model + """ + + if use_onnx_subfunctions: + export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}} + + # Sort _use_default_values in config to ensure consistent hash generation during export + self.model.config["_use_default_values"].sort() + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + export_kwargs=export_kwargs, + offload_pt_weights=False, # As weights are needed with AdaLN changes + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000..24eb36f53 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,218 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +from tqdm import tqdm + +from QEfficient.utils._utils import load_json +from QEfficient.utils.logging_utils import logger + + +def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int: + """ + Calculate the compressed latent dimension. + Args: + height (int): Target image height in pixels + width (int): Target image width in pixels + vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux) + + Returns: + int: Compressed latent dimension (cl) for transformer input buffer allocation + """ + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing) + cl = (latent_height * latent_width) // 4 + return cl, latent_height, latent_width + + +def config_manager(cls, config_source: Optional[str] = None): + """ + JSON-based compilation configuration manager for diffusion pipelines. + + Supports loading configuration from JSON files only. Automatically detects + model type and handles model-specific requirements. + Initialize the configuration manager. + + Args: + config_source: Path to JSON configuration file. If None, uses default config. + """ + if config_source is None: + config_source = cls.get_default_config_path() + + if not isinstance(config_source, str): + raise ValueError("config_source must be a path to JSON configuration file") + + # Direct use of load_json utility - no wrapper needed + if not os.path.exists(config_source): + raise FileNotFoundError(f"Configuration file not found: {config_source}") + + cls.custom_config = load_json(config_source) + + +def set_module_device_ids(cls): + """ + Set device IDs for each module based on the custom configuration. + + Iterates through all modules in the pipeline and assigns device IDs + from the configuration file to each module's device_ids attribute. + """ + config_modules = cls.custom_config["modules"] + for module_name, module_obj in cls.modules.items(): + module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"] + + +def compile_modules_parallel( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules in parallel using ThreadPoolExecutor. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + """ + + def _prepare_and_compile(module_name: str, module_obj: Any) -> None: + """Prepare specializations and compile a single module.""" + specializations = config["modules"][module_name]["specializations"].copy() + compile_kwargs = config["modules"][module_name]["compilation"] + + if specialization_updates and module_name in specialization_updates: + specializations.update(specialization_updates[module_name]) + + module_obj.compile(specializations=[specializations], **compile_kwargs) + + # Execute compilations in parallel + with ThreadPoolExecutor(max_workers=len(modules)) as executor: + futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()} + + with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar: + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Compilation failed for {futures[future]}: {e}") + raise + pbar.update(1) + + +def compile_modules_sequential( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules sequentially. + + This function provides a generic way to compile diffusion pipeline modules + sequentially, which is the default behavior for backward compatibility. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + + """ + for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"): + module_config = config["modules"] + specializations = module_config[module_name]["specializations"].copy() + compile_kwargs = module_config[module_name]["compilation"] + + # Apply dynamic specialization updates if provided + if specialization_updates and module_name in specialization_updates: + specializations.update(specialization_updates[module_name]) + + # Compile the module to QPC format + module_obj.compile(specializations=[specializations], **compile_kwargs) + + +@dataclass(frozen=True) +class ModulePerf: + """ + Data class to store performance metrics for a pipeline module. + + Attributes: + module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder') + perf: Performance metric in seconds. Can be a single float for modules that run once, + or a list of floats for modules that run multiple times (e.g., transformer steps) + """ + + module_name: str + perf: int + + +@dataclass(frozen=True) +class QEffPipelineOutput: + """ + Data class to store the output of a QEfficient diffusion pipeline. + + Attributes: + pipeline_module: List of ModulePerf objects containing performance metrics for each module + images: Generated images as either a list of PIL Images or numpy array + """ + + pipeline_module: list[ModulePerf] + images: Union[List[PIL.Image.Image], np.ndarray] + + def __repr__(self): + output_str = "=" * 60 + "\n" + output_str += "QEfficient Diffusers Pipeline Inference Report\n" + output_str += "=" * 60 + "\n\n" + + # Module-wise inference times + output_str += "Module-wise Inference Times:\n" + output_str += "-" * 60 + "\n" + + # Calculate E2E time while iterating + e2e_time = 0 + for module_perf in self.pipeline_module: + module_name = module_perf.module_name + inference_time = module_perf.perf + + # Add to E2E time + e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time + + # Format module name for display + display_name = module_name.replace("_", " ").title() + + # Handle transformer specially as it has a list of times + if isinstance(inference_time, list) and len(inference_time) > 0: + total_time = sum(inference_time) + avg_time = total_time / len(inference_time) + output_str += f" {display_name:25s} {total_time:.4f} s\n" + output_str += f" - Total steps: {len(inference_time)}\n" + output_str += f" - Average per step: {avg_time:.4f} s\n" + output_str += f" - Min step time: {min(inference_time):.4f} s\n" + output_str += f" - Max step time: {max(inference_time):.4f} s\n" + else: + # Single inference time value + output_str += f" {display_name:25s} {inference_time:.4f} s\n" + + output_str += "-" * 60 + "\n\n" + + # Print E2E time after all modules + output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n" + output_str += "=" * 60 + "\n" + + return output_str + + +# List of module name that require special handling during export +# when use_onnx_subfunctions is enabled +ONNX_SUBFUNCTION_MODULE = ["transformer"] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8edc1f3f0..16a809c96 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -124,21 +124,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path) - @property - def model_name(self) -> str: - """ - Get the name of the underlying HuggingFace model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - class MultimodalUtilityMixin: """ @@ -701,21 +686,6 @@ def compile( **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying vision encoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -869,21 +839,6 @@ def compile( **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying language decoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -946,21 +901,6 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ @@ -2131,21 +2071,6 @@ def cloud_ai_100_generate( ), ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -2437,21 +2362,6 @@ def __init__( if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - @property - def model_name(self) -> str: - """ - Get the name of the underlying Causal Language Model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 21a867eb5..07b9fe7e1 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -197,6 +197,10 @@ Starcoder2ForCausalLM, Starcoder2Model, ) +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, @@ -417,6 +421,10 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.t5.modeling_t5 import ( + QEffT5Attention, + QEffT5LayerNorm, +) from QEfficient.transformers.models.whisper.modeling_whisper import ( QEffWhisperAttention, QEffWhisperDecoder, @@ -808,6 +816,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} +class T5ModelTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + T5Attention: QEffT5Attention, + T5LayerNorm: QEffT5LayerNorm, + } + + class PoolingTransform: """ Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/t5/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py new file mode 100644 index 000000000..f54201465 --- /dev/null +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -0,0 +1,145 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +from transformers import EncoderDecoderCache +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) + + +class QEffT5LayerNorm(T5LayerNorm): + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class QEffT5Attention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + if past_key_value is not None: # This block is where the patch applies + position_bias = position_bias[:, :, -1:, :] # Added by patch + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index 49f0ad30b..3d6583f85 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -16,7 +16,6 @@ create_model_params, custom_format_warning, dump_qconfig, - export_wrapper, generate_mdp_partition_config, get_num_layers_from_config, get_num_layers_vlm, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 131a7fc26..26bae7a34 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -12,7 +12,6 @@ import subprocess import xml.etree.ElementTree as ET from dataclasses import dataclass -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import requests @@ -27,9 +26,8 @@ PreTrainedTokenizerFast, ) -from QEfficient.utils.cache import QEFF_HOME from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants -from QEfficient.utils.hash_utils import create_export_hash, json_serializable +from QEfficient.utils.hash_utils import json_serializable from QEfficient.utils.logging_utils import logger @@ -532,61 +530,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict: """ model_params = copy.deepcopy(kwargs) model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST} - model_params["config"] = qeff_model.model.config.to_diff_dict() model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None) model_params["applied_transform_names"] = qeff_model._transform_names() return model_params -def export_wrapper(func): - def wrapper(self, *args, **kwargs): - export_dir = kwargs.get("export_dir", None) - parent_dir = self.model_architecture or self.model_name - export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name)) - - # PREPROCESSING OF PARAMETERS - - # Get the original signature - original_sig = inspect.signature(func) - - # Remove 'self' from parameters - params = list(original_sig.parameters.values())[1:] # skip 'self' - new_sig = inspect.Signature(params) - - # Bind args and kwargs to the new signature - bound_args = new_sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - # Get arguments as a dictionary - all_args = bound_args.arguments - - export_hash, filtered_hash_params = create_export_hash( - model_params=self.hash_params, - output_names=all_args.get("output_names"), - dynamic_axes=all_args.get("dynamic_axes"), - export_kwargs=all_args.get("export_kwargs", None), - onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), - use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False), - ) - - export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) - kwargs["export_dir"] = export_dir - self.export_hash = export_hash - - # _EXPORT CALL - onnx_path = func(self, *args, **kwargs) - - # POST-PROCESSING - # Dump JSON file with hashed parameters - hashed_params_export_path = export_dir / "hashed_export_params.json" - create_json(hashed_params_export_path, filtered_hash_params) - logger.info("Hashed parameters exported successfully.") - - return onnx_path - - return wrapper - - def execute_command(process: str, command: str, output_file_path: Optional[str] = None): """ Executes the give command using subprocess. diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index e0b003422..613d7049a 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -144,6 +144,13 @@ def get_models_dir(): # Molmo Constants MOLMO_IMAGE_HEIGHT = 536 MOLMO_IMAGE_WIDTH = 354 +# Flux Transformer Constants +FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 +FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 +FLUX_ADALN_HIDDEN_DIM = 3072 +FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context +FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 +FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM class Constants: diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py new file mode 100644 index 000000000..eea92a490 --- /dev/null +++ b/QEfficient/utils/export_utils.py @@ -0,0 +1,235 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import inspect +import re +import warnings +from pathlib import Path +from typing import Dict + +from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform +from QEfficient.transformers.cache_utils import InvalidIndexProvider +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export +from QEfficient.utils.cache import QEFF_HOME +from QEfficient.utils.hash_utils import create_export_hash +from QEfficient.utils.logging_utils import logger +from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches + + +def export_wrapper(func): + """ + Decorator for export methods that orchestrates the complete export lifecycle. + + Responsibilities: + 1. Prepare export directory structure + 2. Generate reproducible hash for export configuration + 3. Setup ONNX subfunction environment (if enabled) + 4. Execute the wrapped export function + 5. Cleanup subfunction environment (if enabled) + 6. Save export metadata + + Args: + func: The export method to wrap (typically _export) + + Returns: + Wrapped function with complete export lifecycle management + """ + + def wrapper(self, *args, **kwargs): + # 1. Prepare export directory + export_dir = _prepare_export_directory(self, kwargs) + + # 2. Generate hash and finalize export directory path + export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func) + export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) + kwargs["export_dir"] = export_dir + self.export_hash = export_hash + + # 3. Setup ONNX subfunctions if requested + # TODO: No need of this variable, if export_kwargs contains classes (refer diffusers) + if use_onnx_subfunctions := kwargs.get("use_onnx_subfunctions", False): + _setup_onnx_subfunctions(self, kwargs) + + # 4. Execute the actual export + onnx_path = func(self, *args, **kwargs) + + # 5. Save export metadata + _save_export_metadata(export_dir, filtered_hash_params) + + # 6. Always cleanup subfunctions if they were setup + if use_onnx_subfunctions: + _cleanup_onnx_subfunctions(self) + + return onnx_path + + return wrapper + + +def _prepare_export_directory(qeff_model, kwargs) -> Path: + """ + Prepare and return the base export directory path. + + Args: + qeff_model: The QEff model instance + kwargs: Keyword arguments containing optional export_dir + + Returns: + Path object for the base export directory + """ + export_dir = kwargs.get("export_dir", None) + parent_dir = qeff_model.model_architecture or qeff_model.model_name + return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name)) + + +def _generate_export_hash(qeff_model, args, kwargs, func): + """ + Generate export hash from model parameters and export arguments. + + The hash ensures reproducibility and prevents conflicts between + different export configurations. + + Args: + qeff_model: The QEff model instance + args: Positional arguments to the export function + kwargs: Keyword arguments to the export function + func: The export function being wrapped + + Returns: + Tuple of (export_hash: str, filtered_hash_params: dict) + """ + # Extract use_onnx_subfunctions before binding (it's used by wrapper, not _export) + use_onnx_subfunctions = kwargs.pop("use_onnx_subfunctions", False) + + # Extract function signature + original_sig = inspect.signature(func) + params = list(original_sig.parameters.values())[1:] # Skip 'self' + new_sig = inspect.Signature(params) + # Bind all arguments + bound_args = new_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + all_args = bound_args.arguments + + # Use the model's current configuration for hashing to ensure any post-load modifications are captured + # TODO: Replace with get_model_config property of modeling classes and remove the if-else + # Determine the config dict to use, preferring .to_diff_dict() if available + + if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"): + config_val = qeff_model.model.config.to_diff_dict() + elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"): + config_val = qeff_model.model.model.config.to_diff_dict() + else: + config_val = qeff_model.model.config + + qeff_model.hash_params.update( + { + "config": config_val, + } + ) + + # Generate hash from relevant parameters + export_hash, filtered_hash_params = create_export_hash( + model_params=qeff_model.hash_params, + output_names=all_args.get("output_names"), + dynamic_axes=all_args.get("dynamic_axes"), + export_kwargs=all_args.get("export_kwargs", None), + onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + return export_hash, filtered_hash_params + + +def _setup_onnx_subfunctions(qeff_model, kwargs): + """ + Setup ONNX subfunction export environment. + + This function prepares the model and environment for exporting with + ONNX subfunctions enabled. It: + - Applies necessary torch patches + - Modifies output names for subfunction compatibility + - Adds subfunction-specific ONNX transforms + - Updates export kwargs with module classes + + Args: + qeff_model: The QEff model instance + kwargs: Export keyword arguments (modified in-place). + """ + warnings.warn( + "The subfunction feature is experimental. Please note that using compile " + "consecutively with and without subfunction may produce inconsistent results." + ) + + # Apply torch patches for subfunction support + apply_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = True + + # Store original state for restoration during cleanup + qeff_model._original_onnx_transforms = qeff_model._onnx_transforms.copy() + + # Transform output names for subfunction compatibility + if "output_names" in kwargs: + kwargs["output_names"] = [ + re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"] + ] + + # Add subfunction-specific ONNX transforms + qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.append(CustomOpTransform) + + # Configure export to use modules as functions + export_kwargs = kwargs.get("export_kwargs", {}) + + # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation + export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + kwargs["export_kwargs"] = export_kwargs + + +def _cleanup_onnx_subfunctions(qeff_model): + """ + Cleanup ONNX subfunction export environment. + + Restores the model and environment to pre-subfunction state by: + - Undoing torch patches + - Resetting InvalidIndexProvider flag + - Restoring original ONNX transforms list + + Args: + qeff_model: The QEff model instance + + Note: + This function is called in a finally block to ensure cleanup + even if export fails. Errors during cleanup are logged but + not re-raised to avoid masking the original exception. + """ + try: + # Undo torch patches + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + + # Restore original ONNX transforms + if hasattr(qeff_model, "_original_onnx_transforms"): + qeff_model._onnx_transforms = qeff_model._original_onnx_transforms + delattr(qeff_model, "_original_onnx_transforms") + + except Exception as e: + logger.error(f"Error during subfunction cleanup: {e}") + + +def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict): + """ + Save export metadata to JSON file for reproducibility. + + Args: + export_dir: Directory where the export was saved + filtered_hash_params: Dictionary of parameters used for hashing + """ + # Import here to avoid circular dependency + from QEfficient.utils._utils import create_json + + hashed_params_path = export_dir / "hashed_export_params.json" + create_json(hashed_params_path, filtered_hash_params) + logger.info("Hashed parameters exported successfully.") diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 948b72e6a..68ccab0d4 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -14,7 +14,8 @@ def json_serializable(obj): if isinstance(obj, set): - return sorted(obj) + # Convert set to a sorted list of strings for consistent hashing + return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png new file mode 100644 index 000000000..9e58da61d Binary files /dev/null and b/docs/image/girl_laughing.png differ diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md new file mode 100644 index 000000000..2a3c1605f --- /dev/null +++ b/examples/diffusers/flux/README.md @@ -0,0 +1,243 @@ +# FLUX.1-schnell Image Generation Examples + +This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs. + +## Overview + +FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation. + +## Files + +- **`flux_1_schnell.py`** - Basic example showing simple image generation +- **`flux_1_shnell_custom.py`** - Advanced example with customization options +- **`flux_config.json`** - Configuration file for pipeline modules + +## Quick Start + +### Basic Usage + +The simplest way to generate images with FLUX.1-schnell: + +```python +from QEfficient import QEffFluxPipeline +import torch + +# Initialize pipeline +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate image +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Save image +output.images[0].save("girl_laughing.png") +``` + +Run the basic example: +```bash +python flux_1_schnell.py +``` + +## Advanced Customization + +The `flux_1_shnell_custom.py` example demonstrates several advanced features: + +### 1. Custom Model Components + +You can provide custom text encoders, transformers, and tokenizers: + +```python +pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + text_encoder=custom_text_encoder, + transformer=custom_transformer, + tokenizer=custom_tokenizer, +) +``` + +### 2. Custom Scheduler + +Replace the default scheduler with your own: + +```python +pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) +``` + +### 3. Reduce Model Layers for Faster Inference + +Trade quality for speed by reducing transformer blocks: + +```python +original_blocks = pipeline.transformer.model.transformer_blocks +org_single_blocks = pipeline.transformer.model.single_transformer_blocks +pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +pipeline.transformer.model.config['num_layers'] = 1 +pipeline.transformer.model.config['num_single_layers'] = 1 +``` + +### 4. Pre-compile with Custom Configuration + +Compile the model separately before generation: + +```python +pipeline.compile( + compile_config="examples/diffusers/flux/flux_config.json", + height=512, + width=512, + use_onnx_subfunctions=False +) +``` + +### 5. Runtime Configuration + +Use custom configuration during generation: + +```python +output = pipeline( + prompt="A girl laughing", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) +``` + +Run the advanced example: +```bash +python flux_1_shnell_custom.py +``` + +## Configuration File + +The `flux_config.json` file controls compilation and execution settings for each pipeline module: + +### Module Structure + +The configuration includes four main modules: + +1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence) +2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence) +3. **transformer** - Core diffusion transformer model +4. **vae_decoder** - Decodes latents to images + +### Configuration Parameters + +Each module has three sections: + +#### Specializations +- `batch_size`: Batch size for inference +- `seq_len`: Sequence length for text encoders +- `steps`: Number of inference steps (transformer only) +- `channels`: Number of channels (VAE decoder only) + +#### Compilation +- `onnx_path`: Path to pre-exported ONNX model (null for auto-export) +- `compile_dir`: Directory for compiled artifacts (null for auto-generation) +- `mdp_ts_num_devices`: Number of devices for model data parallelism +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use +- `mos`: Multi-output streaming (transformer only) +- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only) +- `aic-enable-depth-first`: Enable depth-first compilation (VAE only) + +#### Execute +- `device_ids`: List of device IDs to use (null for auto-selection) + +### Example Configuration Snippet + +```json +{ + "transformer": { + "specializations": { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": { + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": { + "device_ids": null + } + } +} +``` + +## Key Parameters + +### Generation Parameters + +- **`prompt`** (str): Text description of the image to generate +- **`height`** (int): Output image height in pixels (default: 1024) +- **`width`** (int): Output image width in pixels (default: 1024) +- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell) +- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell) +- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended) +- **`generator`** (torch.Generator): Random seed for reproducibility +- **`parallel_compile`** (bool): Enable parallel compilation of modules +- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental) + +### Performance Tuning + +- **Faster inference**: Reduce `num_inference_steps` or model layers +- **Better quality**: Increase `num_inference_steps` or use full model +- **Memory optimization**: Adjust `mdp_ts_num_devices` in config +- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16` + +## Output + +The pipeline returns an output object containing: +- `images`: List of generated PIL Image objects +- Performance metrics (timing information) + +Example output: +```python +print(output) # Displays performance information +image = output.images[0] # Access the generated image +image.save("output.png") # Save to disk +``` + +## Hardware Requirements + +- Qualcomm Cloud AI 100 accelerator +- Sufficient memory for model compilation and execution +- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`) + +## Notes + +- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0` +- The transformer module benefits most from multi-device parallelism +- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use +- Custom configurations allow fine-tuning for specific hardware setups + +## Troubleshooting + +- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices` +- **Slow compilation**: Enable `parallel_compile=True` +- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0) +- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection + +## References + +- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +- [QEfficient Documentation](../../../README.md) +- [Diffusers Pipeline Guide](../../README.md) diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py new file mode 100644 index 000000000..46f26bb6b --- /dev/null +++ b/examples/diffusers/flux/flux_1_schnell.py @@ -0,0 +1,45 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1-schnell Image Generation Example + +This example demonstrates how to use the QEffFluxPipeline to generate images +using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a +fast, distilled version of the FLUX.1 text-to-image model optimized for +speed with minimal quality loss. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# Initialize the FLUX.1-schnell pipeline from pretrained weights +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate an image from a text prompt +# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Extract the generated image from the output +image = output.images[0] + +# Save the generated image to disk +image.save("girl_laughing.png") + +# Print the output object (contains perf info) +print(output) diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py new file mode 100644 index 000000000..201ebe659 --- /dev/null +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -0,0 +1,113 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1 Schnell Custom Configuration Example + +This example demonstrates how to customize the FLUX.1 model with various options: +1. Custom image dimensions (height/width) +2. Custom transformer model and text encoder +3. Custom scheduler configuration +4. Reduced model layers for faster inference +5. Custom compilation settings +6. Custom runtime configuration via JSON config file + +Use this example to learn how to fine-tune FLUX.1 for your specific needs. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization with default parameters +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") +# Option 2: Advanced initialization with custom modules +# Uncomment and modify to use your own custom components: +# +# pipeline = QEffFluxPipeline.from_pretrained( +# "black-forest-labs/FLUX.1-schnell", +# text_encoder=custom_text_encoder, # Your custom CLIP text encoder +# transformer=custom_transformer, # Your custom transformer model +# tokenizer=custom_tokenizer, # Your custom tokenizer +# ) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# Uncomment to use a custom scheduler (e.g., different sampling methods): +# +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up image generation. +# +# Trade-off: Faster inference but potentially lower image quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only the first transformer block: +# +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +# pipeline.transformer.model.config['num_layers'] = 1 +# pipeline.transformer.model.config['num_single_layers'] = 1 + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# Pre-compile the model for optimized performance on target hardware. +# +# When to use: +# - When you want to compile the model separately before generation +# - When you need to skip image generation and only prepare the model +# +# NOTE-1: If compile_config is not specified, the default configuration from +# QEfficient/diffusers/pipelines/flux/flux_config.json will be used +# +# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended) +# This feature improves export performance by breaking down the model into smaller, +# more manageable ONNX functions, which can lead to improve compile time. +# Uncomment to compile with a custom configuration: +# pipeline.compile( +# compile_config="examples/diffusers/flux/flux_config.json", +# height=512, +# width=512, +# use_onnx_subfunctions=False +# ) + +# ============================================================================ +# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate an image using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +output = pipeline( + prompt="A laughing girl", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +image = output.images[0] +# Save the generated image to disk +image.save("laughing_girl.png") +print(output) diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/examples/diffusers/flux/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/pyproject.toml b/pyproject.toml index 8e179ab4a..fe0c42ec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,9 +20,10 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.55.0", + "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", + "peft==0.17.0", "datasets==2.20.0", "fsspec==2023.6.0", "multidict==6.0.4", diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 134770638..d878076fa 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -22,6 +22,7 @@ pipeline { . preflight_qeff/bin/activate && pip install --upgrade pip setuptools && pip install .[test] && + pip install .[diffusers] && pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs @@ -69,7 +70,7 @@ pipeline { } stage('QAIC MultiModal Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -86,7 +87,7 @@ pipeline { } stage('Inference Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " #source /qnn_sdk/bin/envsetup.sh && @@ -162,7 +163,7 @@ pipeline { // } stage('Finetune CLI Tests') { steps { - timeout(time: 5, unit: 'MINUTES') { + timeout(time: 20, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py index d1b7a4653..f63b18f1a 100644 --- a/tests/base/test_export_memory_offload.py +++ b/tests/base/test_export_memory_offload.py @@ -27,7 +27,7 @@ @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py new file mode 100644 index 000000000..305116c03 --- /dev/null +++ b/tests/diffusers/diffusers_utils.py @@ -0,0 +1,175 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Common utilities for diffusion pipeline testing. +Provides essential functions for MAD validation, image validation +hash verification, and other testing utilities. +""" + +import os +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +from PIL import Image + + +class DiffusersTestUtils: + """Essential utilities for diffusion pipeline testing""" + + @staticmethod + def validate_image_generation( + image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0 + ) -> Dict[str, Any]: + """ + Validate generated image properties. + Args: + image: Generated PIL Image + expected_size: Expected (width, height) tuple + min_variance: Minimum pixel variance to ensure image is not blank + + Returns: + Dict containing validation results + Raises: + AssertionError: If image validation fails + """ + # Basic image validation + assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}" + assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}" + assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}" + + # Variance check (ensure image is not blank) + img_array = np.array(image) + image_variance = float(img_array.std()) + assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})" + + return { + "size": image.size, + "mode": image.mode, + "variance": image_variance, + "mean_pixel_value": float(img_array.mean()), + "min_pixel": int(img_array.min()), + "max_pixel": int(img_array.max()), + "valid": True, + } + + @staticmethod + def check_file_exists(file_path: str, file_type: str = "file") -> bool: + """ + Check if file exists and log result. + Args: + file_path: Path to check + file_type: Description of file type for logging + Returns: + bool: True if file exists + """ + exists = os.path.exists(file_path) + status = "āœ…" if exists else "āŒ" + print(f"{status} {file_type}: {file_path}") + return exists + + @staticmethod + def print_test_header(title: str, config: Dict[str, Any]) -> None: + """ + Print formatted test header with configuration details. + + Args: + title: Test title + config: Test configuration dictionary + """ + print(f"\n{'=' * 80}") + print(f"{title}") + print(f"{'=' * 80}") + + if "model_setup" in config: + setup = config["model_setup"] + for k, v in setup.items(): + print(f"{k} : {v}") + + if "functional_testing" in config: + func = config["functional_testing"] + print(f"Test Prompt: {func.get('test_prompt', 'N/A')}") + print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}") + print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}") + + print(f"{'=' * 80}") + + +class MADValidator: + """Specialized class for MAD validation - always enabled, always reports, always fails on exceed""" + + def __init__(self, tolerances: Dict[str, float] = None): + """ + Initialize MAD validator. + MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded. + + Args: + tolerances: Dictionary of module_name -> tolerance mappings + """ + self.tolerances = tolerances + self.results = {} + + def calculate_mad( + self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray] + ) -> float: + """ + Calculate Max Absolute Deviation between two tensors. + + Args: + tensor1: First tensor (PyTorch or NumPy) + tensor2: Second tensor (PyTorch or NumPy) + + Returns: + float: Maximum absolute difference between tensors + """ + if isinstance(tensor1, torch.Tensor): + tensor1 = tensor1.detach().numpy() + if isinstance(tensor2, torch.Tensor): + tensor2 = tensor2.detach().numpy() + + return float(np.max(np.abs(tensor1 - tensor2))) + + def validate_module_mad( + self, + pytorch_output: Union[torch.Tensor, np.ndarray], + qaic_output: Union[torch.Tensor, np.ndarray], + module_name: str, + step_info: str = "", + ) -> bool: + """ + Validate MAD for a specific module. + Always validates, always reports, always fails if tolerance exceeded. + + Args: + pytorch_output: PyTorch reference output + qaic_output: QAIC inference output + module_name: Name of the module + step_info: Additional step information for logging + + Returns: + bool: True if validation passed + + Raises: + AssertionError: If MAD exceeds tolerance + """ + mad_value = self.calculate_mad(pytorch_output, qaic_output) + + # Always report MAD value + step_str = f" {step_info}" if step_info else "" + print(f"šŸ” {module_name.upper()} MAD{step_str}: {mad_value:.8f}") + + # Always validate - fail if tolerance exceeded + tolerance = self.tolerances.get(module_name, 1e-2) + if mad_value > tolerance: + raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}") + + # Store result + if module_name not in self.results: + self.results[module_name] = [] + self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance}) + return True diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json new file mode 100644 index 000000000..7d0c17d55 --- /dev/null +++ b/tests/diffusers/flux_test_config.json @@ -0,0 +1,123 @@ +{ + "model_setup": { + "height": 256, + "width": 256, + "num_transformer_layers": 2, + "num_single_layers": 2, + "use_onnx_subfunctions": false + }, + "mad_validation": { + "tolerances": { + "clip_text_encoder": 0.1, + "t5_text_encoder": 5.5, + "transformer": 2.0, + "vae_decoder": 1.0 + } + }, + "pipeline_params": { + "test_prompt": "A cat holding a sign that says hello world", + "num_inference_steps": 2, + "guidance_scale": 0.0, + "max_sequence_length": 256, + "validate_gen_img": true, + "min_image_variance": 1.0, + "custom_config_path": null + }, + "validation_checks": { + "image_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "aic-enable-depth-first": true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + } + } + +} diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py new file mode 100644 index 000000000..6f4396a20 --- /dev/null +++ b/tests/diffusers/test_flux.py @@ -0,0 +1,448 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import pytest +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from QEfficient import QEffFluxPipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ModulePerf, + QEffPipelineOutput, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance +CONFIG_PATH = "tests/diffusers/flux_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +def flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 256, + width: int = 256, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + mad_tolerances: Dict[str, float] = None, +): + """ + Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__() + while adding comprehensive MAD validation at each step. + + This function follows the EXACT same structure as QEffFluxPipeline.__call__() + but adds MAD validation hooks throughout the process. + """ + # Initialize MAD validator + mad_validator = MADValidator(tolerances=mad_tolerances) + + device = "cpu" + + # Step 1: Load configuration, compile models + pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width) + + # Set device IDs for all modules based on configuration + set_module_device_ids(pipeline) + + # Validate all inputs + pipeline.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + # Set pipeline attributes + pipeline._guidance_scale = guidance_scale + pipeline._interrupt = False + batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"] + + # Step 3: Encode prompts with both text encoders + # Use pipeline's encode_prompt method + (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + # Deactivate text encoder qpc sessions + pipeline.text_encoder.qpc_session.deactivate() + pipeline.text_encoder_2.qpc_session.deactivate() + + # MAD Validation for Text Encoders + print("šŸ” Performing MAD validation for text encoders...") + mad_validator.validate_module_mad( + clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder" + ) + mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder") + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + pipeline._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = pipeline.transformer.model.config.in_channels // 4 + latents, latent_image_ids = pipeline.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + t5_qaic_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Initialize transformer inference session + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession( + str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids + ) + + # Calculate compressed latent dimension (cl) for transformer buffer allocation + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32), + } + pipeline.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + pipeline.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds) + + # Compute AdaLN embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.transformer_blocks)): + block = pipeline.transformer.model.transformer_blocks[block_idx] + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)): + block = pipeline.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = pipeline.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(), + "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # MAD Validation for Transformer - PyTorch reference inference + noise_pred_torch = pytorch_pipeline.transformer( + hidden_states=latents, + encoder_hidden_states=t5_torch_prompt_embeds, + pooled_projections=clip_torch_pooled_prompt_embeds, + timestep=torch.tensor(timestep), + img_ids=latent_image_ids, + txt_ids=text_ids, + return_dict=False, + )[0] + + # Run transformer inference and measure time + start_transformer_step_time = time.time() + outputs = pipeline.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.time() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Transformer MAD validation + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + outputs["output"], + "transformer", + f"step {i} (t={t.item():.1f})", + ) + + # Update latents using scheduler + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images + if output_type == "latent": + image = latents + vae_decode_perf = 0.0 # No VAE decoding for latent output + else: + # Unpack and denormalize latents + latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor) + + # Denormalize latents + latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor + # Initialize VAE decoder inference session + if pipeline.vae_decode.qpc_session is None: + pipeline.vae_decode.qpc_session = QAICInferenceSession( + str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)} + pipeline.vae_decode.qpc_session.set_buffers(output_buffer) + + # MAD Validation for VAE + # PyTorch reference inference + image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.time() + image = pipeline.vae_decode.qpc_session.run(inputs) + end_decode_time = time.time() + vae_decode_perf = end_decode_time - start_decode_time + + # VAE MAD validation + mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder") + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) + + +@pytest.fixture(scope="session") +def flux_pipeline(): + """Setup compiled Flux pipeline for testing""" + config = INITIAL_TEST_CONFIG["model_setup"] + + pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + use_onnx_subfunctions=config["use_onnx_subfunctions"], + ) + + # Reduce to 2 layers for testing + original_blocks = pipeline.transformer.model.transformer_blocks + org_single_blocks = pipeline.transformer.model.single_transformer_blocks + + pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"] + pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"] + pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + + ### Pytorch pipeline + pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks + org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks + pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList( + [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + return pipeline, pytorch_pipeline + + +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +def test_flux_pipeline(flux_pipeline): + """ + Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py: + - 256x256 resolution - 2 transformer layers + - MAD validation + - Functional image generation test + - Export/compilation checks + - Returns QEffPipelineOutput with performance metrics + """ + pipeline, pytorch_pipeline = flux_pipeline + config = INITIAL_TEST_CONFIG + + # Print test header + DiffusersTestUtils.print_test_header( + f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers", + config, + ) + + # Test parameters + test_prompt = config["pipeline_params"]["test_prompt"] + num_inference_steps = config["pipeline_params"]["num_inference_steps"] + guidance_scale = config["pipeline_params"]["guidance_scale"] + max_sequence_length = config["pipeline_params"]["max_sequence_length"] + + # Generate with MAD validation + generator = torch.manual_seed(42) + start_time = time.time() + + try: + # Run the pipeline with integrated MAD validation (follows exact pipeline flow) + result = flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + prompt=test_prompt, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + max_sequence_length=max_sequence_length, + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + parallel_compile=True, + return_dict=True, + ) + + execution_time = time.time() - start_time + + # Validate image generation + if config["pipeline_params"]["validate_gen_img"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No images generated" + + generated_image = result.images[0] + expected_size = (config["model_setup"]["height"], config["model_setup"]["width"]) + # Validate image properties using utilities + image_validation = DiffusersTestUtils.validate_image_generation( + generated_image, expected_size, config["pipeline_params"]["min_image_variance"] + ) + + print("\nāœ… IMAGE VALIDATION PASSED") + print(f" - Size: {image_validation['size']}") + print(f" - Mode: {image_validation['mode']}") + print(f" - Variance: {image_validation['variance']:.2f}") + print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}") + file_path = "test_flux_256x256_2layers.png" + # Save test image + generated_image.save(file_path) + + if os.path.exists(file_path): + print(f"Image saved successfully at: {file_path}") + else: + print("Image was not saved.") + + if config["validation_checks"]["onnx_export"]: + # Check if ONNX files exist (basic check) + print("\nšŸ” ONNX Export Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path: + DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX") + + if config["validation_checks"]["compilation"]: + # Check if QPC files exist (basic check) + print("\nšŸ” Compilation Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path: + DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC") + + # Print test summary using utilities + print(f"\nTotal execution time: {execution_time:.4f}s") + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + # This allows running the test file directly for debugging + pytest.main([__file__, "-v", "-s", "-m", "flux"]) +# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 0810ac6ba..3eaaf0f69 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -211,7 +211,7 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py index 59281b73b..bc53cb539 100644 --- a/tests/transformers/test_speech_seq2seq.py +++ b/tests/transformers/test_speech_seq2seq.py @@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path): @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py index fefa73973..b7a5495c6 100644 --- a/tests/utils/test_hash_utils.py +++ b/tests/utils/test_hash_utils.py @@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value): def test_json_serializable(): # Test with a set - assert json_serializable({1, 2, 3}) == [1, 2, 3] + assert json_serializable({1, 2, 3}) == ["1", "2", "3"] # Test with an unsupported type with pytest.raises(TypeError): json_serializable({1, 2, 3, {4, 5}})