diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 7da2300d6..4fb77f272 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -329,6 +329,7 @@ def cloud_ai_100_exec_kv( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -356,6 +357,8 @@ def cloud_ai_100_exec_kv( next tokens. For Speculative Decoding Target Language Model, `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. + :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering + during decoding. Only works when `include_sampler`=True. sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, @@ -394,6 +397,7 @@ def cloud_ai_100_exec_kv( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -442,6 +446,7 @@ def __init__( is_tlm: Optional[int] = None, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, activate: bool = True, ) -> None: @@ -451,6 +456,7 @@ def __init__( self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs + self.include_guided_decoding = include_guided_decoding self.sampling_params = sampling_params self._qpc_path = qpc_path # Store qpc_path for later use @@ -461,7 +467,9 @@ def __init__( # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._session.input_names), + include_sampler=include_sampler, + include_guided_decoding=include_guided_decoding, ) # Fetch the variables from the QPC @@ -628,7 +636,7 @@ def prepare_decode_inputs(self): decode_inputs["batch_index"] = self.batch_index if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if self.batch_index is not None: decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] else: @@ -795,7 +803,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: inputs["last_accepted_output_tokens"] = inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -1067,6 +1075,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( @@ -1082,6 +1091,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index c603a60d0..b3e03f253 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -94,6 +94,7 @@ def __init__( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -115,6 +116,7 @@ def __init__( is_tlm: Target language model flag include_sampler: Enable on-device sampling (new feature) return_pdfs: Return probability distributions + include_guided_decoding: Enable guided decoding in on-device sampling sampling_params: Sampling parameters for on-device sampling """ # Validate required parameters @@ -138,6 +140,7 @@ def __init__( is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, activate=False, # vision components need to be initialized first ) @@ -315,7 +318,7 @@ def _execute_chunked_prefill( lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] if self.include_sampler: - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -348,7 +351,7 @@ def _execute_chunked_prefill( if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): chunk_inputs[op] = lang_inputs[op] outputs = self._session.run(chunk_inputs) @@ -803,6 +806,7 @@ def generate_stream_tokens( is_tlm=self.is_tlm, include_sampler=self.include_sampler, return_pdfs=self.return_pdfs, + include_guided_decoding=self.include_guided_decoding, sampling_params=self.sampling_params, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index dc03ba82f..88f2f29b1 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2251,7 +2251,6 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, @@ -2347,6 +2346,8 @@ def __init__( - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. For Speculative Decoding Target Language Models, this is always True. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -2443,6 +2444,8 @@ def from_pretrained( and ``return_pdfs=False`` for regular model. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. The values provided in ``top_ks`` tensor must be less than this maximum limit. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. *args : Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 9e021851b..b978b6193 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -242,6 +242,7 @@ QEffGemma3Attention, QEffGemma3CustomRMSNormAIC, QEffGemma3DecoderLayer, + QEffGemma3DecoderWrapper, QEffGemma3ForCausalLMModel, QEffGemma3ForConditionalGeneration, QEffGemma3TextModel, @@ -313,6 +314,7 @@ QEffLlamaRotaryEmbedding, ) from QEfficient.transformers.models.llama4.modeling_llama4 import ( + QEffLlama4DecoderWrapper, QEffLlama4ForCausalLM, QEffLlama4ForConditionalGeneration, QEffLlama4Router, @@ -325,9 +327,11 @@ QEffLlama4VisionModel, ) from QEfficient.transformers.models.llava.modeling_llava import ( + QEFFLlavaDecoderWrapper, QEffLlavaForConditionalGeneration, ) from QEfficient.transformers.models.llava_next.modeling_llava_next import ( + QEffLlavaNextDecoderWrapper, QEffLlavaNextForConditionalGeneration, ) from QEfficient.transformers.models.mistral.modeling_mistral import ( @@ -755,12 +759,16 @@ class SamplerTransform: _module_mapping = { QEffFalconForCausalLM, QEffGemmaForCausalLM, + QEffGemma3DecoderWrapper, QEffGPT2LMHeadModel, QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, QEffInternDecoderWrapper, QEffLlamaForCausalLM, + QEffLlama4DecoderWrapper, + QEFFLlavaDecoderWrapper, + QEffLlavaNextDecoderWrapper, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index fd7b87dcd..5c86b6355 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -129,6 +129,7 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + token_bitmasks: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -179,6 +180,11 @@ def sampler_forward( random_numbers (`torch.Tensor`, *optional*): Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. + + token_bitmasks (`torch.Tensor`, *optional*): + Boolean mask used to guide token-level filtering during decoding. Each + element of this tensor indicates whether the corresponding token should be + kept (1) or masked (0). Shape: (batch_size, vocab_size) """ if vision_embeds is not None: forward_kwargs = dict( @@ -224,6 +230,13 @@ def sampler_forward( batch_index = torch.arange(batch_size).view(-1, 1) batch_index_reshaped = batch_index.view(-1) + + # Guided decoding + if token_bitmasks is not None and (token_bitmasks != 1).any(): + assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" + # Mask logits where token_bitmasks is 0 with -inf + logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min) + # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( input_ids=input_ids, diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 0460eeb3a..82a0843bc 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -14,7 +14,9 @@ from QEfficient.utils.logging_utils import logger -def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool: +def validate_sampler_inputs( + session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None +) -> bool: """ Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling. @@ -31,7 +33,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ValueError if partial support is detected or if user intent conflicts with QPC capabilities. """ - sampler_inputs = Constants.SAMPLER_INPUTS + sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set()) count = len(sampler_inputs & session_inputs) session_includes_sampler = True @@ -96,10 +98,9 @@ def get_sampling_inputs_and_outputs( """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + seq_len: int = example_inputs["input_ids"].shape[-1] - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) + example_inputs["last_accepted_output_tokens"] = torch.zeros((bs, seq_len), dtype=torch.int64) dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} example_inputs["past_repetition_penalty_buffer"] = torch.zeros( @@ -144,4 +145,8 @@ def get_sampling_inputs_and_outputs( example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} + if qaic_config.get("include_guided_decoding", False): + example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} + return example_inputs, output_names, dynamic_axes diff --git a/examples/performance/on_device_sampling.py b/examples/performance/on_device_sampling.py index b4e1f4e27..da9c5b43b 100644 --- a/examples/performance/on_device_sampling.py +++ b/examples/performance/on_device_sampling.py @@ -21,6 +21,7 @@ def main(args, **kwargs): include_sampler = None return_pdfs = None max_top_k_ids = None + include_guided_decoding = None sampling_params = None bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size if args.override_qaic_config is not None: @@ -29,6 +30,7 @@ def main(args, **kwargs): return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) np.random.seed(int(args.random_number)) + include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true" sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -47,13 +49,12 @@ def main(args, **kwargs): "include_sampler": include_sampler, "return_pdfs": return_pdfs, "max_top_k_ids": max_top_k_ids, + "include_guided_decoding": include_guided_decoding, }.items() if v is not None } print("qaic_config:") pprint(qaic_config) - print("sampling_params:") - pprint(sampling_params) # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( @@ -63,6 +64,19 @@ def main(args, **kwargs): ) print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + if include_guided_decoding: + # Ideally this should come from a logits processor like xgrammar, but for the sake of the + # example, we generate a random bitmask + sampling_params.update( + { + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1) + ) + } + ) + print("sampling_params:") + pprint(sampling_params) + # Compile the model for inference generated_qpc_path = qeff_model.compile( prefill_seq_len=args.prompt_len, @@ -91,6 +105,7 @@ def main(args, **kwargs): generation_len=args.generation_len, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -109,7 +124,7 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ @@ -129,7 +144,27 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 26 + + 3. With guided decoding: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index f9aa35312..26cb6fda9 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -16,6 +16,7 @@ from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants from QEfficient.utils.test_utils import InternProcessor +from tests.transformers.models.image_text_to_text.test_continuous_batching import set_num_layers sampler_transform_configs = [ pytest.param( @@ -92,17 +93,41 @@ True, # is_vlm ), ] +guided_decoding_configs = [ + pytest.param( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model + Constants.INPUT_STR * 4, # prompts + 32, # prefill_seq_len + 64, # ctx_len + 20, # generation_len + 4, # full_batch_size + 1, # spec_length + False, # is_vlm + ), + pytest.param( + "OpenGVLab/InternVL2_5-1B", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm + ), +] def prepare_model_setup( - model: str, is_vlm: bool, num_hidden_layers: Optional[int], prompts: Union[List, Tuple], spec_length: Optional[int] + model: str, is_vlm: bool, num_hidden_layers: int, prompts: Union[List, Tuple], spec_length: Optional[int] ): additional_configs = {} additional_params = {} if is_vlm: config = AutoConfig.from_pretrained(model, trust_remote_code=True) - if num_hidden_layers is not None: - config.llm_config.num_hidden_layers = num_hidden_layers + config = set_num_layers(config, n_layer=num_hidden_layers) additional_configs["config"] = config additional_configs["kv_offload"] = True assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided." @@ -123,7 +148,7 @@ def prepare_model_setup( additional_params["processor"] = AutoProcessor.from_pretrained(model) qeff_class = QEFFAutoModelForImageTextToText else: - if num_hidden_layers is not None: + if num_hidden_layers != -1: additional_configs["num_hidden_layers"] = num_hidden_layers spec_length = (spec_length or 1) - 1 qeff_class = QEFFAutoModelForCausalLM @@ -165,6 +190,17 @@ def test_sampler_transform( }, **additional_configs, ) + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + "include_guided_decoding": True, + }, + **additional_configs, + ) model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, @@ -184,6 +220,16 @@ def test_sampler_transform( mxint8_kv_cache=True, mxfp6_matmul=True, ) + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, @@ -196,10 +242,12 @@ def test_sampler_transform( ) if is_vlm: model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1] model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path) model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) # Skip inputs/outputs buffers @@ -207,6 +255,12 @@ def test_sampler_transform( model_w_sampler_session.skip_buffers( set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")]) + ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")]) + ) model_wo_sampler_session.skip_buffers( set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) ) @@ -220,9 +274,15 @@ def test_sampler_transform( assert input_name in model_w_sampler_session.input_names, ( f"Sampler input {input_name} not found in QPC compiled with On Device Sampler" ) + assert input_name in model_w_sampler_w_guided_decoding_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding" + ) assert input_name not in model_wo_sampler_session.input_names, ( f"Sampler input {input_name} found in QPC compiled without On Device Sampler" ) + assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, ( + "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding" + ) @pytest.mark.on_qaic @@ -241,14 +301,14 @@ def test_greedy_sampling( is_vlm: bool, ): """ - Test greedy sampling with QPC compiled with and without On Device Sampling. + Test greedy sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models num_hidden_layers = 4 additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( model, is_vlm, num_hidden_layers, prompts, spec_length ) - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -258,7 +318,7 @@ def test_greedy_sampling( }, **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -343,14 +403,14 @@ def test_random_sampling( is_vlm: bool, ): """ - Test random sampling with QPC compiled with and without On Device Sampling. + Test random sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models - num_hidden_layers = None + num_hidden_layers = -1 additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( model, is_vlm, num_hidden_layers, prompts, spec_length ) - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -360,7 +420,7 @@ def test_random_sampling( }, **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -547,3 +607,118 @@ def test_random_sampling( assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( "Without sampler generated ids do not match" ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", + guided_decoding_configs, +) +def test_guided_decoding( + model: str, + prompts: Union[List[str], tuple[List[str], List[str]]], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: Optional[int], + is_vlm: bool, +): + """ + Test QPCs compiled with and without guided decoding. + """ + # Export and compile QEfficient models + num_hidden_layers = 2 + additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( + model, is_vlm, num_hidden_layers, prompts, spec_length + ) + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + "include_guided_decoding": True, + }, + **additional_configs, + ) + model_w_sampler_wo_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + }, + **additional_configs, + ) + model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_w_sampler_wo_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) + sampling_params = { + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32), + } + if is_vlm: + vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size + else: + vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size + model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + include_guided_decoding=True, + sampling_params={ + **sampling_params, + **{ + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(vocab_size,)), + (full_batch_size, 1), + ) + }, + }, + **additional_params, + ) + model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params=sampling_params, + **additional_params, + ) + assert ( + model_w_sampler_w_guided_decoding_exec_info.generated_ids + != model_w_sampler_wo_guided_decoding_exec_info.generated_ids + ).any(), "Sampler outputs with and without guided decoding should not match"