Skip to content
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __init__(self, model):
self.model = model
self.model.vision_model = self.model.vision_tower

def get_repeated_layers(self):
return self.model.vision_tower.vision_model.encoder.layers[0].__class__

def forward(self, pixel_values):
# Image features
image_outputs = self.model.vision_tower(pixel_values, output_hidden_states=True)
Expand All @@ -54,6 +57,9 @@ def __init__(self, model):
self.language_model = self.model.language_model
self.lm_head = self.model.lm_head

def get_repeated_layers(self):
return self.model.language_model.layers[0].__class__

def forward(
self,
input_ids,
Expand Down
19 changes: 11 additions & 8 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,23 +890,26 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
Dynamically determine which DecoderLayer classes should be exported as functions
based on the model's architecture using the existing KVCacheTransform mapping.
"""
# Define patterns that identify decoder layer classes
DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"]

# Get all QEff classes that are decoder layers from the existing mapping
DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"]
decoder_layer_classes = set()

for original_class, qeff_class in KVCacheTransform._module_mapping.items():
# Check if the QEff class name contains decoder layer patterns
qeff_class_name = qeff_class.__name__
if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS):
decoder_layer_classes.add(qeff_class)

# Filter to only include classes that are actually used in the current model
model_decoder_classes = set()
for module in model.modules():
if module.__class__ in decoder_layer_classes:
model_decoder_classes.add(module.__class__)
model_class_name = model.__class__.__name__
if "EncoderWrapper" in model_class_name:
model_decoder_classes.update(
module.__class__ for module in model.modules() if "Qwen2_5_VLVisionBlock" in module.__class__.__name__
)
return model_decoder_classes

model_decoder_classes.update(
module.__class__ for module in model.modules() if module.__class__ in decoder_layer_classes
)

return model_decoder_classes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

mrope_section = mrope_section * 2
cos = cos[position_ids]
sin = sin[position_ids]

cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)

cos = torch.cat([cos[0, ..., 0:32], cos[0, ..., 32:80], cos[0, ..., 80:128]], dim=-1).unsqueeze(0)
sin = torch.cat([sin[0, ..., 0:32], sin[0, ..., 32:80], sin[0, ..., 80:128]], dim=-1).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

Expand Down Expand Up @@ -872,6 +868,9 @@ def __init__(self, model):
self.model = model
self.model.vision_model = self.model.visual

def get_repeated_layers(self):
return self.model.visual.blocks[0].__class__

def forward(self, pixel_values, image_grid_thw):
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
bs = image_grid_thw.shape[0]
Expand All @@ -887,6 +886,9 @@ def __init__(self, model):
self.model = model
self.language_model = self.model.model.language_model

def get_repeated_layers(self):
return QEffQwen2_5_VLDecoderLayer

def forward(
self,
input_ids,
Expand Down
17 changes: 13 additions & 4 deletions QEfficient/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform
from QEfficient.transformers.cache_utils import InvalidIndexProvider
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils.cache import QEFF_HOME
from QEfficient.utils.hash_utils import create_export_hash
from QEfficient.utils.logging_utils import logger
Expand Down Expand Up @@ -164,18 +163,28 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs):
# Transform output names for subfunction compatibility
if "output_names" in kwargs:
kwargs["output_names"] = [
re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"]
re.sub("_RetainedState", "_InternalRetainedState", name)
if name.endswith("_RetainedState") and ("key" in name or "value" in name)
else name
for name in kwargs["output_names"]
]
else:
args = list(args)
args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]]
args[1] = [
re.sub("_RetainedState", "_InternalRetainedState", name)
if name.endswith("_RetainedState") and ("key" in name or "value" in name)
else name
for name in args[1]
]
args = tuple(args)

# Add subfunction-specific ONNX transforms
qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
qeff_model._onnx_transforms.append(CustomOpTransform)

# TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model)
# import pdb; pdb.set_trace()
kwargs["export_modules_as_functions"] = {qeff_model.model.get_repeated_layers()}
return args, kwargs


Expand Down
10 changes: 8 additions & 2 deletions QEfficient/utils/torch_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

"""Monkey patches for torch.onnx.utils to fix ONNX export issues."""

import warnings

import torch
import torch.onnx.utils as onnx_utils
from torch import _C
Expand Down Expand Up @@ -37,9 +39,13 @@ def _track_module_attributes_forward_hook(module, input, output):
if hasattr(module, attr_name):
onnx_attrs = getattr(module, attr_name)
delattr(module, attr_name)

# FIX: use empty dict to avoid type mismatch
onnx_attrs = {}
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
# onnx_attrs = {}
try:
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
except Exception as e:
warnings.warn(f"Failed to track ONNX scope attributes: {e}. Skipping this step.")

for m in model.modules():
m.register_forward_hook(_track_module_attributes_forward_hook)
Expand Down
Loading