@@ -73,13 +73,13 @@ def backend_module(*args):
7373 if backend == "fa2" :
7474 _kernels = torch .ops .flashinfer_kernels
7575
76- run_func = _kernels .single_prefill_with_kv_cache
76+ run_func = _kernels .single_prefill_with_kv_cache . default
7777 else :
7878 _kernels_sm90 = torch .ops .flashinfer_kernels_sm90
7979
80- run_func = _kernels_sm90 .single_prefill_with_kv_cache_sm90
80+ run_func = _kernels_sm90 .single_prefill_with_kv_cache_sm90 . default
8181 else :
82- run_func = gen_single_prefill_module (backend , * args ).run
82+ run_func = gen_single_prefill_module (backend , * args ).run . default
8383
8484 # torch library for single_prefill_with_kv_cache
8585
@@ -180,24 +180,30 @@ def backend_module(*args):
180180 if backend == "fa2" :
181181 _kernels = torch .ops .flashinfer_kernels
182182
183- plan_func = _kernels .batch_prefill_with_kv_cache_plan
184- ragged_run_func = _kernels .batch_prefill_with_ragged_kv_cache_run
185- paged_run_func = _kernels .batch_prefill_with_paged_kv_cache_run
183+ plan_func = _kernels .batch_prefill_with_kv_cache_plan .default
184+ ragged_run_func = (
185+ _kernels .batch_prefill_with_ragged_kv_cache_run .default
186+ )
187+ paged_run_func = (
188+ _kernels .batch_prefill_with_paged_kv_cache_run .default
189+ )
186190 else :
187191 _kernels_sm90 = torch .ops .flashinfer_kernels_sm90
188192
189- plan_func = _kernels_sm90 .batch_prefill_with_kv_cache_sm90_plan
193+ plan_func = (
194+ _kernels_sm90 .batch_prefill_with_kv_cache_sm90_plan .default
195+ )
190196 ragged_run_func = (
191- _kernels_sm90 .batch_prefill_with_ragged_kv_cache_sm90_run
197+ _kernels_sm90 .batch_prefill_with_ragged_kv_cache_sm90_run . default
192198 )
193199 paged_run_func = (
194- _kernels_sm90 .batch_prefill_with_paged_kv_cache_sm90_run
200+ _kernels_sm90 .batch_prefill_with_paged_kv_cache_sm90_run . default
195201 )
196202 else :
197203 module = gen_batch_prefill_module (backend , * args )
198- plan_func = module .plan
199- ragged_run_func = module .ragged_run
200- paged_run_func = module .paged_run
204+ plan_func = module .plan . default
205+ ragged_run_func = module .ragged_run . default
206+ paged_run_func = module .paged_run . default
201207
202208 # torch library for ragged_run
203209
@@ -437,9 +443,9 @@ def get_batch_prefill_jit_module(module_name: str, jit_module: Any):
437443 if module_name in _batch_prefill_jit_modules :
438444 return _batch_prefill_jit_modules [module_name ]
439445
440- plan_func = jit_module .plan
441- ragged_run_func = jit_module .ragged_run
442- paged_run_func = jit_module .paged_run
446+ plan_func = jit_module .plan . default
447+ ragged_run_func = jit_module .ragged_run . default
448+ paged_run_func = jit_module .paged_run . default
443449
444450 # torch library for ragged_run
445451 @register_custom_op (
@@ -611,7 +617,7 @@ def single_prefill_with_kv_cache_with_jit_module(
611617 lse = torch .empty (
612618 (q .size (0 ), q .size (1 )), dtype = torch .float32 , device = device
613619 )
614- jit_module .run (
620+ jit_module .run . default (
615621 q ,
616622 k ,
617623 v ,
0 commit comments