@@ -937,11 +937,13 @@ def __init__(
937937 self .model = model
938938 self .config = model .config
939939
940- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
941-
942940 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
943941 self .lang_model = QEffCausalLMForTextImageToTextModel (model , qaic_config = qaic_config , ** kwargs )
944942 self .continuous_batching = continuous_batching
943+ self .ccl_enabled = False
944+ if qaic_config :
945+ self .ccl_enabled = qaic_config .get ("ccl_enabled" , False )
946+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
945947 self .input_shapes , self .output_names = None , None
946948
947949 @property
@@ -985,6 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
985987 logger .warning ("Updating low_cpu_mem_usage=False" )
986988
987989 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
990+
988991 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
989992 return cls (
990993 model ,
@@ -1095,6 +1098,8 @@ def compile(
10951098 compile_dir : Optional [str ] = None ,
10961099 * ,
10971100 prefill_seq_len : Optional [int ] = None ,
1101+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1102+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
10981103 ctx_len : Optional [int ] = None ,
10991104 batch_size : int = 1 ,
11001105 full_batch_size : Optional [int ] = None ,
@@ -1179,10 +1184,21 @@ def compile(
11791184
11801185 output_names = self .model .get_output_names (kv_offload = True )
11811186
1187+ # if ccl_enabled is True read Compute-Context-Length lists
1188+ if self .ccl_enabled :
1189+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
1190+ logger .warning (
1191+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1192+ )
1193+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1194+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1195+ )
1196+
11821197 # For supporting VLLM and Disaggregated with CCL
1183- if "comp_ctx_lengths_prefill" in compiler_options :
1184- self .comp_ctx_lengths_prefill = compiler_options .pop ("comp_ctx_lengths_prefill" )
1185- self .comp_ctx_lengths_decode = compiler_options .pop ("comp_ctx_lengths_decode" )
1198+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1199+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1200+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1201+ )
11861202
11871203 specializations , compiler_options = self .model .get_specializations (
11881204 batch_size = batch_size ,
@@ -1634,7 +1650,6 @@ def __init__(
16341650 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
16351651 super ().__init__ (model , ** kwargs )
16361652
1637- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
16381653 self .model .qaic_config = qaic_config
16391654
16401655 # to handle internvl models
@@ -1648,6 +1663,10 @@ def __init__(
16481663 else :
16491664 self .model .config .use_cache = True
16501665 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
1666+ self .ccl_enabled = False
1667+ if qaic_config :
1668+ self .ccl_enabled = qaic_config .get ("ccl_enabled" , False )
1669+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
16511670
16521671 if self .model .qaic_config is not None and self .model .qaic_config .get ("num_kv_blocks" , None ) is not None :
16531672 BlockedKVAttentionTransform .apply (self .model , num_kv_blocks = self .model .qaic_config .get ("num_kv_blocks" ))
@@ -1687,6 +1706,7 @@ def from_pretrained(
16871706 logger .warning ("Updating low_cpu_mem_usage=False" )
16881707
16891708 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
1709+
16901710 from transformers import AutoConfig
16911711
16921712 config = AutoConfig .from_pretrained (pretrained_model_name_or_path , trust_remote_code = True )
@@ -1741,6 +1761,8 @@ def compile(
17411761 * ,
17421762 prefill_seq_len : Optional [int ] = None ,
17431763 ctx_len : Optional [int ] = None ,
1764+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1765+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
17441766 batch_size : int = 1 ,
17451767 full_batch_size : Optional [int ] = None ,
17461768 kv_cache_batch_size : Optional [int ] = None ,
@@ -1810,10 +1832,21 @@ def compile(
18101832 kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
18111833 output_names = self .model .get_output_names ()
18121834
1835+ # if ccl_enabled is True read Compute-Context-Length lists
1836+ if self .ccl_enabled :
1837+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
1838+ logger .warning (
1839+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1840+ )
1841+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1842+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1843+ )
1844+
18131845 # For supporting VLLM and Disaggregated with CCL
1814- if "comp_ctx_lengths_prefill" in compiler_options :
1815- self .comp_ctx_lengths_prefill = compiler_options .pop ("comp_ctx_lengths_prefill" )
1816- self .comp_ctx_lengths_decode = compiler_options .pop ("comp_ctx_lengths_decode" )
1846+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1847+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1848+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1849+ )
18171850
18181851 # Get specializations from modelling file
18191852 # TODO: expose this via the auto class as well
@@ -2378,8 +2411,6 @@ def __init__(
23782411 # Set use_cache=True to get KV values as output during ONNX export
23792412 model .config .use_cache = True
23802413
2381- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
2382-
23832414 super ().__init__ (model , qaic_config = qaic_config , ** kwargs )
23842415 self .num_layers = model .config .num_hidden_layers
23852416 self .continuous_batching = continuous_batching
@@ -2388,6 +2419,10 @@ def __init__(
23882419 self .is_tlm = transformed
23892420
23902421 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
2422+ self .ccl_enabled = False
2423+ if qaic_config :
2424+ self .ccl_enabled = qaic_config .get ("ccl_enabled" , False )
2425+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
23912426
23922427 # ---Sampling---
23932428 # Note: SamplerTransform should be applied after all other transforms
@@ -2833,6 +2868,8 @@ def compile(
28332868 * ,
28342869 prefill_seq_len : int = 32 ,
28352870 ctx_len : int = 128 ,
2871+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
2872+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
28362873 batch_size : int = 1 ,
28372874 full_batch_size : Optional [int ] = None ,
28382875 kv_cache_batch_size : Optional [int ] = None ,
@@ -2924,10 +2961,18 @@ def compile(
29242961
29252962 """
29262963
2964+ # if ccl_enabled is True read Compute-Context-Length lists
2965+ if self .ccl_enabled :
2966+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
2967+ logger .warning (
2968+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
2969+ )
2970+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
2971+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
2972+ )
2973+
29272974 # For supporting VLLM and Disaggregated with CCL
2928- if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options :
2929- comp_ctx_lengths_prefill = compiler_options .pop ("comp_ctx_lengths_prefill" )
2930- comp_ctx_lengths_decode = compiler_options .pop ("comp_ctx_lengths_decode" )
2975+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
29312976 if isinstance (comp_ctx_lengths_prefill , str ):
29322977 import ast
29332978
@@ -2942,6 +2987,9 @@ def compile(
29422987 self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
29432988 self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
29442989
2990+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
2991+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len , prefill_seq_len
2992+ )
29452993 # --- Validation ---
29462994 if prefill_only is not None and not isinstance (prefill_only , bool ):
29472995 raise TypeError ("`prefill_only` must be a boolean." )
0 commit comments