|
5 | 5 | # |
6 | 6 | # ----------------------------------------------------------------------------- |
7 | 7 |
|
| 8 | +from typing import List, Optional, Tuple |
| 9 | + |
| 10 | + |
| 11 | +def next_multiple_of_1024(n: int) -> int: |
| 12 | + """Ceil 'n' to the next multiple of 1024.""" |
| 13 | + if n <= 0: |
| 14 | + return 0 |
| 15 | + return ((n + 1023) // 1024) * 1024 |
| 16 | + |
| 17 | + |
| 18 | +def floor_to_1000(n: int) -> int: |
| 19 | + """Floor 'n' to the nearest lower multiple of 1000.""" |
| 20 | + if n <= 0: |
| 21 | + return 0 |
| 22 | + return (n // 1000) * 1000 |
| 23 | + |
| 24 | + |
| 25 | +def is_power_of_two(n: int) -> bool: |
| 26 | + """Return True if n is a power of two (n>0 and n&(n-1)==0).""" |
| 27 | + return n > 0 and (n & (n - 1)) == 0 |
| 28 | + |
| 29 | + |
| 30 | +def build_doubling_sequence(start: int, limit: int, max_elements: int, force_last: Optional[int] = None) -> List[int]: |
| 31 | + """ |
| 32 | + Build an increasing sequence starting at 'start', doubling each step, |
| 33 | + not exceeding 'limit', with total length <= max_elements. |
| 34 | + If 'force_last' is provided, ensure the last element equals 'force_last' |
| 35 | + (replacing/appending as needed), even if it exceeds 'limit'. |
| 36 | + """ |
| 37 | + if max_elements <= 0: |
| 38 | + return [] |
| 39 | + |
| 40 | + # If start is already beyond limit, return [force_last or limit] as a single element. |
| 41 | + if start > limit: |
| 42 | + seq = [force_last if force_last is not None else limit] |
| 43 | + return seq[:max_elements] |
| 44 | + |
| 45 | + seq: List[int] = [] |
| 46 | + val = start |
| 47 | + |
| 48 | + while val <= limit and len(seq) < max_elements: |
| 49 | + seq.append(val) |
| 50 | + next_val = val * 2 |
| 51 | + if next_val > limit or len(seq) >= max_elements: |
| 52 | + break |
| 53 | + val = next_val |
| 54 | + |
| 55 | + # Add/replace last element if a 'force_last' is requested |
| 56 | + if force_last is not None: |
| 57 | + if len(seq) == 0: |
| 58 | + seq = [force_last] |
| 59 | + elif seq[-1] != force_last: |
| 60 | + if len(seq) < max_elements: |
| 61 | + seq.append(force_last) |
| 62 | + else: |
| 63 | + seq[-1] = force_last |
| 64 | + |
| 65 | + # Deduplicate while preserving order |
| 66 | + dedup = [] |
| 67 | + seen = set() |
| 68 | + for x in seq: |
| 69 | + if x not in seen: |
| 70 | + dedup.append(x) |
| 71 | + seen.add(x) |
| 72 | + return dedup[:max_elements] |
| 73 | + |
| 74 | + |
| 75 | +def Automatic_CCL_Generation( |
| 76 | + CL: int, |
| 77 | + prefill_seq_len: int, |
| 78 | + comp_ctx_lengths_prefill: Optional[List[int]] = None, |
| 79 | + comp_ctx_lengths_decode: Optional[List[int]] = None, |
| 80 | +) -> Tuple[List[int], List[int], int]: |
| 81 | + """ |
| 82 | + Automatic Compute-Context-Length Lists Generation |
| 83 | +
|
| 84 | + Purpose: |
| 85 | + Compute decode and prefill ccl lists based on an input context |
| 86 | + length (CL), prefill sequence length, and optional pre-specified lists. |
| 87 | + """ |
| 88 | + |
| 89 | + if CL <= 0: |
| 90 | + mapped_CL = next_multiple_of_1024(max(CL, 1)) |
| 91 | + # For non-positive CL, minimal identical sequences |
| 92 | + seq = [mapped_CL] |
| 93 | + return seq, seq, mapped_CL |
| 94 | + |
| 95 | + mapped_CL = next_multiple_of_1024(CL) |
| 96 | + |
| 97 | + # Tiered starts |
| 98 | + if mapped_CL <= 4096: |
| 99 | + seq = [mapped_CL] |
| 100 | + return seq, seq, mapped_CL |
| 101 | + elif mapped_CL <= 32768: |
| 102 | + decode_start, prefill_start = 4096, 4000 |
| 103 | + elif mapped_CL <= 65536: |
| 104 | + decode_start, prefill_start = 8192, 8000 |
| 105 | + elif mapped_CL <= 131072: |
| 106 | + decode_start, prefill_start = 16384, 16000 |
| 107 | + else: |
| 108 | + decode_start, prefill_start = 16384, 16000 |
| 109 | + |
| 110 | + # If prefill_seq_len > 1: |
| 111 | + if prefill_seq_len > 1: |
| 112 | + # Passthrough if either provided |
| 113 | + if comp_ctx_lengths_decode is not None or comp_ctx_lengths_prefill is not None: |
| 114 | + return ( |
| 115 | + comp_ctx_lengths_decode if comp_ctx_lengths_decode is not None else [], |
| 116 | + comp_ctx_lengths_prefill if comp_ctx_lengths_prefill is not None else [], |
| 117 | + mapped_CL, |
| 118 | + ) |
| 119 | + |
| 120 | + max_elems = 5 |
| 121 | + |
| 122 | + # Decode: ensure last = mapped_CL |
| 123 | + decode = build_doubling_sequence( |
| 124 | + start=decode_start, |
| 125 | + limit=mapped_CL, |
| 126 | + max_elements=max_elems, |
| 127 | + force_last=mapped_CL, |
| 128 | + ) |
| 129 | + |
| 130 | + # Prefill: |
| 131 | + if is_power_of_two(CL): |
| 132 | + # Strict doubling, limit = CL, no forced non-doubling last |
| 133 | + prefill = build_doubling_sequence( |
| 134 | + start=prefill_start, |
| 135 | + limit=CL, |
| 136 | + max_elements=max_elems, |
| 137 | + force_last=None, |
| 138 | + ) |
| 139 | + else: |
| 140 | + prefill_last = floor_to_1000(mapped_CL) |
| 141 | + prefill = build_doubling_sequence( |
| 142 | + start=prefill_start, |
| 143 | + limit=CL, |
| 144 | + max_elements=max_elems, |
| 145 | + force_last=prefill_last, |
| 146 | + ) |
| 147 | + |
| 148 | + return prefill, decode, mapped_CL |
| 149 | + |
| 150 | + # UPDATED: prefill_seq_len == 1 → identical lists |
| 151 | + else: |
| 152 | + max_elems = 10 |
| 153 | + grid_cap = 2097152 # upper cap for doubling grid |
| 154 | + |
| 155 | + if mapped_CL < 4096: |
| 156 | + seq = [mapped_CL] |
| 157 | + else: |
| 158 | + seq = build_doubling_sequence( |
| 159 | + start=4096, |
| 160 | + limit=min(mapped_CL, grid_cap), |
| 161 | + max_elements=max_elems, |
| 162 | + force_last=mapped_CL, # identical lists end at mapped_CL |
| 163 | + ) |
| 164 | + return seq, seq, mapped_CL |
| 165 | + |
8 | 166 |
|
9 | 167 | def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len): |
10 | | - if ccl_prefill is None or ccl_decode is None: |
11 | | - return None, None |
12 | | - |
13 | | - if ctx_len is None: |
14 | | - raise TypeError("`ctx_len` is required when loading the model with CCL.") |
15 | | - |
16 | | - if prefill_seq_len == 1: |
17 | | - # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. |
18 | | - ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) |
19 | | - ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] |
20 | | - return ccl_union_all, ccl_union_all |
21 | | - |
22 | | - # Step 1: Cap values to ctx_len |
23 | | - ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] |
24 | | - ccl_decode = [min(x, ctx_len) for x in ccl_decode] |
25 | | - |
26 | | - # Step 2: Remove duplicates within each list |
27 | | - ccl_prefill = list(set(ccl_prefill)) |
28 | | - ccl_decode = list(set(ccl_decode)) |
29 | | - |
30 | | - # Step 3: Ensure no overlap between ccl_prefill and ccl_decode |
31 | | - updated_prefill = [] |
32 | | - for val in ccl_prefill: |
33 | | - while val in ccl_decode or val in updated_prefill: |
34 | | - val -= 1 |
35 | | - if val < 0: |
36 | | - break # Prevent negative values |
37 | | - if val >= 0: |
38 | | - updated_prefill.append(val) |
39 | | - |
40 | | - # Step 4: Sort both lists |
41 | | - updated_prefill.sort() |
42 | | - ccl_decode.sort() |
43 | | - |
44 | | - return updated_prefill, ccl_decode |
| 168 | + # Automatic CCL generation: If both ccl_prefill and ccl_decode are None, |
| 169 | + # generate optimized context length lists for prefill and decode based on ctx_len |
| 170 | + if ccl_prefill is None and ccl_decode is None: |
| 171 | + ccl_prefill, ccl_decode, ctx_len = Automatic_CCL_Generation(ctx_len, prefill_seq_len, ccl_prefill, ccl_decode) |
| 172 | + else: |
| 173 | + if prefill_seq_len == 1: |
| 174 | + if ccl_prefill is not None and ccl_decode is not None: |
| 175 | + # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. |
| 176 | + ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) |
| 177 | + ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] |
| 178 | + ccl_prefill = ccl_union_all |
| 179 | + ccl_decode = ccl_union_all |
| 180 | + else: |
| 181 | + # Step 1: Cap values to ctx_len |
| 182 | + ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] if ccl_prefill is not None else None |
| 183 | + ccl_decode = [min(x, ctx_len) for x in ccl_decode] if ccl_decode is not None else None |
| 184 | + |
| 185 | + # Step 2: Remove duplicates within each list |
| 186 | + ccl_prefill = list(set(ccl_prefill)) if ccl_prefill is not None else None |
| 187 | + ccl_decode = list(set(ccl_decode)) if ccl_decode is not None else None |
| 188 | + |
| 189 | + if ccl_prefill is None or ccl_decode is None: |
| 190 | + if ccl_prefill: |
| 191 | + ccl_prefill.sort() |
| 192 | + if ccl_decode: |
| 193 | + ccl_decode.sort() |
| 194 | + else: |
| 195 | + # Step 3: Ensure no overlap between ccl_prefill and ccl_decode |
| 196 | + tmp_prefill = ccl_prefill |
| 197 | + ccl_prefill = [] |
| 198 | + for val in tmp_prefill: |
| 199 | + while val in ccl_decode or val in ccl_prefill: |
| 200 | + val -= 1 |
| 201 | + if val < 0: |
| 202 | + break # Prevent negative values |
| 203 | + if val >= 0: |
| 204 | + ccl_prefill.append(val) |
| 205 | + |
| 206 | + # Step 4: Sort both lists |
| 207 | + ccl_prefill.sort() |
| 208 | + ccl_decode.sort() |
| 209 | + |
| 210 | + print("CCL Configuration:") |
| 211 | + print(f" - Prefill context lengths: {ccl_prefill}") |
| 212 | + print(f" - Decode context lengths: {ccl_decode}") |
| 213 | + print(f" - Max context length: {ctx_len}") |
| 214 | + return ccl_prefill, ccl_decode, ctx_len |
0 commit comments