diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index b507363c3..4765e7bc8 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -29,6 +29,7 @@ ) from QEfficient.compile.compile_helper import compile from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline +from QEfficient.diffusers.pipelines.qwen_image.pipeline_qwenimage import QEFFQwenImagePipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -55,6 +56,7 @@ "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", "QEffFluxPipeline", + "QEFFQwenImagePipeline", ] diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py index d3c84ee63..89cd19542 100644 --- a/QEfficient/diffusers/models/pytorch_transforms.py +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from diffusers.models.attention_processor import Attention from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm from diffusers.models.transformers.transformer_flux import ( FluxAttention, @@ -13,6 +14,10 @@ FluxTransformer2DModel, FluxTransformerBlock, ) +from diffusers.models.transformers.transformer_qwenimage import ( + QwenDoubleStreamAttnProcessor2_0, + QwenImageTransformer2DModel, +) from torch import nn from QEfficient.base.pytorch_transforms import ModuleMappingTransform @@ -29,6 +34,11 @@ QEffFluxTransformer2DModel, QEffFluxTransformerBlock, ) +from QEfficient.diffusers.models.transformers.transformer_qwenimage import ( + QEffQwenDoubleStreamAttnProcessor2_0, + QEffQwenImageAttention, + QEffQwenImageTransformer2DModel, +) class CustomOpsTransform(ModuleMappingTransform): @@ -45,6 +55,9 @@ class AttentionTransform(ModuleMappingTransform): FluxTransformer2DModel: QEffFluxTransformer2DModel, FluxAttention: QEffFluxAttention, FluxAttnProcessor: QEffFluxAttnProcessor, + QwenImageTransformer2DModel: QEffQwenImageTransformer2DModel, + QwenDoubleStreamAttnProcessor2_0: QEffQwenDoubleStreamAttnProcessor2_0, + Attention: QEffQwenImageAttention, } diff --git a/QEfficient/diffusers/models/transformers/transformer_qwenimage.py b/QEfficient/diffusers/models/transformers/transformer_qwenimage.py new file mode 100644 index 000000000..2f6651007 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_qwenimage.py @@ -0,0 +1,409 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + + +import functools +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.models.transformers.transformer_qwenimage import ( + QwenDoubleStreamAttnProcessor2_0, + QwenImageTransformer2DModel, +) +from diffusers.utils.constants import USE_PEFT_BACKEND +from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers + +logger = logging.getLogger(__name__) + + +def qeff_apply_rotary_emb_qwen(x, freqs_cos, freqs_sin): + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2] + x1 = x_reshaped[..., 0] # [B, S, H, D//2] + x2 = x_reshaped[..., 1] # [B, S, H, D//2] + + # Reshape for broadcasting: [S, D//2] -> [1, S, 1, D//2] + freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2) + freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2) + + # Apply rotation + x_out1 = x1 * freqs_cos - x2 * freqs_sin # Real part + x_out2 = x1 * freqs_sin + x2 * freqs_cos # Imaginary part + + # Stack and reshape back + x_out = torch.stack([x_out1, x_out2], dim=-1) # [B, S, H, D//2, 2] + x_out = x_out.flatten(-2) # [B, S, H, D] + return x_out.type_as(x) + + +class QEffQwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.scale_rope = scale_rope + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + + # Store cos and sin separately instead of complex numbers + pos_freqs_list = [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ] + self.pos_freqs_cos = torch.cat([f[0] for f in pos_freqs_list], dim=1) + self.pos_freqs_sin = torch.cat([f[1] for f in pos_freqs_list], dim=1) + + neg_freqs_list = [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ] + self.neg_freqs_cos = torch.cat([f[0] for f in neg_freqs_list], dim=1) + self.neg_freqs_sin = torch.cat([f[1] for f in neg_freqs_list], dim=1) + + self.rope_cache = {} + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos_cos = self.pos_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1) + freqs_pos_sin = self.pos_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg_cos = self.neg_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg_sin = self.neg_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1) + + # Frame dimension + freqs_frame_cos = freqs_pos_cos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame_sin = freqs_pos_sin[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + + if self.scale_rope: + freqs_height_cos = torch.cat( + [freqs_neg_cos[1][-(height - height // 2) :], freqs_pos_cos[1][: height // 2]], dim=0 + ) + freqs_height_sin = torch.cat( + [freqs_neg_sin[1][-(height - height // 2) :], freqs_pos_sin[1][: height // 2]], dim=0 + ) + freqs_height_cos = freqs_height_cos.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = freqs_height_sin.view(1, height, 1, -1).expand(frame, height, width, -1) + + freqs_width_cos = torch.cat( + [freqs_neg_cos[2][-(width - width // 2) :], freqs_pos_cos[2][: width // 2]], dim=0 + ) + freqs_width_sin = torch.cat( + [freqs_neg_sin[2][-(width - width // 2) :], freqs_pos_sin[2][: width // 2]], dim=0 + ) + freqs_width_cos = freqs_width_cos.view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = freqs_width_sin.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height_cos = freqs_pos_cos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = freqs_pos_sin[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width_cos = freqs_pos_cos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = freqs_pos_sin[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs_cos = torch.cat([freqs_frame_cos, freqs_height_cos, freqs_width_cos], dim=-1).reshape(seq_lens, -1) + freqs_sin = torch.cat([freqs_frame_sin, freqs_height_sin, freqs_width_sin], dim=-1).reshape(seq_lens, -1) + + return freqs_cos.clone().contiguous(), freqs_sin.clone().contiguous() + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: + video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video + txt_length: [bs] a list of 1 integers representing the length of the text + Returns: + Tuple of (vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin) + """ + if self.pos_freqs_cos.device != device: + self.pos_freqs_cos = self.pos_freqs_cos.to(device) + self.pos_freqs_sin = self.pos_freqs_sin.to(device) + self.neg_freqs_cos = self.neg_freqs_cos.to(device) + self.neg_freqs_sin = self.neg_freqs_sin.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs_cos_list = [] + vid_freqs_sin_list = [] + max_vid_index = 0 + + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq_cos, video_freq_sin = self.rope_cache[rope_key] + else: + video_freq_cos, video_freq_sin = self._compute_video_freqs(frame, height, width, idx) + + video_freq_cos = video_freq_cos.to(device) + video_freq_sin = video_freq_sin.to(device) + vid_freqs_cos_list.append(video_freq_cos) + vid_freqs_sin_list.append(video_freq_sin) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs_cos = self.pos_freqs_cos[max_vid_index : max_vid_index + max_len, ...] + txt_freqs_sin = self.pos_freqs_sin[max_vid_index : max_vid_index + max_len, ...] + + vid_freqs_cos = torch.cat(vid_freqs_cos_list, dim=0) + vid_freqs_sin = torch.cat(vid_freqs_sin_list, dim=0) + + return vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + # Return cos and sin separately instead of complex tensor + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + +class QEffQwenImageTransformer2DModel(QwenImageTransformer2DModel): + def __qeff_init__(self): + self.pos_embed = QEffQwenEmbedRope(theta=10000, axes_dim=list(self.axes_dims_rope), scale_rope=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + frame: torch.Tensor = None, + height: torch.Tensor = None, + width: torch.Tensor = None, + txt_seq_lens: torch.Tensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # Convert scalar tensors to Python integers and create img_shapes list + if isinstance(frame, torch.Tensor): + frame = frame.item() if frame.numel() == 1 else int(frame[0]) + if isinstance(height, torch.Tensor): + height = height.item() if height.numel() == 1 else int(height[0]) + if isinstance(width, torch.Tensor): + width = width.item() if width.numel() == 1 else int(width[0]) + + if not img_shapes: + img_shapes = [(frame, height, width)] + + # Convert txt_seq_lens to list if it's a tensor + if isinstance(txt_seq_lens, torch.Tensor): + txt_seq_lens = txt_seq_lens.tolist() + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.img_in(hidden_states) + + timestep = timestep.to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + encoder_hidden_states_mask, + temb, + image_rotary_emb, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class QEffQwenDoubleStreamAttnProcessor2_0(QwenDoubleStreamAttnProcessor2_0): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if encoder_hidden_states is None: + raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query = attn.to_q(hidden_states) + img_key = attn.to_k(hidden_states) + img_value = attn.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query = attn.add_q_proj(encoder_hidden_states) + txt_key = attn.add_k_proj(encoder_hidden_states) + txt_value = attn.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (attn.heads, -1)) + img_key = img_key.unflatten(-1, (attn.heads, -1)) + img_value = img_value.unflatten(-1, (attn.heads, -1)) + + txt_query = txt_query.unflatten(-1, (attn.heads, -1)) + txt_key = txt_key.unflatten(-1, (attn.heads, -1)) + txt_value = txt_value.unflatten(-1, (attn.heads, -1)) + + # Apply QK normalization + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + if attn.norm_added_q is not None: + txt_query = attn.norm_added_q(txt_query) + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + # Unpack the 4 tensors (cos and sin for both img and txt) + img_freqs_cos, img_freqs_sin, txt_freqs_cos, txt_freqs_sin = image_rotary_emb + + img_query = qeff_apply_rotary_emb_qwen(img_query, img_freqs_cos, img_freqs_sin) + img_key = qeff_apply_rotary_emb_qwen(img_key, img_freqs_cos, img_freqs_sin) + txt_query = qeff_apply_rotary_emb_qwen(txt_query, txt_freqs_cos, txt_freqs_sin) + txt_key = qeff_apply_rotary_emb_qwen(txt_key, txt_freqs_cos, txt_freqs_sin) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = attn.to_out[0](img_attn_output) + if len(attn.to_out) > 1: + img_attn_output = attn.to_out[1](img_attn_output) # dropout + + txt_attn_output = attn.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QEffQwenImageAttention(Attention): + def __qeff_init__(self): + self.processor = QEffQwenDoubleStreamAttnProcessor2_0() diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 6d9243fdc..a0158625c 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -479,3 +479,116 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ self._compile(specializations=specializations, **compiler_options) + + +class QEffQwenImageTransformer2DModel(QEFFBaseModel): + _pytorch_transforms = [AttentionTransform, CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffQwenImageTransformer2DModel is a wrapper class for QwenImage Transformer2D models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle QwenImage Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is designed for the QwenImage architecture that uses + transformer-based diffusion models with unique latent packing and attention mechanisms. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = model + + def get_onnx_config(self): + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # For testing purpose I have set this to constant values from the original models + latent_seq_len = 6032 + text_seq_len = 126 + hidden_dim = 64 + encoder_hidden_dim = 3584 + example_inputs = { + "hidden_states": torch.randn(bs, latent_seq_len, hidden_dim, dtype=torch.float32), + "encoder_hidden_states": torch.randn(bs, text_seq_len, encoder_hidden_dim, dtype=torch.float32), + "encoder_hidden_states_mask": torch.ones(bs, text_seq_len, dtype=torch.int64), + "timestep": torch.tensor([1000.0], dtype=torch.float32), + "frame": torch.tensor([1], dtype=torch.int64), + "height": torch.tensor([58], dtype=torch.int64), + "width": torch.tensor([104], dtype=torch.int64), + "txt_seq_lens": torch.tensor([126], dtype=torch.int64), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "latent_seq_len"}, + "encoder_hidden_states": {0: "batch_size", 1: "text_seq_len"}, + "encoder_hidden_states_mask": {0: "batch_size", 1: "text_seq_len"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + export_kwargs=None, + ): + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + export_kwargs=export_kwargs, + ) + + def get_specializations( + self, + batch_size: int, + latent_seq_len: int, + text_seq_len: int, + ): + specializations = [ + { + "batch_size": batch_size, + "latent_seq_len": latent_seq_len, + "text_seq_len": text_seq_len, + } + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.config.__dict__ diff --git a/QEfficient/diffusers/pipelines/qwen_image/__init__.py b/QEfficient/diffusers/pipelines/qwen_image/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/qwen_image/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py b/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py new file mode 100644 index 000000000..c1df0e012 --- /dev/null +++ b/QEfficient/diffusers/pipelines/qwen_image/pipeline_qwenimage.py @@ -0,0 +1,541 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import QwenImagePipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput +from diffusers.pipelines.qwenimage.pipeline_qwenimage import calculate_shift +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from QEfficient.diffusers.pipelines.pipeline_module import ( + QEffQwenImageTransformer2DModel, + QEffTextEncoder, + QEffVAE, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession + + +class QEFFQwenImagePipeline(QwenImagePipeline): + _hf_auto_class = QwenImagePipeline + """ + + A QEfficient-optimized QwenImage pipeline, inheriting from `diffusers.QwenImagePipeline`. + + This class integrates QEfficient components (e.g., optimized models for text encoder, + transformer, and VAE) to enhance performance, particularly for deployment on Qualcomm AI hardware. + It provides methods for text-to-image generation leveraging these optimized components. + """ + + def __init__(self, model, *args, **kwargs): + self.text_encoder = QEffTextEncoder(model.text_encoder) + self.text_encoder.tokenizer = model.tokenizer + self.transformer = QEffQwenImageTransformer2DModel(model.transformer) + self.vae_decode = QEffVAE(model.vae, "decoder") + self.vae_cpu = model.vae + self.tokenizer = model.tokenizer + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + self.prompt_template_encode = model.prompt_template_encode + self.prompt_template_encode_start_idx = model.prompt_template_encode_start_idx + + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + self.vae_scale_factor = 2 ** len(model.vae.temperal_downsample) if getattr(model, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = model.default_sample_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Instantiate a QEFFQwenImagePipeline from pretrained Diffusers models. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + The path to the pretrained model or its name. + **kwargs (additional keyword arguments): + Additional arguments that can be passed to the underlying `QwenImagePipeline.from_pretrained` + method. + """ + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + **kwargs, + ) + model.to("cpu") + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + ) + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # transformer + example_inputs_transformer, dynamic_axes_transformer, output_names_transformer = ( + self.transformer.get_onnx_config() + ) + + self.transformer.export( + inputs=example_inputs_transformer, + output_names=output_names_transformer, + dynamic_axes=dynamic_axes_transformer, + export_dir=export_dir, + ) + print("Exported transformers") + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + batch_size: int = 1, + num_devices_text_encoder: int = 1, + num_devices_transformer: int = 4, + num_devices_vae_decoder: int = 1, + num_cores: int = 16, + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the text encoder, transformer, and VAE decoder, + and compiles them into an optimized format for inference. + + Args: + onnx_path (`str`, *optional*): + The base directory where ONNX files were exported. + compile_dir (`str`, *optional*): + The directory path to store the compiled artifacts. + batch_size (`int`, *optional*, defaults to 1): + The batch size to use for compilation. + num_devices_text_encoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the text encoder model on. + num_devices_transformer (`int`, *optional*, defaults to 4): + The number of AI devices to deploy the transformer model on. + num_devices_vae_decoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the VAE decoder model on. + num_cores (`int`, *optional*, defaults to 16): + The number of cores to use for compilation. + mxfp6_matmul (`bool`, *optional*, defaults to `False`): + If `True`, enables mixed-precision floating-point 6-bit matrix multiplication + optimization during compilation. + **compiler_options: + Additional keyword arguments to pass to the underlying compiler. + + Returns: + `str`: A message indicating the compilation status or path to compiled artifacts. + """ + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export() + + latent_seq_len = 6032 + batch_size = 1 + text_seq_len = 126 + + specializations = [ + { + "batch_size": batch_size, + "latent_seq_len": latent_seq_len, + "text_seq_len": text_seq_len, + } + ] + + compiler_options_transformer = {"mos": 1, "ols": 2, "mdts-mos": 1} + self.transformer_compile_path = self.transformer._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_transformer, + node_precision_info="/home/dipankar/fp32_nodes_scale_woscale_womatmul.yaml", + aic_num_cores=num_cores, + **compiler_options_transformer, + ) + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + """Extract hidden states based on attention mask.""" + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + device_ids: List[int] = None, + ): + """ + Get Qwen prompt embeddings for the given prompt(s) using QAICInferenceSession. + + Args: + prompt (Union[str, List[str]], optional): The input prompt(s) to encode. + device (Optional[torch.device], optional): The device to place tensors on. + dtype (Optional[torch.dtype], optional): The data type for tensors. + device_ids (List[int], optional): List of device IDs to use for inference. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The prompt embeddings and attention mask. + """ + device = device or "cpu" + dtype = dtype or torch.float32 + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ) + + # HACK: Currently working on Pytorch + encoder_hidden_states = self.text_encoder.model( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + device_ids: List[int] = None, + ): + """ + Encode the given prompts into text embeddings using the Qwen text encoder. + + Args: + prompt (Union[str, List[str]]): The prompt(s) to encode. + device (Optional[torch.device], optional): The device to place tensors on. + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + prompt_embeds (Optional[torch.Tensor], optional): Pre-computed prompt embeddings. + prompt_embeds_mask (Optional[torch.Tensor], optional): Pre-computed prompt embeddings mask. + max_sequence_length (int, defaults to 1024): Maximum sequence length for tokenization. + device_ids (List[int], optional): List of device IDs to use for inference. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The prompt embeddings and attention mask. + """ + device = device or "cpu" + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device, device_ids=device_ids) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Generate images from text prompts using the QEfficient-optimized QwenImage pipeline. + + This method performs text-to-image generation by encoding the input prompts through the + Qwen text encoder, running the diffusion process with the transformer model, and decoding + the final latents to images using the VAE decoder. All components are optimized for + Qualcomm AI hardware. + + Args: + prompt (Union[str, List[str]], optional): The text prompt(s) to guide image generation. + negative_prompt (Union[str, List[str]], optional): Negative prompt(s) for true CFG. + true_cfg_scale (float, defaults to 4.0): Scale for true classifier-free guidance. + height (Optional[int], optional): Height of the generated image in pixels. + width (Optional[int], optional): Width of the generated image in pixels. + num_inference_steps (int, defaults to 50): Number of denoising steps. + sigmas (Optional[List[float]], optional): Custom sigmas for the denoising process. + guidance_scale (float, defaults to 1.0): Guidance scale (for future guidance-distilled models). + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + generator (Optional[Union[torch.Generator, List[torch.Generator]]], optional): Random generator(s). + latents (Optional[torch.Tensor], optional): Pre-generated noisy latents. + prompt_embeds (Optional[torch.Tensor], optional): Pre-generated text embeddings. + prompt_embeds_mask (Optional[torch.Tensor], optional): Pre-generated text embeddings mask. + negative_prompt_embeds (Optional[torch.Tensor], optional): Pre-generated negative text embeddings. + negative_prompt_embeds_mask (Optional[torch.Tensor], optional): Pre-generated negative text embeddings mask. + output_type (Optional[str], defaults to "pil"): Output format ("pil", "np", "pt", or "latent"). + return_dict (bool, defaults to True): Whether to return a QwenImagePipelineOutput. + attention_kwargs (Optional[Dict[str, Any]], optional): Additional attention kwargs. + callback_on_step_end (Optional[Callable], optional): Callback function at end of each step. + callback_on_step_end_tensor_inputs (List[str], defaults to ["latents"]): Tensor inputs for callback. + max_sequence_length (int, defaults to 512): Maximum sequence length for text encoder. + + Returns: + Union[QwenImagePipelineOutput, Tuple]: Generated images. + + Examples: + ```python + from QEfficient import QEFFQwenImagePipeline + + pipeline = QEFFQwenImagePipeline.from_pretrained("Qwen/Qwen-Image") + pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) + + image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50).images[0] + image.save("qwenimage.png") + ``` + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + device = "cpu" + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.model.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # Initialize transformer session + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession(str(self.transformer.qpc_path)) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + timestep = (t.expand(latents.shape[0]) / 1000).detach().numpy().astype(np.float32) + + # Conditional pass + transformer_inputs = { + "hidden_states": latents.detach().numpy().astype(np.float32), + "encoder_hidden_states": prompt_embeds.detach().numpy().astype(np.float32), + "timestep": timestep, + } + if guidance is not None: + transformer_inputs["guidance"] = guidance.numpy().astype(np.float32) + + noise_pred = self.transformer.qpc_session.run(transformer_inputs) + noise_pred = torch.tensor(noise_pred["output"]) + # if do_true_cfg: + # # Unconditional pass + # transformer_inputs_uncond = { + # "hidden_states": latents.detach().numpy().astype(np.float32), + # "encoder_hidden_states": negative_prompt_embeds.detach().numpy().astype(np.float32), + # "timestep": timestep, + # } + # if guidance is not None: + # transformer_inputs_uncond["guidance"] = guidance.numpy().astype(np.float32) + + # neg_noise_pred = self.transformer.qpc_session.run(transformer_inputs_uncond) + # neg_noise_pred = torch.tensor(neg_noise_pred["output"]) + + # comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + # cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + # noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + # noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae_decode.model.dtype) + latents_mean = ( + torch.tensor(self.vae_decode.model.config.latents_mean) + .view(1, self.vae_decode.model.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_decode.model.config.latents_std).view( + 1, self.vae_decode.model.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + image = self.vae_cpu.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image.detach(), output_type=output_type) + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/examples/diffusers/qwen_image/qwen_image_example.py b/examples/diffusers/qwen_image/qwen_image_example.py new file mode 100644 index 000000000..f1741ad84 --- /dev/null +++ b/examples/diffusers/qwen_image/qwen_image_example.py @@ -0,0 +1,58 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch + +from QEfficient import QEFFQwenImagePipeline + +model_name = "Qwen/Qwen-Image" + +pipe = QEFFQwenImagePipeline.from_pretrained(model_name) +positive_magic = { + "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt +} + +# Generate image +prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition""" + +negative_prompt = " " # using an empty string if you do not have specific concept to remove + + +# Generate with different aspect ratios +aspect_ratios = { + "1:1": (1328, 1328), + "16:9": (1664, 928), + "9:16": (928, 1664), + "4:3": (1472, 1140), + "3:4": (1140, 1472), + "3:2": (1584, 1056), + "2:3": (1056, 1584), +} + +width, height = aspect_ratios["16:9"] + +# Config for two layers + +original_blocks = pipe.transformer.model.transformer_blocks +pipe.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0], original_blocks[1]]) +pipe.transformer.model.config.num_layers = 2 + +# Pipeline Compile +pipe.compile() + + +image = pipe( + prompt=prompt + positive_magic["en"], + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=5, + true_cfg_scale=4.0, + generator=torch.Generator(device="cpu").manual_seed(42), +).images[0] + +image.save("example.png") diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9..eb8b9ba37 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,15 +27,7 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn - -# Example imports for three representative models from transformers.models.blueprint.modeling_blueprint import ( BlueprintAttention, BlueprintDecoderLayer, @@ -62,6 +54,14 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function + +# Example imports for three representative models +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer,