Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
409da24
Extend on-device sampling support for dual QPC VLMs
quic-xiyushi Oct 23, 2025
e06e175
Fix random_numbers shape
quic-xiyushi Oct 30, 2025
3e242ce
Update example with new random sampling logic
quic-xiyushi Oct 30, 2025
1a01d57
Update to align with recent VLM CB changes
quic-xiyushi Nov 11, 2025
30d6061
Update tests with new random sampling logic
Nov 11, 2025
d02d04d
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 19, 2025
7cf106e
Refactor
quic-xiyushi Nov 19, 2025
45aed11
Add unit tests
quic-xiyushi Nov 20, 2025
6273ab5
Clean up
quic-xiyushi Nov 20, 2025
ef9ae14
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 20, 2025
3789d5a
Update test_sampler.py
quic-xiyushi Nov 20, 2025
5e2afb7
Fix hash for VLM's language decoder to include qaic_config
quic-xiyushi Nov 21, 2025
df06617
Merge remote-tracking branch 'origin/main' into HEAD
quic-xiyushi Nov 25, 2025
10990a9
Fix bug in getting vocab_size and missing ccl in forward
quic-xiyushi Nov 25, 2025
98cfadf
Merge branch 'main' into on-device-sampling-vlm
quic-mamta Dec 10, 2025
a60e7ce
Merge branch 'main' into on-device-sampling-vlm
quic-xiyushi Dec 16, 2025
b22af54
Support prefix-caching with on-device sampling
quic-xiyushi Dec 16, 2025
2533262
Modify tests to use internvl 1b for quicker CI
quic-xiyushi Dec 16, 2025
8698651
Merge branch 'main' into on-device-sampling-vlm
quic-xiyushi Dec 16, 2025
86aaad2
Fix compilation error on Llama3.1 8B due to changes in presence penalty
quic-xiyushi Dec 16, 2025
a2d4fb4
Update tests
quic-xiyushi Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
write_io_files,
)
from QEfficient.utils import LRUCache
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -313,6 +314,13 @@ def _execute_chunked_prefill(
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

if self.include_sampler:
for op in Constants.SAMPLER_OPS:
if decode_batch_id is not None:
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
lang_inputs[op] = self.sampling_params[op]

for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
Expand All @@ -338,6 +346,11 @@ def _execute_chunked_prefill(

chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]

if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
chunk_inputs[op] = lang_inputs[op]

outputs = self._session.run(chunk_inputs)

if "image_idx_output" in outputs:
Expand Down
134 changes: 43 additions & 91 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from pathlib import Path
from time import perf_counter
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -70,6 +70,7 @@
)
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs


class QEFFTransformersBase(QEFFBaseModel):
Expand Down Expand Up @@ -719,7 +720,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, qaic_config, **kwargs):
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.

Expand All @@ -733,7 +734,7 @@ def __init__(self, model, qaic_config, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
super().__init__(model, **kwargs)
super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model = model.get_qeff_language_decoder()
self.model.qaic_config = qaic_config
self.hash_params["qeff_auto_class"] = self.__class__.__name__
Expand Down Expand Up @@ -871,16 +872,16 @@ def __init__(
----------
model : nn.Module
The full HuggingFace multimodal model.
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
**kwargs :
Additional keyword arguments. `full_batch_size` is not supported here.

Raises
------
NotImplementedError
If `full_batch_size` is provided.
Additional keyword arguments.
"""
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
continuous_batching = True
warnings.warn(
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
)
self.model = model
self.config = model.config

Expand All @@ -892,6 +893,11 @@ def __init__(
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
# previous transform function.
self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
Expand Down Expand Up @@ -1002,6 +1008,19 @@ def export(
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
)
output_names = self.model.get_output_names(kv_offload=True)
if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
"include_sampler", False
):
logits_index = output_names["lang"].index("logits")
output_names["lang"][logits_index] = "next_tokens"
inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs(
example_inputs=inputs["lang"],
output_names=output_names["lang"],
dynamic_axes=dynamic_axes["lang"],
continuous_batching=self.continuous_batching,
vocab_size=self.model.language_model.config.vocab_size,
qaic_config=self.lang_model.model.qaic_config,
)

self.vision_model.export(
inputs["vision"],
Expand Down Expand Up @@ -1234,6 +1253,7 @@ def generate(
generation_len: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
**kwargs,
) -> Union[torch.Tensor, np.ndarray]:
"""
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
Expand Down Expand Up @@ -1294,6 +1314,7 @@ def generate(
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
image_height=image_height,
image_width=image_width,
**kwargs,
)

# Call generate method
Expand Down Expand Up @@ -1576,10 +1597,15 @@ def __init__(
Raises
------
NotImplementedError
If `full_batch_size` is provided.
If `full_batch_size` is provided or `include_sampler` is True.
"""
if kwargs.pop("full_batch_size", None):
warnings.warn(
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
)
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
if qaic_config is not None and qaic_config.pop("include_sampler", False):
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)

self.model.qaic_config = qaic_config
Expand Down Expand Up @@ -2196,6 +2222,8 @@ def from_pretrained(
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.

Expand Down Expand Up @@ -2659,10 +2687,13 @@ def export(
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}

if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs(
example_inputs=example_inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
continuous_batching=self.continuous_batching,
vocab_size=self.model.config.vocab_size,
qaic_config=self.model.qaic_config,
)
return self._export(
example_inputs,
Expand All @@ -2674,85 +2705,6 @@ def export(
prefill_only=prefill_only,
)

def get_sampling_inputs_and_outputs(
self,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
):
"""
Updates the example inputs, output names, and dynamic axes to include
parameters relevant for on-device sampling during ONNX export.

Parameters
----------
example_inputs : Dict[str, torch.Tensor]
Current dictionary of example inputs.
output_names : List[str]
Current list of output names.
dynamic_axes : Dict[str, Dict[int, str]]
Current dictionary of dynamic axes configurations.

Returns
-------
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
Updated example inputs, output names, and dynamic axes including
sampling-related parameters.
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
)
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
)
dynamic_axes["temperatures"] = {0: "batch_size"}

max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
dynamic_axes["top_ps"] = {0: "batch_size"}

example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
Expand Down Expand Up @@ -404,6 +405,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
Expand Down Expand Up @@ -757,10 +759,12 @@ class SamplerTransform:
QEffGPTJForCausalLM,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
QEffInternDecoderWrapper,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we are enabling sampling only for intern model?
Will other VLMs also be supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other VLMs are also supposed to be supported. But currently only InternVL and Qwen VL 2.5 have been tested.

QEffLlamaForCausalLM,
QEffMptForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen_2_5_vl_DecoderWrapper,
}

@classmethod
Expand Down
Loading
Loading