Skip to content

Commit 2788e6e

Browse files
committed
Add automatic CCL list generation for prefill and decode when user does not provide lists
Signed-off-by: Vahid Janfaza <[email protected]>
1 parent 8673d2c commit 2788e6e

File tree

17 files changed

+295
-100
lines changed

17 files changed

+295
-100
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,17 +1126,14 @@ def compile(
11261126

11271127
# if ccl_enabled is True read Compute-Context-Length lists
11281128
if self.ccl_enabled:
1129-
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1130-
logger.warning(
1131-
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1132-
)
1133-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1129+
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
1130+
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
1131+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
11341132
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
11351133
)
1136-
11371134
# For supporting VLLM and Disaggregated with CCL
1138-
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1139-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1135+
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1136+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
11401137
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
11411138
)
11421139

@@ -1774,17 +1771,14 @@ def compile(
17741771

17751772
# if ccl_enabled is True read Compute-Context-Length lists
17761773
if self.ccl_enabled:
1777-
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1778-
logger.warning(
1779-
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1780-
)
1781-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1774+
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
1775+
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
1776+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
17821777
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
17831778
)
1784-
17851779
# For supporting VLLM and Disaggregated with CCL
1786-
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1787-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1780+
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1781+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
17881782
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
17891783
)
17901784

@@ -2873,16 +2867,13 @@ def compile(
28732867

28742868
# if ccl_enabled is True read Compute-Context-Length lists
28752869
if self.ccl_enabled:
2876-
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
2877-
logger.warning(
2878-
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
2879-
)
2880-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2870+
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
2871+
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
2872+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
28812873
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
28822874
)
2883-
28842875
# For supporting VLLM and Disaggregated with CCL
2885-
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
2876+
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
28862877
if isinstance(comp_ctx_lengths_prefill, str):
28872878
import ast
28882879

@@ -2897,7 +2888,7 @@ def compile(
28972888
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
28982889
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
28992890

2900-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2891+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
29012892
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
29022893
)
29032894
# --- Validation ---

QEfficient/utils/check_ccl_specializations.py

