Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e05f05a
Diffusers support (#604)
quic-amitraj Dec 9, 2025
a2d91cc
Stage-1 export debug
quic-amitraj Jul 10, 2025
b47884a
Stage-2 Export inital working version done
quic-amitraj Jul 16, 2025
f4a4784
Stage-3 compilation work is under progress
quic-amitraj Jul 17, 2025
1189dc9
Stage-4 Working pipeline with wrong output
quic-amitraj Jul 23, 2025
712f18b
Testing
quic-amitraj Jul 30, 2025
f6f4235
Testing
quic-amitraj Jul 30, 2025
a0d8b2c
Working sd3-turbo
quic-amitraj Aug 3, 2025
8e4fbd1
Working with cleaned code
quic-amitraj Aug 5, 2025
3f3c14a
Working with cleaned code
quic-amitraj Aug 5, 2025
beeddfa
Working with vae_included
quic-amitraj Aug 6, 2025
14db9b6
Fix-1
quic-amitraj Aug 8, 2025
096cb91
Fix-2
quic-amitraj Aug 10, 2025
c2c86d7
Fix-3
quic-amitraj Aug 13, 2025
e7a0f65
Added readme for diffusers
quic-amitraj Aug 14, 2025
89c250c
Code cleanup
quic-amitraj Aug 14, 2025
2d31d34
Code cleanup-2
quic-amitraj Aug 15, 2025
539f425
Minor fix
quic-amitraj Aug 20, 2025
3dd8f37
Added Support of flux
quic-amitraj Sep 19, 2025
4edcee6
Updated seq_len of flux transformers
tv-karthikeya Sep 24, 2025
759a1b7
Removing SD3, adding small fix for flux model hash
tv-karthikeya Oct 9, 2025
11c59be
adding device id support for flux for all stages
tv-karthikeya Oct 9, 2025
93890b6
[WIP] Adding support for custom Height,width
tv-karthikeya Nov 3, 2025
7708282
Flux support with Custom config
quic-amitraj Nov 4, 2025
77e4527
Added OnnxfunctionTransform and code cleanup while modifying compile …
quic-amitraj Nov 4, 2025
8e2dc5c
Compile fix
quic-amitraj Nov 5, 2025
ad3bba7
Modification of Pipeline-1
Nov 6, 2025
f4fea00
Modification of Pipeline-2
Nov 7, 2025
6d8718e
Update readme for diffusers
Nov 10, 2025
6a733be
Added support of output dataclass
Nov 11, 2025
ddf5bc5
Replaced output dict with dataclass to make it more user friendly
Nov 11, 2025
e037478
Rebased with main and fixed some issues
Nov 12, 2025
318e0df
Code cleaning and removed redundant code
Nov 13, 2025
de503b8
Code cleaning and removed redundant code-2
Nov 13, 2025
2be94c6
Added tqdm for export and compile
Nov 13, 2025
24e0b26
Parallel compilation and onnx subfunction is added
Nov 14, 2025
e3e26f3
Onboarding Qwen Image
qcdipankar Nov 19, 2025
f2eb650
Onboarding Qwen Image
qcdipankar Nov 19, 2025
3b905c1
Moved Rotary Embed inside
qcdipankar Nov 19, 2025
e11a7fb
Cleaning and rebase with diffusers
qcdipankar Dec 10, 2025
a8d87d2
Subfunction fixes for KV cache transform (#655)
abhishek-singh591 Dec 10, 2025
76a0ccb
Merge branch 'config_support_diffusers' into qwen_image_pipeline
qcdipankar Dec 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from QEfficient.compile.compile_helper import compile
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.diffusers.pipelines.qwen_image.pipeline_qwenimage import QEFFQwenImagePipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
Expand All @@ -41,6 +42,8 @@
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEffFluxPipeline",
"QEFFQwenImagePipeline",

]
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
Expand Down Expand Up @@ -70,4 +73,4 @@ def check_qaic_sdk():


if not check_qaic_sdk():
logger.warning("QAIC SDK is not installed, eager mode features won't be available!")
logger.warning("QAIC SDK is not installed, eager mode features won't be available!")
4 changes: 4 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@
"""
Get the model configuration as a dictionary.

Returns:
Dict: The configuration dictionary of the underlying HuggingFace model
"""
return self.model.config.__dict__
This is an abstract property that must be implemented by all subclasses.

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:70: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:66: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:63: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:51: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:48: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:43: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:38: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:29: invalid-syntax: Simple statements must be separated by newlines or semicolons

Check failure on line 148 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:20: invalid-syntax: Simple statements must be separated by newlines or semicolons
Typically returns: self.model.config.__dict__

Check failure on line 149 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (invalid-syntax)

QEfficient/base/modeling_qeff.py:148:81: invalid-syntax: Expected an identifier

Returns:
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGather,
CtxGather3D,
CtxGatherBlockedKV,
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxScatter,
CtxScatter3D,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherBlockedKVCB,
CtxGatherCB,
CtxGatherCB3D,
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterCB,
Expand Down Expand Up @@ -95,6 +99,8 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
"CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV),
"CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB),
}

@classmethod
Expand Down
75 changes: 75 additions & 0 deletions QEfficient/diffusers/models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import torch
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward


class QEffJointTransformerBlock(JointTransformerBlock):
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
# ff_output = self.ff(norm_hidden_states)
ff_output = self.ff(norm_hidden_states, block_size=4096)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output

# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output

return encoder_hidden_states, hidden_states
158 changes: 158 additions & 0 deletions QEfficient/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from typing import Optional

import torch
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0

from QEfficient.diffusers.models.transformers.transformer_qwenimage import QEffQwenDoubleStreamAttnProcessor2_0


class QEffAttention(Attention):
def __qeff_init__(self):
# breakpoint()
processor = QEffQwenDoubleStreamAttnProcessor2_0()
self.processor = processor
processor.query_block_size = 64

def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()

if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1

attention_scores = torch.baddbmm(
baddbmm_input,
query,
key,
beta=beta,
alpha=self.scale,
)
del baddbmm_input

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = attention_scores.softmax(dim=-1)
del attention_scores

attention_probs = attention_probs.to(dtype)

return attention_probs


class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
def __call__(
self,
attn: QEffAttention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states

batch_size = hidden_states.shape[0]

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

query = query.reshape(-1, query.shape[-2], query.shape[-1])
key = key.reshape(-1, key.shape[-2], key.shape[-1])
value = value.reshape(-1, value.shape[-2], value.shape[-1])

# pre-transpose the key
key = key.transpose(-1, -2)
if query.size(-2) != value.size(-2): # cross-attention, use regular attention
# QKV done in single block
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
else: # self-attention, use blocked attention
# QKV done with block-attention (a la FlashAttentionV2)
query_block_size = self.query_block_size
query_seq_len = query.size(-2)
num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
for qidx in range(num_blocks):
query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :]
attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
hidden_states_block = torch.bmm(attention_probs, value)
if qidx == 0:
hidden_states = hidden_states_block
else:
hidden_states = torch.cat((hidden_states, hidden_states_block), -2)
hidden_states = attn.batch_to_head_dim(hidden_states)

if encoder_hidden_states is not None:
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
6 changes: 6 additions & 0 deletions QEfficient/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
31 changes: 31 additions & 0 deletions QEfficient/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import torch
from diffusers import AutoencoderKL


class QEffAutoencoderKL(AutoencoderKL):
def encode(self, x: torch.Tensor, return_dict: bool = True):
"""
Encode a batch of images into latents.

Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
return h
Loading
Loading