diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 2d8f72e0a..b507363c3 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,7 +6,17 @@ # ----------------------------------------------------------------------------- import os -import warnings + +# ----------------------------------------------------------------------------- # +# For faster downloads via hf_transfer +# This code is put above import statements as this needs to be executed before +# hf_transfer is imported (will happen on line 15 via leading imports) +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# DO NOT ADD ANY CODE ABOVE THIS LINE +# Please contact maintainers if you must edit this file above this line. +# ----------------------------------------------------------------------------- # +# Placeholder for all non-transformer models registered in QEfficient +import warnings # noqa: I001 import QEfficient.utils.model_registery # noqa: F401 from QEfficient.base import ( @@ -26,6 +36,10 @@ from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger +# custom warning for the better logging experience +warnings.formatwarning = custom_format_warning + + # Users can use QEfficient.export for exporting models to ONNX export = qualcomm_efficient_converter __all__ = [ @@ -42,14 +56,7 @@ "QEFFCommonLoader", "QEffFluxPipeline", ] -# For faster downloads via hf_transfer -# This code is put above import statements as this needs to be executed before -# hf_transfer is imported (will happen on line 15 via leading imports) -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -# Placeholder for all non-transformer models registered in QEfficient -# custom warning for the better logging experience -warnings.formatwarning = custom_format_warning # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ea347016b..2c98a83f3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -60,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -204,10 +205,11 @@ def _export( example_inputs: Dict[str, torch.Tensor], output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]], - export_kwargs: Optional[Dict[str, any]] = None, onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + prefill_only: Optional[bool] = False, + **export_kwargs, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -232,11 +234,16 @@ def _export( instance using from_pretrained() for re-export. """ + # TODO: Hack for retain_full_kv, handle this outside + export_kwargs.pop("retain_full_kv", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists if onnx_path.is_file(): - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path # check if the model is in meta state or weights are offloaded @@ -272,9 +279,6 @@ def _export( input_names.append(param) try: - # Export to ONNX - export_kwargs = {} if export_kwargs is None else export_kwargs - torch.onnx.export( self.model, (example_inputs,), @@ -318,9 +322,42 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path + def get_onnx_path( + self, + prefill_only: Optional[bool] = False, + enable_chunking: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + use_onnx_subfunctions: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + kwargs = { + "offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions, + "retain_full_kv": retain_full_kv, + } + if prefill_only: + if self.prefill_onnx_path is None: + kwargs.update( + { + "prefill_only": prefill_only, + "prefill_seq_len": specializations[0].get("seq_len"), + "enable_chunking": enable_chunking, + } + ) + self.export(**kwargs) + return self.prefill_onnx_path + else: + if self.onnx_path is None: + self.export(**kwargs) + return self.onnx_path + @dump_qconfig def _compile( self, @@ -335,6 +372,10 @@ def _compile( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, + prefill_only: Optional[str] = None, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -360,11 +401,18 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - - if onnx_path is None and self.onnx_path is None: - self.export(use_onnx_subfunctions=use_onnx_subfunctions) - - onnx_path = Path(onnx_path or self.onnx_path) + onnx_path = Path( + onnx_path + if onnx_path + else self.get_onnx_path( + prefill_only, + enable_chunking, + specializations, + offload_pt_weights, + use_onnx_subfunctions, + retain_full_kv, + ) + ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): @@ -426,6 +474,7 @@ def _compile( "mdp_ts_num_devices": mdp_ts_num_devices, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, + "prefill_only": prefill_only, } compile_hash = hash_dict_params(compile_hash_params) @@ -465,6 +514,16 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + if use_onnx_subfunctions: + + class FeatureNotAvailableError(Exception): + pass + + exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}' + raise FeatureNotAvailableError( + "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model." + + f"\nRun following command manually with assert compiler:\n{exec_command}" + ) try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: @@ -485,5 +544,4 @@ def _compile( logger.info("Hashed parameters exported successfully.") self.qpc_path = qpc_path - return qpc_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index bdf7bf677..16697cec9 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -95,12 +95,12 @@ class CustomOpTransform(BaseOnnxTransform): "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), "CtxGatherFunc": (CtxGatherFunc, CtxGather), "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), - "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), - "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), + "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), + "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), } @classmethod diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c7dc8639a..7b15effe7 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function): def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 8a06bc2b1..c15b60810 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function): def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 41a3d29f7..6d9243fdc 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -102,7 +102,7 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = None, + export_kwargs: Dict = {}, ) -> str: """ Export the text encoder model to ONNX format. @@ -122,7 +122,7 @@ def export( output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, - export_kwargs=export_kwargs, + **export_kwargs, ) def compile(self, specializations: List[Dict], **compiler_options) -> None: @@ -179,7 +179,7 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = None, + export_kwargs: Dict = {}, ) -> str: """ Export the UNet model to ONNX format. @@ -199,7 +199,7 @@ def export( output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, - export_kwargs=export_kwargs, + **export_kwargs, ) def compile(self, specializations: List[Dict], **compiler_options) -> None: @@ -292,7 +292,7 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = None, + export_kwargs: Dict = {}, ) -> str: """ Export the VAE model to ONNX format. @@ -312,7 +312,7 @@ def export( output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, - export_kwargs=export_kwargs, + **export_kwargs, ) def compile(self, specializations: List[Dict], **compiler_options) -> None: @@ -438,7 +438,7 @@ def export( output_names: List[str], dynamic_axes: Dict, export_dir: str = None, - export_kwargs: Dict = None, + export_kwargs: Dict = {}, use_onnx_subfunctions: bool = False, ) -> str: """ @@ -466,8 +466,8 @@ def export( output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, - export_kwargs=export_kwargs, offload_pt_weights=False, # As weights are needed with AdaLN changes + **export_kwargs, ) def compile(self, specializations: List[Dict], **compiler_options) -> None: diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index e69aebb2b..6c7173072 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with the active adapter to ONNX format. @@ -291,10 +291,10 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = example_inputs, output_names, dynamic_axes, - export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights + do_constant_folding=False, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def compile( diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 64fa3f61c..8ff8335f5 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def generate( diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 62cc71a4c..faadaba6b 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls): """ if torch.onnx.is_in_onnx_export(): if cls.SUBFUNC_ENABLED: + # TODO: should not return 0 remove this if condition, it can hurt perf return 0 else: return torch.iinfo(torch.int32).max @@ -681,6 +682,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache + def write_only( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + _, _, ctx_len, _ = self.key_cache[layer_idx].shape + if is_sliding_layer: + kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + return k_out, v_out + def update( self, key_states: torch.Tensor, @@ -747,3 +779,92 @@ def update( v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + + def full_cache_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Gather + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + def sliding_window_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + sliding_window_len = cache_kwargs.get("sliding_window") + + # Gather + ctx_len = position_ids.shape[1] + sliding_window_len + ctx_indices = torch.arange(ctx_len)[None, None, ...] + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0) + ctx_indices += add_idx + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 5337b44f5..47059d8dc 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -188,6 +188,9 @@ # This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} +# This is for supporting different modelling classes specially written for prefill-only model +SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"} + # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 84552aff4..3efe890b8 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +import math +import os from typing import Callable, Optional, Union import torch @@ -30,8 +32,8 @@ from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -from QEfficient.utils import constants from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger class QEffGptOssExperts(GptOssExperts): @@ -42,8 +44,8 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) -class QEffGptOssMLP(GptOssMLP): - def alt_forward(self, hidden: torch.Tensor): +class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -78,7 +80,62 @@ def alt_forward(self, hidden: torch.Tensor): up = (hidden @ W_u) + b_u # [T, I] # Apply GptOss activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +class QEffPrefillOnlyGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + if os.environ.get("NUM_FFN_BLOCKS", None) is not None: + return self.blocked_ffn_forward(hidden) + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -88,6 +145,165 @@ def alt_forward(self, hidden: torch.Tensor): # Down projection down_out = (intermediate @ W_d) + b_d # [T, H] + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + # Gate and Up projections + gate = (tgb @ W_g) + b_g # [T, I] + up = (tgb @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out_block = (intermediate @ W_d) + b_d # [T, H] + + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + + wg_col_shape = W_g.shape[1] + wg_num_blocks = math.ceil(wg_col_shape / 128) + last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128 + + intermediates = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:] + cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:] + else: + cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] + cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] + + cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) + cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) + cur_intermediate = (cur_up + 1) * cur_glu + intermediates.append(cur_intermediate) + + intermediate = torch.cat(intermediates, dim=-1) + + downs = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:]) + else: + downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128]) + + down_out_block = torch.cat(downs, dim=1) + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + # Apply routing weights and accumulate masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) expert_out += masked_down @@ -95,6 +311,8 @@ def alt_forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + +class QEffGptOssMLP(GptOssMLP): # ------------------- Gather based, weights as activation approach --------------- def forward_weights_as_activation(self, hidden_states): bs, seq_len, _ = hidden_states.shape @@ -142,7 +360,6 @@ def forward_weights_as_activation(self, hidden_states): # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- def forward(self, hidden_states): - # print("Seperate Split, Up, Gate Projections") bs, seq_len, _ = hidden_states.shape hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) @@ -172,7 +389,7 @@ def forward(self, hidden_states): up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) # Apply activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -404,6 +621,283 @@ def eager_attention_forward( return attn_output, attn_weights +def eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + q_block = query[:, :, qi : qi + real_q_len, :] + scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, value_states) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +def opt_eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + # Calculate block size (last block should be handled with remainder) + + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + if block_idx == 0: + kv_start_idx = 0 + else: + kv_start_idx = qi - 128 + + q_block = query[:, :, qi : qi + real_q_len, :] + if kwargs.get("sliding_window"): + k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :] + v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :] + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len] + else: + k_block = key_states + v_block = value_states + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, v_block) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": self.sliding_window, + } + if self.sliding_window is not None: + key_states, value_states = past_key_value.sliding_window_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states, value_states = past_key_value.full_cache_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.sliding_window is not None: + attention_mask = sliding_mask + # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) + ctx_len = position_ids.shape[1] + self.sliding_window + ctx_indices = torch.arange(ctx_len) + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0) + # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0) + # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window) + ctx_indices += add_idx + attention_mask = attention_mask[:, :, :, ctx_indices] + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffPrefillOnlyGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if self.sliding_window is not None: + sliding_window_len = past_key_value.sliding_window_len + short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2]) + read_idx = short_read_idx + torch.where( + position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 + ) + # This is a trick to export with seq_len position_ids.max(), 0, read_idx) + k_cache = key_states[:, :, read_idx, :] + v_cache = value_states[:, :, read_idx, :] + else: + k_cache, v_cache = key_states, value_states + _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + if os.environ.get("ENABLE_OPT_SWA", "0") == "1": + attention_interface: Callable = opt_eager_attention_forward_blocked + else: + attention_interface: Callable = eager_attention_forward_blocked + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -429,8 +923,9 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -511,7 +1006,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores - # alth, _ = self.mlp.alt_forward(hidden_states) hidden_states = hidden_states.reshape(residual.shape) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -525,6 +1019,97 @@ def forward( return outputs +class QEffPrefillOnlyGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.max_cache_len, + sliding_window=self.config.sliding_window, + ) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + class QEffGptOssModel(GptOssModel): def forward( self, @@ -578,7 +1163,6 @@ def forward( ) hidden_states = inputs_embeds - # position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -708,15 +1292,15 @@ def forward( router_logits=outputs.router_logits, ) - def get_pkv_dynamic_axes( - self, - ): + def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False): pkv_dynamic_axes = [] for layer_type in self.config.layer_types: - if layer_type == "sliding_attention": - pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) - elif layer_type == "full_attention": - pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + if layer_type == "sliding_attention" and not retain_full_kv: + pkv_dynamic_axes.append( + {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} + ) + else: + pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}) return pkv_dynamic_axes def get_specializations( @@ -724,10 +1308,14 @@ def get_specializations( batch_size: int, prefill_seq_len: int, ctx_len: int, + **kwargs, ): batch_size = batch_size if batch_size else 1 - prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN - ctx_len = ctx_len if ctx_len else constants.CTX_LEN + if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len: + ctx_len = prefill_seq_len + logger.warning( + f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" + ) specializations = [ { diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 16a809c96..008147c03 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter @@ -37,13 +38,20 @@ get_compilation_dims, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration -from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, +) from QEfficient.transformers.models.pytorch_transforms import ( BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyChunkedTransform, + PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -301,7 +309,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -338,7 +346,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -588,15 +596,7 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the vision encoder component to ONNX format. @@ -626,7 +626,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -741,15 +741,7 @@ def __init__(self, model, qaic_config, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the language decoder component to ONNX format. @@ -779,7 +771,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -2284,11 +2276,30 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + def __init__( self, model: nn.Module, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, **kwargs, ): """ @@ -2336,6 +2347,7 @@ def __init__( # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2348,6 +2360,7 @@ def __init__( if qaic_config: self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.hash_params["max_seq_len_cached"] = max_seq_len_cached # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -2372,6 +2385,7 @@ def from_pretrained( pretrained_model_name_or_path, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, *args, **kwargs, ): @@ -2435,7 +2449,6 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class - if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( model, @@ -2450,6 +2463,7 @@ def from_pretrained( continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, **kwargs, ) @@ -2465,7 +2479,56 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: + def get_seq_len_and_handle_specialized_prefill_model( + self, prefill_seq_len: Optional[int] = None, enable_chunking=False + ) -> int: + self.hash_params["prefill_only"] = True + if enable_chunking: + self.hash_params["chunking"] = True + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) + if num_q_blocks is None: + block_size = 128 + if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Or set `NUM_Q_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // block_size + logger.warning( + f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override" + ) + os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) + num_q_blocks = int(num_q_blocks) + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) + + self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks + self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks + self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + **kwargs, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2491,6 +2554,33 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) + enable_chunking = kwargs.get("enable_chunking", False) + if prefill_only: + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.prefill(enable=True, enable_chunking=enable_chunking) + self.hash_params.pop("retain_full_kv", None) + seq_len = ( + self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + ) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH + else seq_len + ) + kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + else: + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("NUM_Q_BLOCKS", None) + self.hash_params.pop("NUM_FFN_BLOCKS", None) + self.hash_params.pop("ENABLE_OPT_SWA", None) + self.hash_params.pop("chunking", None) + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + self.model.config.sliding_window + self.hash_params["retain_full_kv"] = True + example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), @@ -2539,7 +2629,13 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = else: # HACK: create common function for this including above if condition code pkv_dynamic_axes = ( - self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + self.model.get_pkv_dynamic_axes( + retain_full_kv=kwargs.get("retain_full_kv", False) + or (prefill_only and kwargs.get("enable_chunking", False)), + continuous_batching=self.continuous_batching, + ) + if hasattr(self.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes ) pkv_dynamic_axes = ( [pkv_dynamic_axes] * self.model.config.num_hidden_layers @@ -2548,7 +2644,6 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = ) for i in range(self.num_layers): - pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size" for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] @@ -2569,14 +2664,14 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names=output_names, dynamic_axes=dynamic_axes, ) - return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), offload_pt_weights=kwargs.get("offload_pt_weights", True), + prefill_only=prefill_only, ) def get_sampling_inputs_and_outputs( @@ -2666,6 +2761,7 @@ def build_prefill_specialization( batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the prefill phase. @@ -2688,11 +2784,17 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ + if prefill_seq_len == 1 and self.continuous_batching: + exec_batch_size = full_batch_size + else: + exec_batch_size = 1 if self.continuous_batching else batch_size + if hasattr(self.model, "get_specializations"): spec = self.model.get_specializations( - batch_size=1 if self.continuous_batching else batch_size, + batch_size=exec_batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + **kwargs, )[0] else: spec = { @@ -2720,6 +2822,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the decode phase. @@ -2790,6 +2893,9 @@ def compile( num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -2870,6 +2976,20 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) + else: + if self.continuous_batching: + if not isinstance(kv_cache_batch_size, int): + raise ValueError( + "Please pass valid integer for kv_cache_batch_size as continuous_batching is enabled for prefill-only model" + ) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: @@ -2907,15 +3027,6 @@ def compile( if self.is_tlm: num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) - if self.continuous_batching and full_batch_size is None: - raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") - - if kv_cache_batch_size and not full_batch_size: - raise ValueError( - "KV caching requires continuous batching. Please set `full_batch_size` and " - "enable `continuous_batching=True` in `from_pretrained`." - ) - if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) @@ -2924,15 +3035,23 @@ def compile( ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") + if kv_cache_batch_size and prefill_only is not None and prefill_only: + logger.warning( + "kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model" + ) + # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: + # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading if self.comp_ctx_lengths_prefill is not None: # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization for i in range(0, len(self.comp_ctx_lengths_prefill)): + if prefill_only or enable_chunking: + raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") specializations.append( self.build_prefill_specialization( prefill_seq_len=prefill_seq_len, @@ -2952,6 +3071,8 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + prefill_only=prefill_only, + enable_chunking=enable_chunking, ) ) @@ -2979,6 +3100,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, ) if decode_spec: specializations.append(decode_spec) @@ -2991,7 +3113,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -3006,6 +3127,10 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + retain_full_kv=retain_full_kv, **compiler_options, ) @@ -3197,7 +3322,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3225,7 +3350,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -3573,7 +3698,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. @@ -3601,7 +3726,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 07b9fe7e1..4ba6641cf 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -265,6 +265,11 @@ QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, + QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyChunkedGptOssMLP, + QEffPrefillOnlyGptOssAttention, + QEffPrefillOnlyGptOssMLP, + QEffPrefillOnlyGptOssModel, ) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, @@ -642,6 +647,39 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class PrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + } + + +class PrefillOnlyChunkedTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + } + + +class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, + QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + } + + +class RevertPrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, + **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, + } + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index dfadc00ef..dc2308e99 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -5,6 +5,6 @@ # # ----------------------------------------------------------------------------- -from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers -__all__ = ["replace_transformers_quantizers"] +__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"] diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index eea92a490..638f55921 100644 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import copy import inspect import re import warnings @@ -40,20 +41,19 @@ def export_wrapper(func): """ def wrapper(self, *args, **kwargs): - # 1. Prepare export directory + # 1. Setup ONNX subfunctions if requested + if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False): + args, kwargs = _setup_onnx_subfunctions(self, args, kwargs) + + # 2. Prepare export directory export_dir = _prepare_export_directory(self, kwargs) - # 2. Generate hash and finalize export directory path + # 3. Generate hash and finalize export directory path export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func) export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir self.export_hash = export_hash - # 3. Setup ONNX subfunctions if requested - # TODO: No need of this variable, if export_kwargs contains classes (refer diffusers) - if use_onnx_subfunctions := kwargs.get("use_onnx_subfunctions", False): - _setup_onnx_subfunctions(self, kwargs) - # 4. Execute the actual export onnx_path = func(self, *args, **kwargs) @@ -101,9 +101,6 @@ def _generate_export_hash(qeff_model, args, kwargs, func): Returns: Tuple of (export_hash: str, filtered_hash_params: dict) """ - # Extract use_onnx_subfunctions before binding (it's used by wrapper, not _export) - use_onnx_subfunctions = kwargs.pop("use_onnx_subfunctions", False) - # Extract function signature original_sig = inspect.signature(func) params = list(original_sig.parameters.values())[1:] # Skip 'self' @@ -116,7 +113,6 @@ def _generate_export_hash(qeff_model, args, kwargs, func): # Use the model's current configuration for hashing to ensure any post-load modifications are captured # TODO: Replace with get_model_config property of modeling classes and remove the if-else # Determine the config dict to use, preferring .to_diff_dict() if available - if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"): config_val = qeff_model.model.config.to_diff_dict() elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"): @@ -124,26 +120,25 @@ def _generate_export_hash(qeff_model, args, kwargs, func): else: config_val = qeff_model.model.config - qeff_model.hash_params.update( + copy_of_hash_params = copy.deepcopy(qeff_model.hash_params) + copy_of_hash_params.update( { "config": config_val, } ) - # Generate hash from relevant parameters export_hash, filtered_hash_params = create_export_hash( - model_params=qeff_model.hash_params, + model_params=copy_of_hash_params, output_names=all_args.get("output_names"), dynamic_axes=all_args.get("dynamic_axes"), export_kwargs=all_args.get("export_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), - use_onnx_subfunctions=use_onnx_subfunctions, ) return export_hash, filtered_hash_params -def _setup_onnx_subfunctions(qeff_model, kwargs): +def _setup_onnx_subfunctions(qeff_model, args, kwargs): """ Setup ONNX subfunction export environment. @@ -166,26 +161,22 @@ def _setup_onnx_subfunctions(qeff_model, kwargs): # Apply torch patches for subfunction support apply_torch_patches() InvalidIndexProvider.SUBFUNC_ENABLED = True - - # Store original state for restoration during cleanup - qeff_model._original_onnx_transforms = qeff_model._onnx_transforms.copy() - # Transform output names for subfunction compatibility if "output_names" in kwargs: kwargs["output_names"] = [ re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"] ] - + else: + args = list(args) + args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]] + args = tuple(args) # Add subfunction-specific ONNX transforms qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform) qeff_model._onnx_transforms.append(CustomOpTransform) - # Configure export to use modules as functions - export_kwargs = kwargs.get("export_kwargs", {}) - # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation - export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) - kwargs["export_kwargs"] = export_kwargs + kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + return args, kwargs def _cleanup_onnx_subfunctions(qeff_model): @@ -205,18 +196,11 @@ def _cleanup_onnx_subfunctions(qeff_model): even if export fails. Errors during cleanup are logged but not re-raised to avoid masking the original exception. """ - try: - # Undo torch patches - undo_torch_patches() - InvalidIndexProvider.SUBFUNC_ENABLED = False - - # Restore original ONNX transforms - if hasattr(qeff_model, "_original_onnx_transforms"): - qeff_model._onnx_transforms = qeff_model._original_onnx_transforms - delattr(qeff_model, "_original_onnx_transforms") - - except Exception as e: - logger.error(f"Error during subfunction cleanup: {e}") + # Undo torch patches + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + qeff_model._onnx_transforms.remove(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.remove(CustomOpTransform) def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict): diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 68ccab0d4..10e6686d0 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -56,8 +56,6 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - if kwargs.get("use_onnx_subfunctions"): - export_params["use_onnx_subfunctions"] = True export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") @@ -69,5 +67,4 @@ def create_export_hash(**kwargs): export_hash_params.update(onnx_transform_kwargs) if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict): export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict() - return hash_dict_params(export_hash_params), export_hash_params diff --git a/examples/disagg_serving/README.md b/examples/disagg_serving/README.md new file mode 100644 index 000000000..fcf665357 --- /dev/null +++ b/examples/disagg_serving/README.md @@ -0,0 +1,31 @@ +# We should be using disaggragate serving for GPTOSS model for best performance + - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok + - We use read all experts only once always strategy in prefill-only model + - And we treat weights activtions meaning read only chosen experts for decode-only model + +# Prefill-only model +## Blocking default behviour when `prefill_only=True` in compile API + - NUM_Q_BLOCKS= set number of Q blocks in attention + - NUM_FFN_BLOCKS= set number of blocks in FFN + - ENABLE_OPT_SWA="0" or "1" to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs + - prefix_caching is not supported with this mode + +## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API + - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default + - This model can be used for prefix_caching by passing `kv_cache_batch_size=` in compile API + +# Decode-only model +## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API + - This reduces the amount of DDR used by the model + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed +## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API + - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed + - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers + + +NOTE: +* decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it +* 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=` +* It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error + diff --git a/examples/disagg_serving/gpt_oss_disagg_mode.py b/examples/disagg_serving/gpt_oss_disagg_mode.py new file mode 100644 index 000000000..fd0d5b045 --- /dev/null +++ b/examples/disagg_serving/gpt_oss_disagg_mode.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 256 +CTX_LEN = 256 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, +) +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + use_onnx_subfunctions=True, +) + +prefill_session = QAICInferenceSession(prefill_qpc_path) + +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() +qpc_out = prefill_session.run(inputs) +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) diff --git a/examples/disagg_serving/subfunction_120b_npi.yaml b/examples/disagg_serving/subfunction_120b_npi.yaml new file mode 100644 index 000000000..762703d58 --- /dev/null +++ b/examples/disagg_serving/subfunction_120b_npi.yaml @@ -0,0 +1,27 @@ +FP32NodeInstanceNames: + - CustomRMSNorm_58 + - onnx::Shape_1033777 + - CustomRMSNorm_349 + - hidden.127 + - CustomRMSNorm_27448 + - onnx::Shape_1066066 + - CustomRMSNorm_27709 + - hidden.131 + - CustomRMSNorm_54808 + - onnx::Shape_878 + - CustomRMSNorm_55105 + - hidden + - hidden_states.259 + - Add_348 + - Add_347 + - onnx::Add_1034099 + - hidden_states.267 + - Add_27708 + - onnx::Add_1066358 + - Add_27707 + - hidden_states.3 + - Add_55104 + - onnx::Add_1209 + - Add_55103 + - /model/norm/CustomRMSNorm + - /model/norm/CustomRMSNorm_output_0 \ No newline at end of file diff --git a/examples/disagg_serving/without_subfunc_npi_120b.yaml b/examples/disagg_serving/without_subfunc_npi_120b.yaml new file mode 100644 index 000000000..ec6cf034f --- /dev/null +++ b/examples/disagg_serving/without_subfunc_npi_120b.yaml @@ -0,0 +1,148 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/Add_1_output_0 + - /model/layers.34/Add_output_0 + - /model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/Add_1_output_0 + - /model/layers.35/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 + \ No newline at end of file diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py new file mode 100644 index 000000000..363e2806c --- /dev/null +++ b/examples/gpt_oss_disagg_mode_with_chunking.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +# Run prefill +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "provide path here" +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +decode_session = QAICInferenceSession(decode_qpc_path) +prefill_session = QAICInferenceSession(prefill_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index d878076fa..8f95c1d98 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -42,7 +42,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 00a4216b7..46b33c60b 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( # export start = perf_counter() - qeff_model.export(export_dir=tmp_path) + onnx_path = qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash) @@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( assert export_time_1 < export_time_0 # test compile - qeff_model.compile(prefill_seq_len=32, ctx_len=64) + qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py index cc94467db..c3bb2f140 100644 --- a/tests/peft/test_peft_model.py +++ b/tests/peft/test_peft_model.py @@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path): _, lora_model = create_peft_model(base_config, adapter_config) qeff_model = QEffAutoPeftModelForCausalLM(lora_model) - qeff_model.export(tmp_path) + onnx_path = qeff_model.export(tmp_path) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_0 = end - start @@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con ) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_1 = end - start assert compile_time_1 < 0.01 * compile_time_0 diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py new file mode 100644 index 000000000..6358940df --- /dev/null +++ b/tests/transformers/models/test_disagg_mode.py @@ -0,0 +1,192 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers + +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 + +prompt2 = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +prompt1 = "Once upon a time" + +prompts = [prompt1, prompt2] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + qeff_out = qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + qpc_out = prefill_session.run(inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 + + +@pytest.mark.skip(reason="no way of currently testing this without the assert sdk") +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill_chunked(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True, enable_chunking=True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = CTX_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + + qeff_out = qeff_model.model(**chunk_inputs) + inputs["past_key_values"] = qeff_out["past_key_values"] + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + ) + prefill_session = QAICInferenceSession(prefill_qpc_path) + prefill_session.skip_buffers( + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] + ) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 3eaaf0f69..72477d56a 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -14,10 +14,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.hash_utils import hash_dict_params -configs = [ +test_configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), @@ -36,30 +37,43 @@ ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs +test_prefill_only_specialized_models_configs = [ + ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}), ] + + +def get_auto_config_from_test_config(configs): + auto_configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs + ] + return auto_configs + + +configs = get_auto_config_from_test_config(test_configs) config_ids = [x.model_type for x in configs] +prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs) +prefill_only_config_ids = [x.model_type for x in prefill_only_configs] + model_kwargs = {"attn_implementation": "eager"} @@ -144,20 +158,21 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path): @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash_creation(config, cb, tmp_path): +def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) qeff_model = QEFFAutoModelForCausalLM(model, cb) - qeff_model.export(tmp_path) + qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc) hash_params = {} hash_params["config"] = qeff_model.model.config.to_diff_dict() hash_params["peft_config"] = None hash_params["applied_transform_names"] = qeff_model._transform_names() hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["max_seq_len_cached"] = None hash_params["qaic_config"] = None # Create parameters separately for hash creation - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -190,12 +205,12 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): ) output_names = [] output_names.append("logits") - + onnx_out_name_suffix = "InternalRetainedState" if subfunc else "RetainedState" for i in range(qeff_model.num_layers): pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size" for kv in ["key", "value"]: dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] - output_names.append(f"past_{kv}.{i}_RetainedState") + output_names.append(f"past_{kv}.{i}_{onnx_out_name_suffix}") if qeff_model.continuous_batching: dynamic_axes["batch_index"] = {0: "batch_size"} @@ -204,11 +219,32 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): export_params["output_names"] = output_names export_params["dynamic_axes"] = dynamic_axes hash_params["export_params"] = export_params + if subfunc: + hash_params["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + manual_hash = hash_dict_params(hash_params) assert manual_hash == qeff_model.export_hash +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids) +def test_prefill_only_specialized_models(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + if cb: + with pytest.raises(NotImplementedError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + else: + with pytest.raises(ValueError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False) + first_export_hash = qeff_model.export_hash + qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False) + second_export_hash = qeff_model.export_hash + assert first_export_hash != second_export_hash + + @pytest.fixture def tmp_cache(tmp_path, monkeypatch): monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path)