Lines changed: 205 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,210 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from typing import List, Optional, Tuple
9+
10+
11+
def next_multiple_of_1024(n: int) -> int:
12+
"""Ceil 'n' to the next multiple of 1024."""
13+
if n <= 0:
14+
return 0
15+
return ((n + 1023) // 1024) * 1024
16+
17+
18+
def floor_to_1000(n: int) -> int:
19+
"""Floor 'n' to the nearest lower multiple of 1000."""
20+
if n <= 0:
21+
return 0
22+
return (n // 1000) * 1000
23+
24+
25+
def is_power_of_two(n: int) -> bool:
26+
"""Return True if n is a power of two (n>0 and n&(n-1)==0)."""
27+
return n > 0 and (n & (n - 1)) == 0
28+
29+
30+
def build_doubling_sequence(start: int, limit: int, max_elements: int, force_last: Optional[int] = None) -> List[int]:
31+
"""
32+
Build an increasing sequence starting at 'start', doubling each step,
33+
not exceeding 'limit', with total length <= max_elements.
34+
If 'force_last' is provided, ensure the last element equals 'force_last'
35+
(replacing/appending as needed), even if it exceeds 'limit'.
36+
"""
37+
if max_elements <= 0:
38+
return []
39+
40+
# If start is already beyond limit, return [force_last or limit] as a single element.
41+
if start > limit:
42+
seq = [force_last if force_last is not None else limit]
43+
return seq[:max_elements]
44+
45+
seq: List[int] = []
46+
val = start
47+
48+
while val <= limit and len(seq) < max_elements:
49+
seq.append(val)
50+
next_val = val * 2
51+
if next_val > limit or len(seq) >= max_elements:
52+
break
53+
val = next_val
54+
55+
# Add/replace last element if a 'force_last' is requested
56+
if force_last is not None:
57+
if len(seq) == 0:
58+
seq = [force_last]
59+
elif seq[-1] != force_last:
60+
if len(seq) < max_elements:
61+
seq.append(force_last)
62+
else:
63+
seq[-1] = force_last
64+
65+
# Deduplicate while preserving order
66+
dedup = []
67+
seen = set()
68+
for x in seq:
69+
if x not in seen:
70+
dedup.append(x)
71+
seen.add(x)
72+
return dedup[:max_elements]
73+
74+
75+
def Automatic_CCL_Generation(
76+
CL: int,
77+
prefill_seq_len: int,
78+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
79+
comp_ctx_lengths_decode: Optional[List[int]] = None,
80+
) -> Tuple[List[int], List[int], int]:
81+
"""
82+
Automatic Compute-Context-Length Lists Generation
83+
84+
Purpose:
85+
Compute decode and prefill ccl lists based on an input context
86+
length (CL), prefill sequence length, and optional pre-specified lists.
87+
"""
88+
89+
if CL <= 0:
90+
mapped_CL = next_multiple_of_1024(max(CL, 1))
91+
# For non-positive CL, minimal identical sequences
92+
seq = [mapped_CL]
93+
return seq, seq, mapped_CL
94+
95+
mapped_CL = next_multiple_of_1024(CL)
96+
97+
# Tiered starts
98+
if mapped_CL <= 4096:
99+
seq = [mapped_CL]
100+
return seq, seq, mapped_CL
101+
elif mapped_CL <= 32768:
102+
decode_start, prefill_start = 4096, 4000
103+
elif mapped_CL <= 65536:
104+
decode_start, prefill_start = 8192, 8000
105+
elif mapped_CL <= 131072:
106+
decode_start, prefill_start = 16384, 16000
107+
else:
108+
decode_start, prefill_start = 16384, 16000
109+
110+
# If prefill_seq_len > 1:
111+
if prefill_seq_len > 1:
112+
# Passthrough if either provided
113+
if comp_ctx_lengths_decode is not None or comp_ctx_lengths_prefill is not None:
114+
return (
115+
comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [],
116+
comp_ctx_lengths_prefill if comp_ctx_lengths_prefill is not None else [],
117+
mapped_CL,
118+
)
119+
120+
max_elems = 5
121+
122+
# Decode: ensure last = mapped_CL
123+
decode = build_doubling_sequence(
124+
start=decode_start,
125+
limit=mapped_CL,
126+
max_elements=max_elems,
127+
force_last=mapped_CL,
128+
)
129+
130+
# Prefill:
131+
if is_power_of_two(CL):
132+
# Strict doubling, limit = CL, no forced non-doubling last
133+
prefill = build_doubling_sequence(
134+
start=prefill_start,
135+
limit=CL,
136+
max_elements=max_elems,
137+
force_last=None,
138+
)
139+
else:
140+
prefill_last = floor_to_1000(mapped_CL)
141+
prefill = build_doubling_sequence(
142+
start=prefill_start,
143+
limit=CL,
144+
max_elements=max_elems,
145+
force_last=prefill_last,
146+
)
147+
148+
return prefill, decode, mapped_CL
149+
150+
# UPDATED: prefill_seq_len == 1 → identical lists
151+
else:
152+
max_elems = 10
153+
grid_cap = 2097152 # upper cap for doubling grid
154+
155+
if mapped_CL < 4096:
156+
seq = [mapped_CL]
157+
else:
158+
seq = build_doubling_sequence(
159+
start=4096,
160+
limit=min(mapped_CL, grid_cap),
161+
max_elements=max_elems,
162+
force_last=mapped_CL, # identical lists end at mapped_CL
163+
)
164+
return seq, seq, mapped_CL
165+
8166

9167
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
10-
if ccl_prefill is None or ccl_decode is None:
11-
return None, None
12-
13-
if ctx_len is None:
14-
raise TypeError("`ctx_len` is required when loading the model with CCL.")
15-
16-
if prefill_seq_len == 1:
17-
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
18-
ccl_union_all = sorted(set(ccl_prefill + ccl_decode))
19-
ccl_union_all = [min(x, ctx_len) for x in ccl_union_all]
20-
return ccl_union_all, ccl_union_all
21-
22-
# Step 1: Cap values to ctx_len
23-
ccl_prefill = [min(x, ctx_len) for x in ccl_prefill]
24-
ccl_decode = [min(x, ctx_len) for x in ccl_decode]
25-
26-
# Step 2: Remove duplicates within each list
27-
ccl_prefill = list(set(ccl_prefill))
28-
ccl_decode = list(set(ccl_decode))
29-
30-
# Step 3: Ensure no overlap between ccl_prefill and ccl_decode
31-
updated_prefill = []
32-
for val in ccl_prefill:
33-
while val in ccl_decode or val in updated_prefill:
34-
val -= 1
35-
if val < 0:
36-
break # Prevent negative values
37-
if val >= 0:
38-
updated_prefill.append(val)
39-
40-
# Step 4: Sort both lists
41-
updated_prefill.sort()
42-
ccl_decode.sort()
43-
44-
return updated_prefill, ccl_decode
168+
# Automatic CCL generation: If both ccl_prefill and ccl_decode are None,
169+
# generate optimized context length lists for prefill and decode based on ctx_len
170+
if ccl_prefill is None and ccl_decode is None:
171+
ccl_prefill, ccl_decode, ctx_len = Automatic_CCL_Generation(ctx_len, prefill_seq_len, ccl_prefill, ccl_decode)
172+
else:
173+
if prefill_seq_len == 1:
174+
if ccl_prefill is not None and ccl_decode is not None:
175+
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
176+
ccl_union_all = sorted(set(ccl_prefill + ccl_decode))
177+
ccl_union_all = [min(x, ctx_len) for x in ccl_union_all]
178+
ccl_prefill = ccl_union_all
179+
ccl_decode = ccl_union_all
180+
else:
181+
# Step 1: Cap values to ctx_len
182+
ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] if ccl_prefill is not None else None
183+
ccl_decode = [min(x, ctx_len) for x in ccl_decode] if ccl_decode is not None else None
184+
185+
# Step 2: Remove duplicates within each list
186+
ccl_prefill = list(set(ccl_prefill)) if ccl_prefill is not None else None
187+
ccl_decode = list(set(ccl_decode)) if ccl_decode is not None else None
188+
189+
if ccl_prefill is None or ccl_decode is None:
190+
if ccl_prefill:
191+
ccl_prefill.sort()
192+
if ccl_decode:
193+
ccl_decode.sort()
194+
else:
195+
# Step 3: Ensure no overlap between ccl_prefill and ccl_decode
196+
tmp_prefill = ccl_prefill
197+
ccl_prefill = []
198+
for val in tmp_prefill:
199+
while val in ccl_decode or val in ccl_prefill:
200+
val -= 1
201+
if val < 0:
202+
break # Prevent negative values
203+
if val >= 0:
204+
ccl_prefill.append(val)
205+
206+
# Step 4: Sort both lists
207+
ccl_prefill.sort()
208+
ccl_decode.sort()
209+
210+
print("CCL Configuration:")
211+
print(f" - Prefill context lengths: {ccl_prefill}")
212+
print(f" - Decode context lengths: {ccl_decode}")
213+
print(f" - Max context length: {ctx_len}")
214+
return ccl_prefill, ccl_decode, ctx_len

examples/performance/compute_context_length/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,22 @@ python basic_inference.py \
3737
--model-name meta-llama/Llama-3.2-1B \
3838
--prompt "Hello, how are you?" \
3939
--ctx-len 1024 \
40+
--ccl-enabled \
4041
--comp-ctx-lengths-prefill "256,500" \
4142
--comp-ctx-lengths-decode "512,1024" \
4243
--generation-len 100
4344
```
4445

46+
# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag
47+
```bash
48+
python basic_inference.py \
49+
--model-name meta-llama/Llama-3.2-1B \
50+
--prompt "Hello, how are you?" \
51+
--ctx-len 1024 \
52+
--ccl-enabled \
53+
--generation-len 100
54+
```
55+
4556
### Vision-Language Models
4657

4758
Run VLM inference with CCL:
@@ -55,11 +66,22 @@ python vlm_inference.py \
5566
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
5667
--query "Describe this image" \
5768
--image-url "https://..." \
69+
--ccl-enabled \
5870
--comp-ctx-lengths-prefill "4096" \
5971
--comp-ctx-lengths-decode "6144,8192" \
6072
--ctx-len 8192
6173
```
6274

75+
# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag
76+
```bash
77+
python vlm_inference.py \
78+
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
79+
--query "Describe this image" \
80+
--image-url "https://..." \
81+
--ccl-enabled \
82+
--ctx-len 8192
83+
```
84+
6385
## Available Examples
6486

6587
### Text-Only Models

examples/performance/compute_context_length/basic_inference.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def main():
5454
parser.add_argument(
5555
"--comp-ctx-lengths-prefill",
5656
type=lambda x: [int(i) for i in x.split(",")],
57-
default="256,500",
57+
default=None,
5858
help="Comma-separated list of context lengths for prefill phase (e.g., '256,500')",
5959
)
6060
parser.add_argument(
6161
"--comp-ctx-lengths-decode",
6262
type=lambda x: [int(i) for i in x.split(",")],
63-
default="512,1024",
63+
default=None,
6464
help="Comma-separated list of context lengths for decode phase (e.g., '512,1024')",
6565
)
6666
parser.add_argument(
@@ -107,11 +107,7 @@ def main():
107107
args = parser.parse_args()
108108

109109
print(f"Loading model: {args.model_name}")
110-
print("CCL Configuration:")
111-
print(f" - Prefill context lengths: {args.comp_ctx_lengths_prefill}")
112-
print(f" - Decode context lengths: {args.comp_ctx_lengths_decode}")
113-
print(f" - Max context length: {args.ctx_len}")
114-
print(f" - Continuous batching: {args.continuous_batching}")
110+
print(f"Continuous batching: {args.continuous_batching}")
115111

116112
# Load model with CCL configuration
117113
model = QEFFAutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)