Skip to content

Commit b769fc0

Browse files
committed
Add automatic CCL list generation for prefill and decode when user does not provide lists
Signed-off-by: Vahid Janfaza <[email protected]>
1 parent 2788e6e commit b769fc0

File tree

2 files changed

+122
-111
lines changed

2 files changed

+122
-111
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def compile(
11271127
# if ccl_enabled is True read Compute-Context-Length lists
11281128
if self.ccl_enabled:
11291129
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
1130-
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
1130+
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
11311131
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
11321132
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
11331133
)
@@ -1772,7 +1772,7 @@ def compile(
17721772
# if ccl_enabled is True read Compute-Context-Length lists
17731773
if self.ccl_enabled:
17741774
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
1775-
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
1775+
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
17761776
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
17771777
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
17781778
)
@@ -2868,7 +2868,7 @@ def compile(
28682868
# if ccl_enabled is True read Compute-Context-Length lists
28692869
if self.ccl_enabled:
28702870
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
2871-
print("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
2871+
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
28722872
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
28732873
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
28742874
)

QEfficient/utils/check_ccl_specializations.py

Lines changed: 119 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Optional, Tuple
8+
from typing import List, Optional, Set, Tuple
99

1010

1111
def next_multiple_of_1024(n: int) -> int:
@@ -23,57 +23,59 @@ def floor_to_1000(n: int) -> int:
2323

2424

2525
def is_power_of_two(n: int) -> bool:
26-
"""Return True if n is a power of two (n>0 and n&(n-1)==0)."""
26+
"""Return True if n is a power of two (n > 0 and n & (n - 1) == 0)."""
2727
return n > 0 and (n & (n - 1)) == 0
2828

2929

30-
def build_doubling_sequence(start: int, limit: int, max_elements: int, force_last: Optional[int] = None) -> List[int]:
30+
def band_index_from_mapped_cl(mapped_cl: int) -> int:
3131
"""
32-
Build an increasing sequence starting at 'start', doubling each step,
33-
not exceeding 'limit', with total length <= max_elements.
34-
If 'force_last' is provided, ensure the last element equals 'force_last'
35-
(replacing/appending as needed), even if it exceeds 'limit'.
32+
Compute band index ∈ {0,1,2} from mapped_cl using bit arithmetic.
33+
34+
Bands (upper bounds): 2^15=32768 → idx=0, 2^16=65536 → idx=1, 2^17=131072 → idx=2.
35+
For mapped_cl > 131072, clamp to idx=2.
3636
"""
37-
if max_elements <= 0:
38-
return []
37+
# ceil(log2(mapped_cl)) == bit_length(mapped_cl - 1)
38+
ceil_log2 = (mapped_cl - 1).bit_length()
39+
# map to {0,1,2} by subtracting 15 (the exponent for 32768) and clamping
40+
idx = max(0, min(2, ceil_log2 - 15))
41+
return idx
3942

40-
# If start is already beyond limit, return [force_last or limit] as a single element.
41-
if start > limit:
42-
seq = [force_last if force_last is not None else limit]
43-
return seq[:max_elements]
44-
45-
seq: List[int] = []
46-
val = start
47-
48-
while val <= limit and len(seq) < max_elements:
49-
seq.append(val)
50-
next_val = val * 2
51-
if next_val > limit or len(seq) >= max_elements:
52-
break
53-
val = next_val
54-
55-
# Add/replace last element if a 'force_last' is requested
56-
if force_last is not None:
57-
if len(seq) == 0:
58-
seq = [force_last]
59-
elif seq[-1] != force_last:
60-
if len(seq) < max_elements:
61-
seq.append(force_last)
62-
else:
63-
seq[-1] = force_last
6443

65-
# Deduplicate while preserving order
66-
dedup = []
67-
seen = set()
68-
for x in seq:
69-
if x not in seen:
70-
dedup.append(x)
71-
seen.add(x)
72-
return dedup[:max_elements]
44+
def build_doubling_set(start: int, limit: int, max_elements: int) -> Set[int]:
45+
"""
46+
Build a STRICT doubling set: {start, start*2, start*4, ...} up to 'limit',
47+
collecting at most 'max_elements' values. Returns a set; caller will sort.
48+
"""
49+
values: Set[int] = set()
50+
if max_elements <= 0 or start <= 0 or limit <= 0:
51+
return values
52+
53+
v = start
54+
while v <= limit and len(values) < max_elements:
55+
values.add(v)
56+
v *= 2
57+
return values
58+
59+
60+
def ensure_last(sorted_seq: List[int], last_value: int, max_elements: int) -> List[int]:
61+
"""
62+
Ensure the last element equals 'last_value' by appending or replacing the final element,
63+
keeping length <= max_elements. If the sequence is empty, return [last_value].
64+
"""
65+
if max_elements <= 0:
66+
return []
67+
if not sorted_seq:
68+
return [last_value][:max_elements]
69+
if sorted_seq[-1] != last_value:
70+
if len(sorted_seq) < max_elements:
71+
sorted_seq.append(last_value)
72+
else:
73+
sorted_seq[-1] = last_value
74+
return sorted_seq[:max_elements]
7375

7476

75-
def Automatic_CCL_Generation(
76-
CL: int,
77+
def automatic_ccl_generation(
78+
ctx_len: int,
7779
prefill_seq_len: int,
7880
comp_ctx_lengths_prefill: Optional[List[int]] = None,
7981
comp_ctx_lengths_decode: Optional[List[int]] = None,
@@ -82,93 +84,102 @@ def Automatic_CCL_Generation(
8284
Automatic Compute-Context-Length Lists Generation
8385
8486
Purpose:
85-
Compute decode and prefill ccl lists based on an input context
86-
length (CL), prefill sequence length, and optional pre-specified lists.
87+
Compute decode and prefill CCL lists based on an input context length (CL),
88+
prefill sequence length, and optional pre-specified lists.
89+
90+
High-level rules (unchanged from your finalized logic):
91+
- prefill_seq_len > 1:
92+
* If either list is provided, pass them through unchanged.
93+
* decode: doubles from tiered start; MUST end at mapped_CL (last forced to mapped_CL).
94+
* prefill:
95+
• If CL is power of two: STRICT doubling from tiered start, bounded by CL (no forced non-doubling last).
96+
• Else: doubles from tiered start, bounded by CL, and last element = floor_to_1000(mapped_CL).
97+
* Max 5 elements per list.
98+
- prefill_seq_len == 1:
99+
* decode and prefill are IDENTICAL.
100+
* start at 4096, double up to 10 elements.
101+
* upper grid cap computed dynamically (start * 2^(max_elements-1)); last = mapped_CL.
102+
* If mapped_CL < 4096, both lists are [mapped_CL].
87103
"""
88-
89-
if CL <= 0:
90-
mapped_CL = next_multiple_of_1024(max(CL, 1))
91-
# For non-positive CL, minimal identical sequences
92-
seq = [mapped_CL]
93-
return seq, seq, mapped_CL
94-
95-
mapped_CL = next_multiple_of_1024(CL)
96-
97-
# Tiered starts
98-
if mapped_CL <= 4096:
99-
seq = [mapped_CL]
100-
return seq, seq, mapped_CL
101-
elif mapped_CL <= 32768:
102-
decode_start, prefill_start = 4096, 4000
103-
elif mapped_CL <= 65536:
104-
decode_start, prefill_start = 8192, 8000
105-
elif mapped_CL <= 131072:
106-
decode_start, prefill_start = 16384, 16000
107-
else:
108-
decode_start, prefill_start = 16384, 16000
109-
110-
# If prefill_seq_len > 1:
104+
# Handle non-positive CL
105+
if ctx_len <= 0:
106+
mapped_cl = next_multiple_of_1024(1)
107+
seq = [mapped_cl]
108+
return seq, seq, mapped_cl
109+
110+
mapped_cl = next_multiple_of_1024(ctx_len)
111+
112+
# Early small-ctx_len case for identical lists
113+
if mapped_cl <= 4096:
114+
seq = [mapped_cl]
115+
return seq, seq, mapped_cl
116+
117+
# Compute tier starts via band index (no hard-coded chain)
118+
idx = band_index_from_mapped_cl(mapped_cl)
119+
decode_start = 4096 << idx # 4096, 8192, 16384
120+
PREFILL_STARTS = {0: 4000, 1: 8000, 2: 16000}
121+
prefill_start = PREFILL_STARTS[idx]
122+
123+
# Branch: prefill_seq_len > 1
111124
if prefill_seq_len > 1:
112125
# Passthrough if either provided
113126
if comp_ctx_lengths_decode is not None or comp_ctx_lengths_prefill is not None:
114127
return (
115-
comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [],
116128
comp_ctx_lengths_prefill if comp_ctx_lengths_prefill is not None else [],
117-
mapped_CL,
129+
comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [],
130+
mapped_cl,
118131
)
119132

133+
# Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 5.
120134
max_elems = 5
121135

122-
# Decode: ensure last = mapped_CL
123-
decode = build_doubling_sequence(
124-
start=decode_start,
125-
limit=mapped_CL,
126-
max_elements=max_elems,
127-
force_last=mapped_CL,
128-
)
129-
130-
# Prefill:
131-
if is_power_of_two(CL):
132-
# Strict doubling, limit = CL, no forced non-doubling last
133-
prefill = build_doubling_sequence(
134-
start=prefill_start,
135-
limit=CL,
136-
max_elements=max_elems,
137-
force_last=None,
138-
)
136+
# ---- Decode: strict doubling up to mapped_cl, then enforce last = mapped_cl
137+
decode_set = build_doubling_set(start=decode_start, limit=mapped_cl, max_elements=max_elems)
138+
decode_list = sorted(decode_set)
139+
decode_list = ensure_last(decode_list, last_value=mapped_cl, max_elements=max_elems)
140+
141+
# ---- Prefill:
142+
if is_power_of_two(ctx_len):
143+
# STRICT doubling only, bounded by ctx_len; do NOT force a non-doubling last
144+
prefill_set = build_doubling_set(start=prefill_start, limit=ctx_len, max_elements=max_elems)
145+
prefill_list = sorted(prefill_set)[:max_elems]
139146
else:
140-
prefill_last = floor_to_1000(mapped_CL)
141-
prefill = build_doubling_sequence(
142-
start=prefill_start,
143-
limit=CL,
144-
max_elements=max_elems,
145-
force_last=prefill_last,
146-
)
147+
# Doubles bounded by ctx_len, but last must equal floor_to_1000(mapped_cl)
148+
prefill_last = floor_to_1000(mapped_cl)
149+
prefill_set = build_doubling_set(start=prefill_start, limit=ctx_len, max_elements=max_elems)
150+
prefill_list = sorted(prefill_set)
151+
prefill_list = ensure_last(prefill_list, last_value=prefill_last, max_elements=max_elems)
147152

148-
return prefill, decode, mapped_CL
153+
# NOTE: return order preserved from your last snippet (prefill first, then decode)
154+
return prefill_list, decode_list, mapped_cl
149155

150-
# UPDATED: prefill_seq_len == 1 → identical lists
156+
# Branch: prefill_seq_len == 1 → identical lists
151157
else:
158+
# When prefill_seq_len=1 such as in MoE models, prefilling and decoding processes can use the same specializations and we can double the length of Ccl lists.
159+
# Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 10.
152160
max_elems = 10
153-
grid_cap = 2097152 # upper cap for doubling grid
161+
start_identical = 4096
154162

155-
if mapped_CL < 4096:
156-
seq = [mapped_CL]
157-
else:
158-
seq = build_doubling_sequence(
159-
start=4096,
160-
limit=min(mapped_CL, grid_cap),
161-
max_elements=max_elems,
162-
force_last=mapped_CL, # identical lists end at mapped_CL
163-
)
164-
return seq, seq, mapped_CL
163+
if mapped_cl < start_identical:
164+
seq = [mapped_cl]
165+
return seq, seq, mapped_cl
166+
167+
# Dynamic grid cap: start * 2^(max_elems - 1)
168+
grid_cap = start_identical * (1 << (max_elems - 1))
169+
limit = min(mapped_cl, grid_cap)
170+
171+
seq_set = build_doubling_set(start=start_identical, limit=limit, max_elements=max_elems)
172+
seq_list = sorted(seq_set)
173+
seq_list = ensure_last(seq_list, last_value=mapped_cl, max_elements=max_elems)
174+
175+
return seq_list, seq_list, mapped_cl
165176

166177

167178
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
168179
# Automatic CCL generation: If both ccl_prefill and ccl_decode are None,
169180
# generate optimized context length lists for prefill and decode based on ctx_len
170181
if ccl_prefill is None and ccl_decode is None:
171-
ccl_prefill, ccl_decode, ctx_len = Automatic_CCL_Generation(ctx_len, prefill_seq_len, ccl_prefill, ccl_decode)
182+
ccl_prefill, ccl_decode, ctx_len = automatic_ccl_generation(ctx_len, prefill_seq_len, ccl_prefill, ccl_decode)
172183
else:
173184
if prefill_seq_len == 1:
174185
if ccl_prefill is not None and ccl_decode is not None:

0 commit comments

Comments
 (0)