99
1010import numpy as np
1111import torch
12- import torch .nn as nn
1312from diffusers .models .attention_dispatch import dispatch_attention_fn
1413from diffusers .models .modeling_outputs import Transformer2DModelOutput
1514from diffusers .models .transformers .transformer_flux import (
2120 _get_qkv_projections ,
2221)
2322
24- from QEfficient .diffusers .models .normalization import (
25- QEffAdaLayerNormZero ,
26- QEffAdaLayerNormZeroSingle ,
27- )
28-
2923
3024def qeff_apply_rotary_emb (
3125 x : torch .Tensor , freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]]
@@ -120,24 +114,6 @@ def __qeff_init__(self):
120114
121115
122116class QEffFluxSingleTransformerBlock (FluxSingleTransformerBlock ):
123- def __init__ (self , dim : int , num_attention_heads : int , attention_head_dim : int , mlp_ratio : float = 4.0 ):
124- super ().__init__ (dim , num_attention_heads , attention_head_dim , mlp_ratio )
125- self .mlp_hidden_dim = int (dim * mlp_ratio )
126- self .norm = QEffAdaLayerNormZeroSingle (dim )
127- self .proj_mlp = nn .Linear (dim , self .mlp_hidden_dim )
128- self .act_mlp = nn .GELU (approximate = "tanh" )
129- self .proj_out = nn .Linear (dim + self .mlp_hidden_dim , dim )
130- self .attn = QEffFluxAttention (
131- query_dim = dim ,
132- dim_head = attention_head_dim ,
133- heads = num_attention_heads ,
134- out_dim = dim ,
135- bias = True ,
136- processor = QEffFluxAttnProcessor (),
137- eps = 1e-6 ,
138- pre_only = True ,
139- )
140-
141117 def forward (
142118 self ,
143119 hidden_states : torch .Tensor ,
@@ -163,33 +139,12 @@ def forward(
163139 gate = gate .unsqueeze (1 )
164140 hidden_states = gate * self .proj_out (hidden_states )
165141 hidden_states = residual + hidden_states
166- # if hidden_states.dtype == torch.float16:
167- hidden_states = hidden_states .clip (- 65504 , 65504 )
168142
169143 encoder_hidden_states , hidden_states = hidden_states [:, :text_seq_len ], hidden_states [:, text_seq_len :]
170144 return encoder_hidden_states , hidden_states
171145
172146
173147class QEffFluxTransformerBlock (FluxTransformerBlock ):
174- def __init__ (
175- self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
176- ):
177- super ().__init__ (dim , num_attention_heads , attention_head_dim )
178-
179- self .norm1 = QEffAdaLayerNormZero (dim )
180- self .norm1_context = QEffAdaLayerNormZero (dim )
181- self .attn = QEffFluxAttention (
182- query_dim = dim ,
183- added_kv_proj_dim = dim ,
184- dim_head = attention_head_dim ,
185- heads = num_attention_heads ,
186- out_dim = dim ,
187- context_pre_only = False ,
188- bias = True ,
189- processor = QEffFluxAttnProcessor (),
190- eps = eps ,
191- )
192-
193148 def forward (
194149 self ,
195150 hidden_states : torch .Tensor ,
@@ -395,31 +350,3 @@ def forward(
395350 return (output ,)
396351
397352 return Transformer2DModelOutput (sample = output )
398-
399-
400- class QEffFluxTransformer2DModelOF (QEffFluxTransformer2DModel ):
401- def __qeff_init__ (self ):
402- self .transformer_blocks = nn .ModuleList ()
403- self ._block_classes = set ()
404-
405- for _ in range (self .config .num_layers ):
406- BlockClass = QEffFluxTransformerBlock
407- block = BlockClass (
408- dim = self .inner_dim ,
409- num_attention_heads = self .config .num_attention_heads ,
410- attention_head_dim = self .config .attention_head_dim ,
411- )
412- self .transformer_blocks .append (block )
413- self ._block_classes .add (BlockClass )
414-
415- self .single_transformer_blocks = nn .ModuleList ()
416-
417- for _ in range (self .config .num_single_layers ):
418- SingleBlockClass = QEffFluxSingleTransformerBlock
419- single_block = SingleBlockClass (
420- dim = self .inner_dim ,
421- num_attention_heads = self .config .num_attention_heads ,
422- attention_head_dim = self .config .attention_head_dim ,
423- )
424- self .single_transformer_blocks .append (single_block )
425- self ._block_classes .add (SingleBlockClass )
0 commit comments