@@ -3084,6 +3084,159 @@ def remove_ssa_from_func_ir(func_ir):
30843084 func_ir ._definitions = build_definitions (func_ir .blocks )
30853085
30863086
3087+ class _ExtendedConstantInference :
3088+ """
3089+ Extended ConstantInference that supports binop for string concatenation.
3090+ """
3091+
3092+ def __init__ (self , func_ir ):
3093+ from numba .core .consts import ConstantInference
3094+
3095+ self ._base_inference = ConstantInference (func_ir )
3096+ self ._func_ir = func_ir
3097+
3098+ def infer_constant (self , name , loc = None ):
3099+ """Infer a constant value, delegating to base inference first."""
3100+ from numba .core .errors import ConstantInferenceError
3101+
3102+ try :
3103+ return self ._base_inference .infer_constant (name , loc = loc )
3104+ except ConstantInferenceError :
3105+ # If base inference fails, check if the variable's definition is
3106+ # an expression we can handle (like binop or call)
3107+ try :
3108+ defn = self ._func_ir .get_definition (name )
3109+ if isinstance (defn , ir .Expr ):
3110+ if defn .op == "binop" :
3111+ return self .infer_expr (defn , loc = loc )
3112+ elif defn .op == "call" :
3113+ return self .infer_expr (defn , loc = loc )
3114+ except (KeyError , AttributeError ):
3115+ pass
3116+ raise
3117+
3118+ def _infer_value (self , val , loc = None ):
3119+ """
3120+ Infer a constant from a value which might be a variable name or an expression.
3121+ """
3122+ from numba .core .errors import ConstantInferenceError
3123+
3124+ if isinstance (val , ir .Var ):
3125+ return self .infer_constant (val .name , loc = val .loc )
3126+ elif isinstance (val , ir .Expr ):
3127+ return self .infer_expr (val , loc = loc )
3128+ elif isinstance (val , str ):
3129+ # Direct variable name
3130+ return self .infer_constant (val , loc = loc )
3131+ else :
3132+ raise ConstantInferenceError (f"Cannot infer value for { val } " , loc = loc )
3133+
3134+ def infer_expr (self , expr , loc = None ):
3135+ """
3136+ Infer an expression, with added support for binop (string concatenation)
3137+ and format_value() calls.
3138+ """
3139+ from numba .core .errors import ConstantInferenceError
3140+
3141+ if expr .op == "binop" :
3142+ # Support binary operations for string concatenation
3143+ try :
3144+ lhs = self ._infer_value (expr .lhs , loc = expr .loc )
3145+ rhs = self ._infer_value (expr .rhs , loc = expr .loc )
3146+ # String concatenation
3147+ if isinstance (lhs , str ) and isinstance (rhs , str ):
3148+ return lhs + rhs
3149+ except ConstantInferenceError :
3150+ raise
3151+ # If it's not string concatenation
3152+ raise ConstantInferenceError (
3153+ f"Cannot infer binop: { lhs !r} + { rhs !r} " , loc = expr .loc
3154+ )
3155+ elif expr .op == "call" :
3156+ # Handle str() and format_value() calls
3157+ try :
3158+ func = expr .func
3159+
3160+ # Try to infer what function is being called
3161+ func_name = None
3162+ if isinstance (func , ir .Global ):
3163+ if func .value is str :
3164+ func_name = "str"
3165+ elif isinstance (func , ir .Var ):
3166+ # Try to resolve the variable to see what function it points to
3167+ try :
3168+ func_defn = self ._func_ir .get_definition (func .name )
3169+ if isinstance (func_defn , ir .Expr ) and func_defn .op == "global" :
3170+ if func_defn .value is str :
3171+ func_name = "str"
3172+ elif (
3173+ hasattr (func_defn .value , "__name__" )
3174+ and "format_value" in func_defn .value .__name__
3175+ ):
3176+ func_name = "format_value"
3177+ else :
3178+ # Check if the variable name itself suggests what it is
3179+ if "format_value" in func .name :
3180+ func_name = "format_value"
3181+ except (KeyError , AttributeError ):
3182+ if "format_value" in func .name :
3183+ func_name = "format_value"
3184+
3185+ # Handle str() calls
3186+ if func_name == "str" or (
3187+ isinstance (func , ir .Global ) and func .value is str
3188+ ):
3189+ if len (expr .args ) >= 1 :
3190+ arg_val = self ._infer_value (expr .args [0 ], loc = expr .loc )
3191+ return str (arg_val )
3192+
3193+ # Handle format_value calls (used in f-strings)
3194+ if func_name == "format_value" :
3195+ if len (expr .args ) >= 1 :
3196+ arg_val = self ._infer_value (expr .args [0 ], loc = expr .loc )
3197+ return str (arg_val )
3198+
3199+ # If we don't recognize the function, don't try base inference
3200+ raise ConstantInferenceError (
3201+ f"Cannot infer call to unknown function: { func } " , loc = expr .loc
3202+ )
3203+
3204+ except ConstantInferenceError :
3205+ raise
3206+ else :
3207+ # Delegate to base inference for other operations
3208+ return self ._base_inference ._infer_expr (expr )
3209+
3210+
3211+ def _try_infer_string_constant (arg , func_ir ):
3212+ """
3213+ Try to infer a constant string value from an IR node.
3214+ Uses extended ConstantInference that supports binop for string concatenation
3215+ and format_value calls used in f-strings.
3216+
3217+ Returns the string value if resolvable, None otherwise.
3218+ """
3219+ from numba .core .errors import ConstantInferenceError
3220+
3221+ try :
3222+ inference = _ExtendedConstantInference (func_ir )
3223+
3224+ # For variables, use ConstantInference to resolve them
3225+ if isinstance (arg , ir .Var ):
3226+ value = inference .infer_constant (arg .name , loc = arg .loc )
3227+ if isinstance (value , str ):
3228+ return value
3229+ # For expressions, try using the extended inference
3230+ elif isinstance (arg , ir .Expr ):
3231+ value = inference .infer_expr (arg )
3232+ if isinstance (value , str ):
3233+ return value
3234+ except (ConstantInferenceError , AttributeError , NotImplementedError ):
3235+ pass
3236+
3237+ return None
3238+
3239+
30873240def _add_openmp_ir_nodes (func_ir , blocks , blk_start , blk_end , body_blocks , extra ):
30883241 """Given the starting and ending block of the with-context,
30893242 replaces the head block with a new block that has the starting
@@ -3097,14 +3250,16 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
30973250 args = extra ["args" ]
30983251 arg = args [0 ]
30993252
3253+ pragma_value = None
3254+
31003255 # If OpenMP argument is not a constant or not a string then raise exception
3101- # Accept ir.Const, ir.FreeVar, and ir.Global (closure variables)
3102- if not isinstance (arg , (ir .Const , ir .FreeVar , ir .Global )):
3256+ # Accept ir.Const, ir.FreeVar, ir.Global, ir.Expr, and ir.Var
3257+ if not isinstance (arg , (ir .Const , ir .FreeVar , ir .Global , ir . Expr , ir . Var )):
31033258 raise NonconstantOpenmpSpecification (
31043259 f"Non-constant OpenMP specification at line { arg .loc } "
31053260 )
31063261
3107- # Extract the actual string value from Const, FreeVar, or Global
3262+ # Extract the actual string value from various IR types
31083263 if isinstance (arg , ir .Const ):
31093264 pragma_value = arg .value
31103265 if not isinstance (pragma_value , str ):
@@ -3125,6 +3280,17 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
31253280 raise NonStringOpenmpSpecification (
31263281 f"Non-string OpenMP specification at line { arg .loc } "
31273282 )
3283+ else :
3284+ # Handle ir.Var and ir.Expr using ConstantInference
3285+ pragma_value = _try_infer_string_constant (arg , func_ir )
3286+ if pragma_value is None :
3287+ raise NonconstantOpenmpSpecification (
3288+ f"Cannot infer constant OpenMP specification at line { arg .loc } "
3289+ )
3290+ if not isinstance (pragma_value , str ):
3291+ raise NonStringOpenmpSpecification (
3292+ f"Non-string OpenMP specification at line { arg .loc } "
3293+ )
31283294
31293295 if DEBUG_OPENMP >= 1 :
31303296 print ("args:" , args , type (args ))
0 commit comments