55#
66# -----------------------------------------------------------------------------
77
8- from typing import List , Optional , Tuple
8+ from typing import List , Optional , Set , Tuple
99
1010
1111def next_multiple_of_1024 (n : int ) -> int :
@@ -23,57 +23,59 @@ def floor_to_1000(n: int) -> int:
2323
2424
2525def is_power_of_two (n : int ) -> bool :
26- """Return True if n is a power of two (n> 0 and n&(n-1)== 0)."""
26+ """Return True if n is a power of two (n > 0 and n & (n - 1) == 0)."""
2727 return n > 0 and (n & (n - 1 )) == 0
2828
2929
30- def build_doubling_sequence ( start : int , limit : int , max_elements : int , force_last : Optional [ int ] = None ) -> List [ int ] :
30+ def band_index_from_mapped_cl ( mapped_cl : int ) -> int :
3131 """
32- Build an increasing sequence starting at 'start', doubling each step,
33- not exceeding 'limit', with total length <= max_elements.
34- If 'force_last' is provided, ensure the last element equals 'force_last'
35- (replacing/appending as needed), even if it exceeds 'limit' .
32+ Compute band index ∈ {0,1,2} from mapped_cl using bit arithmetic.
33+
34+ Bands (upper bounds): 2^15=32768 → idx=0, 2^16=65536 → idx=1, 2^17=131072 → idx=2.
35+ For mapped_cl > 131072, clamp to idx=2 .
3636 """
37- if max_elements <= 0 :
38- return []
37+ # ceil(log2(mapped_cl)) == bit_length(mapped_cl - 1)
38+ ceil_log2 = (mapped_cl - 1 ).bit_length ()
39+ # map to {0,1,2} by subtracting 15 (the exponent for 32768) and clamping
40+ idx = max (0 , min (2 , ceil_log2 - 15 ))
41+ return idx
3942
40- # If start is already beyond limit, return [force_last or limit] as a single element.
41- if start > limit :
42- seq = [force_last if force_last is not None else limit ]
43- return seq [:max_elements ]
44-
45- seq : List [int ] = []
46- val = start
47-
48- while val <= limit and len (seq ) < max_elements :
49- seq .append (val )
50- next_val = val * 2
51- if next_val > limit or len (seq ) >= max_elements :
52- break
53- val = next_val
54-
55- # Add/replace last element if a 'force_last' is requested
56- if force_last is not None :
57- if len (seq ) == 0 :
58- seq = [force_last ]
59- elif seq [- 1 ] != force_last :
60- if len (seq ) < max_elements :
61- seq .append (force_last )
62- else :
63- seq [- 1 ] = force_last
6443
65- # Deduplicate while preserving order
66- dedup = []
67- seen = set ()
68- for x in seq :
69- if x not in seen :
70- dedup .append (x )
71- seen .add (x )
72- return dedup [:max_elements ]
44+ def build_doubling_set (start : int , limit : int , max_elements : int ) -> Set [int ]:
45+ """
46+ Build a STRICT doubling set: {start, start*2, start*4, ...} up to 'limit',
47+ collecting at most 'max_elements' values. Returns a set; caller will sort.
48+ """
49+ values : Set [int ] = set ()
50+ if max_elements <= 0 or start <= 0 or limit <= 0 :
51+ return values
52+
53+ v = start
54+ while v <= limit and len (values ) < max_elements :
55+ values .add (v )
56+ v *= 2
57+ return values
58+
59+
60+ def ensure_last (sorted_seq : List [int ], last_value : int , max_elements : int ) -> List [int ]:
61+ """
62+ Ensure the last element equals 'last_value' by appending or replacing the final element,
63+ keeping length <= max_elements. If the sequence is empty, return [last_value].
64+ """
65+ if max_elements <= 0 :
66+ return []
67+ if not sorted_seq :
68+ return [last_value ][:max_elements ]
69+ if sorted_seq [- 1 ] != last_value :
70+ if len (sorted_seq ) < max_elements :
71+ sorted_seq .append (last_value )
72+ else :
73+ sorted_seq [- 1 ] = last_value
74+ return sorted_seq [:max_elements ]
7375
7476
75- def Automatic_CCL_Generation (
76- CL : int ,
77+ def automatic_ccl_generation (
78+ ctx_len : int ,
7779 prefill_seq_len : int ,
7880 comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
7981 comp_ctx_lengths_decode : Optional [List [int ]] = None ,
@@ -82,93 +84,102 @@ def Automatic_CCL_Generation(
8284 Automatic Compute-Context-Length Lists Generation
8385
8486 Purpose:
85- Compute decode and prefill ccl lists based on an input context
86- length (CL), prefill sequence length, and optional pre-specified lists.
87+ Compute decode and prefill CCL lists based on an input context length (CL),
88+ prefill sequence length, and optional pre-specified lists.
89+
90+ High-level rules (unchanged from your finalized logic):
91+ - prefill_seq_len > 1:
92+ * If either list is provided, pass them through unchanged.
93+ * decode: doubles from tiered start; MUST end at mapped_CL (last forced to mapped_CL).
94+ * prefill:
95+ • If CL is power of two: STRICT doubling from tiered start, bounded by CL (no forced non-doubling last).
96+ • Else: doubles from tiered start, bounded by CL, and last element = floor_to_1000(mapped_CL).
97+ * Max 5 elements per list.
98+ - prefill_seq_len == 1:
99+ * decode and prefill are IDENTICAL.
100+ * start at 4096, double up to 10 elements.
101+ * upper grid cap computed dynamically (start * 2^(max_elements-1)); last = mapped_CL.
102+ * If mapped_CL < 4096, both lists are [mapped_CL].
87103 """
88-
89- if CL <= 0 :
90- mapped_CL = next_multiple_of_1024 (max (CL , 1 ))
91- # For non-positive CL, minimal identical sequences
92- seq = [mapped_CL ]
93- return seq , seq , mapped_CL
94-
95- mapped_CL = next_multiple_of_1024 (CL )
96-
97- # Tiered starts
98- if mapped_CL <= 4096 :
99- seq = [mapped_CL ]
100- return seq , seq , mapped_CL
101- elif mapped_CL <= 32768 :
102- decode_start , prefill_start = 4096 , 4000
103- elif mapped_CL <= 65536 :
104- decode_start , prefill_start = 8192 , 8000
105- elif mapped_CL <= 131072 :
106- decode_start , prefill_start = 16384 , 16000
107- else :
108- decode_start , prefill_start = 16384 , 16000
109-
110- # If prefill_seq_len > 1:
104+ # Handle non-positive CL
105+ if ctx_len <= 0 :
106+ mapped_cl = next_multiple_of_1024 (1 )
107+ seq = [mapped_cl ]
108+ return seq , seq , mapped_cl
109+
110+ mapped_cl = next_multiple_of_1024 (ctx_len )
111+
112+ # Early small-ctx_len case for identical lists
113+ if mapped_cl <= 4096 :
114+ seq = [mapped_cl ]
115+ return seq , seq , mapped_cl
116+
117+ # Compute tier starts via band index (no hard-coded chain)
118+ idx = band_index_from_mapped_cl (mapped_cl )
119+ decode_start = 4096 << idx # 4096, 8192, 16384
120+ PREFILL_STARTS = {0 : 4000 , 1 : 8000 , 2 : 16000 }
121+ prefill_start = PREFILL_STARTS [idx ]
122+
123+ # Branch: prefill_seq_len > 1
111124 if prefill_seq_len > 1 :
112125 # Passthrough if either provided
113126 if comp_ctx_lengths_decode is not None or comp_ctx_lengths_prefill is not None :
114127 return (
115- comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [],
116128 comp_ctx_lengths_prefill if comp_ctx_lengths_prefill is not None else [],
117- mapped_CL ,
129+ comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [],
130+ mapped_cl ,
118131 )
119132
133+ # Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 5.
120134 max_elems = 5
121135
122- # Decode: ensure last = mapped_CL
123- decode = build_doubling_sequence (
124- start = decode_start ,
125- limit = mapped_CL ,
126- max_elements = max_elems ,
127- force_last = mapped_CL ,
128- )
129-
130- # Prefill:
131- if is_power_of_two (CL ):
132- # Strict doubling, limit = CL, no forced non-doubling last
133- prefill = build_doubling_sequence (
134- start = prefill_start ,
135- limit = CL ,
136- max_elements = max_elems ,
137- force_last = None ,
138- )
136+ # ---- Decode: strict doubling up to mapped_cl, then enforce last = mapped_cl
137+ decode_set = build_doubling_set (start = decode_start , limit = mapped_cl , max_elements = max_elems )
138+ decode_list = sorted (decode_set )
139+ decode_list = ensure_last (decode_list , last_value = mapped_cl , max_elements = max_elems )
140+
141+ # ---- Prefill:
142+ if is_power_of_two (ctx_len ):
143+ # STRICT doubling only, bounded by ctx_len; do NOT force a non-doubling last
144+ prefill_set = build_doubling_set (start = prefill_start , limit = ctx_len , max_elements = max_elems )
145+ prefill_list = sorted (prefill_set )[:max_elems ]
139146 else :
140- prefill_last = floor_to_1000 (mapped_CL )
141- prefill = build_doubling_sequence (
142- start = prefill_start ,
143- limit = CL ,
144- max_elements = max_elems ,
145- force_last = prefill_last ,
146- )
147+ # Doubles bounded by ctx_len, but last must equal floor_to_1000(mapped_cl)
148+ prefill_last = floor_to_1000 (mapped_cl )
149+ prefill_set = build_doubling_set (start = prefill_start , limit = ctx_len , max_elements = max_elems )
150+ prefill_list = sorted (prefill_set )
151+ prefill_list = ensure_last (prefill_list , last_value = prefill_last , max_elements = max_elems )
147152
148- return prefill , decode , mapped_CL
153+ # NOTE: return order preserved from your last snippet (prefill first, then decode)
154+ return prefill_list , decode_list , mapped_cl
149155
150- # UPDATED : prefill_seq_len == 1 → identical lists
156+ # Branch : prefill_seq_len == 1 → identical lists
151157 else :
158+ # When prefill_seq_len=1 such as in MoE models, prefilling and decoding processes can use the same specializations and we can double the length of Ccl lists.
159+ # Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 10.
152160 max_elems = 10
153- grid_cap = 2097152 # upper cap for doubling grid
161+ start_identical = 4096
154162
155- if mapped_CL < 4096 :
156- seq = [mapped_CL ]
157- else :
158- seq = build_doubling_sequence (
159- start = 4096 ,
160- limit = min (mapped_CL , grid_cap ),
161- max_elements = max_elems ,
162- force_last = mapped_CL , # identical lists end at mapped_CL
163- )
164- return seq , seq , mapped_CL
163+ if mapped_cl < start_identical :
164+ seq = [mapped_cl ]
165+ return seq , seq , mapped_cl
166+
167+ # Dynamic grid cap: start * 2^(max_elems - 1)
168+ grid_cap = start_identical * (1 << (max_elems - 1 ))
169+ limit = min (mapped_cl , grid_cap )
170+
171+ seq_set = build_doubling_set (start = start_identical , limit = limit , max_elements = max_elems )
172+ seq_list = sorted (seq_set )
173+ seq_list = ensure_last (seq_list , last_value = mapped_cl , max_elements = max_elems )
174+
175+ return seq_list , seq_list , mapped_cl
165176
166177
167178def process_ccl_specializations (ccl_prefill , ccl_decode , ctx_len , prefill_seq_len ):
168179 # Automatic CCL generation: If both ccl_prefill and ccl_decode are None,
169180 # generate optimized context length lists for prefill and decode based on ctx_len
170181 if ccl_prefill is None and ccl_decode is None :
171- ccl_prefill , ccl_decode , ctx_len = Automatic_CCL_Generation (ctx_len , prefill_seq_len , ccl_prefill , ccl_decode )
182+ ccl_prefill , ccl_decode , ctx_len = automatic_ccl_generation (ctx_len , prefill_seq_len , ccl_prefill , ccl_decode )
172183 else :
173184 if prefill_seq_len == 1 :
174185 if ccl_prefill is not None and ccl_decode is not None :
0 commit comments