Skip to content

Commit 06f9f08

Browse files
Adding ccl_enabled flag during model loading and passing CCL lists during compilation process (#623)
In these changes, instead of passing CCL lists during model loading, I passed a flag called ccl_enabled to specify whether CCL feature is enabled or not and moved passing CCL lists to compilation process. --------- Signed-off-by: Vahid Janfaza <[email protected]> Co-authored-by: Hem Agnihotri <[email protected]>
1 parent 3974a08 commit 06f9f08

24 files changed

+3125
-104
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -937,11 +937,13 @@ def __init__(
937937
self.model = model
938938
self.config = model.config
939939

940-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
941-
942940
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
943941
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
944942
self.continuous_batching = continuous_batching
943+
self.ccl_enabled = False
944+
if qaic_config:
945+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
946+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
945947
self.input_shapes, self.output_names = None, None
946948

947949
@property
@@ -985,6 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
985987
logger.warning("Updating low_cpu_mem_usage=False")
986988

987989
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
990+
988991
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
989992
return cls(
990993
model,
@@ -1095,6 +1098,8 @@ def compile(
10951098
compile_dir: Optional[str] = None,
10961099
*,
10971100
prefill_seq_len: Optional[int] = None,
1101+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1102+
comp_ctx_lengths_decode: Optional[List[int]] = None,
10981103
ctx_len: Optional[int] = None,
10991104
batch_size: int = 1,
11001105
full_batch_size: Optional[int] = None,
@@ -1179,10 +1184,21 @@ def compile(
11791184

11801185
output_names = self.model.get_output_names(kv_offload=True)
11811186

1187+
# if ccl_enabled is True read Compute-Context-Length lists
1188+
if self.ccl_enabled:
1189+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1190+
logger.warning(
1191+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1192+
)
1193+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1194+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1195+
)
1196+
11821197
# For supporting VLLM and Disaggregated with CCL
1183-
if "comp_ctx_lengths_prefill" in compiler_options:
1184-
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
1185-
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
1198+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1199+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1200+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1201+
)
11861202

11871203
specializations, compiler_options = self.model.get_specializations(
11881204
batch_size=batch_size,
@@ -1634,7 +1650,6 @@ def __init__(
16341650
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
16351651
super().__init__(model, **kwargs)
16361652

1637-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
16381653
self.model.qaic_config = qaic_config
16391654

16401655
# to handle internvl models
@@ -1648,6 +1663,10 @@ def __init__(
16481663
else:
16491664
self.model.config.use_cache = True
16501665
self.hash_params["qeff_auto_class"] = self.__class__.__name__
1666+
self.ccl_enabled = False
1667+
if qaic_config:
1668+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
1669+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
16511670

16521671
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
16531672
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
@@ -1687,6 +1706,7 @@ def from_pretrained(
16871706
logger.warning("Updating low_cpu_mem_usage=False")
16881707

16891708
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
1709+
16901710
from transformers import AutoConfig
16911711

16921712
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
@@ -1741,6 +1761,8 @@ def compile(
17411761
*,
17421762
prefill_seq_len: Optional[int] = None,
17431763
ctx_len: Optional[int] = None,
1764+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1765+
comp_ctx_lengths_decode: Optional[List[int]] = None,
17441766
batch_size: int = 1,
17451767
full_batch_size: Optional[int] = None,
17461768
kv_cache_batch_size: Optional[int] = None,
@@ -1810,10 +1832,21 @@ def compile(
18101832
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
18111833
output_names = self.model.get_output_names()
18121834

1835+
# if ccl_enabled is True read Compute-Context-Length lists
1836+
if self.ccl_enabled:
1837+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1838+
logger.warning(
1839+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1840+
)
1841+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1842+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1843+
)
1844+
18131845
# For supporting VLLM and Disaggregated with CCL
1814-
if "comp_ctx_lengths_prefill" in compiler_options:
1815-
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
1816-
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
1846+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1847+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1848+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1849+
)
18171850

18181851
# Get specializations from modelling file
18191852
# TODO: expose this via the auto class as well
@@ -2378,8 +2411,6 @@ def __init__(
23782411
# Set use_cache=True to get KV values as output during ONNX export
23792412
model.config.use_cache = True
23802413

2381-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
2382-
23832414
super().__init__(model, qaic_config=qaic_config, **kwargs)
23842415
self.num_layers = model.config.num_hidden_layers
23852416
self.continuous_batching = continuous_batching
@@ -2388,6 +2419,10 @@ def __init__(
23882419
self.is_tlm = transformed
23892420

23902421
self.hash_params["qeff_auto_class"] = self.__class__.__name__
2422+
self.ccl_enabled = False
2423+
if qaic_config:
2424+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
2425+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
23912426

23922427
# ---Sampling---
23932428
# Note: SamplerTransform should be applied after all other transforms
@@ -2833,6 +2868,8 @@ def compile(
28332868
*,
28342869
prefill_seq_len: int = 32,
28352870
ctx_len: int = 128,
2871+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
2872+
comp_ctx_lengths_decode: Optional[List[int]] = None,
28362873
batch_size: int = 1,
28372874
full_batch_size: Optional[int] = None,
28382875
kv_cache_batch_size: Optional[int] = None,
@@ -2924,10 +2961,18 @@ def compile(
29242961
29252962
"""
29262963

2964+
# if ccl_enabled is True read Compute-Context-Length lists
2965+
if self.ccl_enabled:
2966+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
2967+
logger.warning(
2968+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
2969+
)
2970+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2971+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
2972+
)
2973+
29272974
# For supporting VLLM and Disaggregated with CCL
2928-
if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options:
2929-
comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
2930-
comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
2975+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
29312976
if isinstance(comp_ctx_lengths_prefill, str):
29322977
import ast
29332978

@@ -2942,6 +2987,9 @@ def compile(
29422987
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
29432988
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
29442989

2990+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2991+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
2992+
)
29452993
# --- Validation ---
29462994
if prefill_only is not None and not isinstance(prefill_only, bool):
29472995
raise TypeError("`prefill_only` must be a boolean.")

QEfficient/transformers/spd/spd_transform_forward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def tlm_forward(
7676
attention_mask: Optional[torch.Tensor] = None,
7777
position_ids: Optional[torch.LongTensor] = None,
7878
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
7980
batch_index: Optional[torch.LongTensor] = None,
8081
inputs_embeds: Optional[torch.FloatTensor] = None,
8182
labels: Optional[torch.LongTensor] = None,
@@ -123,6 +124,7 @@ def tlm_forward(
123124
attention_mask=attention_mask,
124125
position_ids=position_ids,
125126
past_key_values=past_key_values,
127+
comp_ctx_lengths=comp_ctx_lengths,
126128
batch_index=batch_index,
127129
inputs_embeds=inputs_embeds,
128130
use_cache=use_cache,

QEfficient/utils/check_ccl_specializations.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88

9-
def process_ccl_specializations(qaic_config):
10-
if qaic_config is None:
11-
return None, None
12-
ccl_prefill = qaic_config.pop("comp_ctx_lengths_prefill", None)
13-
ccl_decode = qaic_config.pop("comp_ctx_lengths_decode", None)
14-
ctx_len = qaic_config.pop("ctx_len", None)
15-
prefill_seq_len = qaic_config.pop("prefill_seq_len", 128)
16-
9+
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
1710
if ccl_prefill is None or ccl_decode is None:
1811
return None, None
1912

examples/performance/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,56 @@ python on_device_sampling.py \
9595
--top-p 0.89
9696
```
9797

98+
### Compute-Context-Length
99+
100+
Calculating Context-Length dynamically during inference for getting the best related performance within each window of context-length
101+
102+
#### compute_context_length/basic_inference.py
103+
Configure CCL parameters: 1) ccl-enabled: to activate CCL feature, 2) comp-ctx-lengths-prefill: list of context length to be used during prefilling, and 3) comp-ctx-lengths-decode: list of context lengths to be used during decoding.
104+
105+
**Usage for Text-only models:**
106+
```bash
107+
python compute_context_length/basic_inference.py \
108+
--model-name meta-llama/Llama-3.1-8B \
109+
--num-cores 16 \
110+
--prefill-seq-len 32 \
111+
--ctx-len 1024 \
112+
--ccl-enabled \
113+
--comp-ctx-lengths-prefill 500,1000 \
114+
--comp-ctx-lengths-decode 512,1024
115+
```
116+
117+
**Usage for VLM models such as mllama and llava:**
118+
```bash
119+
python compute_context_length/vlm_inference.py \
120+
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
121+
--hf-token "" \
122+
--num-cores 16 \
123+
--prefill-seq-len 32 \
124+
--ctx-len 8192 \
125+
--img-size 560 \
126+
--ccl-enabled \
127+
--comp-ctx-lengths-prefill 4096 \
128+
--comp-ctx-lengths-decode 6144,8192
129+
```
130+
131+
**Usage with other MoE and Multimodal models:**
132+
For various models available in compute_context_length directory such as gemma3, gpt_oss, granite_vision, internvl, llama4_cb, llama4_multi_image, llama4, mistral3, molmo, qwen2_5_vl, qwen2_5_vl_cb, and qwen3moe, use the related inference script and only change the model-name and ccl configuration in the related script. The following is an example of each model:
133+
```bash
134+
python compute_context_length/gemma3.py
135+
python compute_context_length/gpt_oss.py
136+
python compute_context_length/granite_vision.py
137+
python compute_context_length/internvl.py
138+
python compute_context_length/llama4_cb.py
139+
python compute_context_length/llama4_multi_image.py
140+
python compute_context_length/llama4.py
141+
python compute_context_length/mistral3.py
142+
python compute_context_length/molmo.py
143+
python compute_context_length/qwen2_5_vl.py
144+
python compute_context_length/qwen2_5_vl_cb.py
145+
python compute_context_length/qwen3moe.py
146+
```
147+
98148
## Performance Tips
99149

100150
1. **Speculative Decoding**: Best for long-form generation where draft model is much faster than target

examples/performance/compute_context_length/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ python vlm_inference.py \
6868
Basic CCL usage with text-only language models.
6969

7070
**Supported Models:**
71-
- Llama (3.2, 3.3)
71+
- Llama (3.2, 3.3, swiftkv)
7272
- Gemma/Gemma-2
7373
- Mistral
7474
- Phi/Phi-3
@@ -77,6 +77,9 @@ Basic CCL usage with text-only language models.
7777
- GPT-2, GPT-J
7878
- CodeGen
7979
- OLMo-2
80+
- Mistral/Mixtral
81+
- Qwen2
82+
- Falcon
8083

8184
**Command-Line Arguments:**
8285
- `--model-name`: HuggingFace model ID (default: meta-llama/Llama-3.2-1B)

examples/performance/compute_context_length/basic_inference.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def main():
4646
default=1024,
4747
help="Maximum context length",
4848
)
49+
parser.add_argument(
50+
"--ccl-enabled",
51+
action="store_true",
52+
help="Enable compute-context-length (CCL) feature",
53+
)
4954
parser.add_argument(
5055
"--comp-ctx-lengths-prefill",
5156
type=lambda x: [int(i) for i in x.split(",")],
@@ -113,9 +118,7 @@ def main():
113118
args.model_name,
114119
continuous_batching=args.continuous_batching,
115120
qaic_config={
116-
"comp_ctx_lengths_prefill": args.comp_ctx_lengths_prefill,
117-
"comp_ctx_lengths_decode": args.comp_ctx_lengths_decode,
118-
"ctx_len": args.ctx_len, # Required for CCL validation
121+
"ccl_enabled": args.ccl_enabled,
119122
},
120123
)
121124

@@ -132,6 +135,9 @@ def main():
132135

133136
if args.continuous_batching:
134137
compile_kwargs["full_batch_size"] = args.full_batch_size
138+
if args.ccl_enabled:
139+
compile_kwargs["comp_ctx_lengths_prefill"] = args.comp_ctx_lengths_prefill
140+
compile_kwargs["comp_ctx_lengths_decode"] = args.comp_ctx_lengths_decode
135141

136142
qpc_path = model.compile(**compile_kwargs)
137143
print(f"Model compiled successfully to: {qpc_path}")

0 commit comments

Comments
 (0)