Skip to content

Commit eb4f52b

Browse files
committed
Support ir.Expr/ir.Var openmp strings by extending constant inference
1 parent 256fb29 commit eb4f52b

1 file changed

Lines changed: 169 additions & 3 deletions

File tree

src/numba/openmp/omp_lower.py

Lines changed: 169 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3084,6 +3084,159 @@ def remove_ssa_from_func_ir(func_ir):
30843084
func_ir._definitions = build_definitions(func_ir.blocks)
30853085

30863086

3087+
class _ExtendedConstantInference:
3088+
"""
3089+
Extended ConstantInference that supports binop for string concatenation.
3090+
"""
3091+
3092+
def __init__(self, func_ir):
3093+
from numba.core.consts import ConstantInference
3094+
3095+
self._base_inference = ConstantInference(func_ir)
3096+
self._func_ir = func_ir
3097+
3098+
def infer_constant(self, name, loc=None):
3099+
"""Infer a constant value, delegating to base inference first."""
3100+
from numba.core.errors import ConstantInferenceError
3101+
3102+
try:
3103+
return self._base_inference.infer_constant(name, loc=loc)
3104+
except ConstantInferenceError:
3105+
# If base inference fails, check if the variable's definition is
3106+
# an expression we can handle (like binop or call)
3107+
try:
3108+
defn = self._func_ir.get_definition(name)
3109+
if isinstance(defn, ir.Expr):
3110+
if defn.op == "binop":
3111+
return self.infer_expr(defn, loc=loc)
3112+
elif defn.op == "call":
3113+
return self.infer_expr(defn, loc=loc)
3114+
except (KeyError, AttributeError):
3115+
pass
3116+
raise
3117+
3118+
def _infer_value(self, val, loc=None):
3119+
"""
3120+
Infer a constant from a value which might be a variable name or an expression.
3121+
"""
3122+
from numba.core.errors import ConstantInferenceError
3123+
3124+
if isinstance(val, ir.Var):
3125+
return self.infer_constant(val.name, loc=val.loc)
3126+
elif isinstance(val, ir.Expr):
3127+
return self.infer_expr(val, loc=loc)
3128+
elif isinstance(val, str):
3129+
# Direct variable name
3130+
return self.infer_constant(val, loc=loc)
3131+
else:
3132+
raise ConstantInferenceError(f"Cannot infer value for {val}", loc=loc)
3133+
3134+
def infer_expr(self, expr, loc=None):
3135+
"""
3136+
Infer an expression, with added support for binop (string concatenation)
3137+
and format_value() calls.
3138+
"""
3139+
from numba.core.errors import ConstantInferenceError
3140+
3141+
if expr.op == "binop":
3142+
# Support binary operations for string concatenation
3143+
try:
3144+
lhs = self._infer_value(expr.lhs, loc=expr.loc)
3145+
rhs = self._infer_value(expr.rhs, loc=expr.loc)
3146+
# String concatenation
3147+
if isinstance(lhs, str) and isinstance(rhs, str):
3148+
return lhs + rhs
3149+
except ConstantInferenceError:
3150+
raise
3151+
# If it's not string concatenation
3152+
raise ConstantInferenceError(
3153+
f"Cannot infer binop: {lhs!r} + {rhs!r}", loc=expr.loc
3154+
)
3155+
elif expr.op == "call":
3156+
# Handle str() and format_value() calls
3157+
try:
3158+
func = expr.func
3159+
3160+
# Try to infer what function is being called
3161+
func_name = None
3162+
if isinstance(func, ir.Global):
3163+
if func.value is str:
3164+
func_name = "str"
3165+
elif isinstance(func, ir.Var):
3166+
# Try to resolve the variable to see what function it points to
3167+
try:
3168+
func_defn = self._func_ir.get_definition(func.name)
3169+
if isinstance(func_defn, ir.Expr) and func_defn.op == "global":
3170+
if func_defn.value is str:
3171+
func_name = "str"
3172+
elif (
3173+
hasattr(func_defn.value, "__name__")
3174+
and "format_value" in func_defn.value.__name__
3175+
):
3176+
func_name = "format_value"
3177+
else:
3178+
# Check if the variable name itself suggests what it is
3179+
if "format_value" in func.name:
3180+
func_name = "format_value"
3181+
except (KeyError, AttributeError):
3182+
if "format_value" in func.name:
3183+
func_name = "format_value"
3184+
3185+
# Handle str() calls
3186+
if func_name == "str" or (
3187+
isinstance(func, ir.Global) and func.value is str
3188+
):
3189+
if len(expr.args) >= 1:
3190+
arg_val = self._infer_value(expr.args[0], loc=expr.loc)
3191+
return str(arg_val)
3192+
3193+
# Handle format_value calls (used in f-strings)
3194+
if func_name == "format_value":
3195+
if len(expr.args) >= 1:
3196+
arg_val = self._infer_value(expr.args[0], loc=expr.loc)
3197+
return str(arg_val)
3198+
3199+
# If we don't recognize the function, don't try base inference
3200+
raise ConstantInferenceError(
3201+
f"Cannot infer call to unknown function: {func}", loc=expr.loc
3202+
)
3203+
3204+
except ConstantInferenceError:
3205+
raise
3206+
else:
3207+
# Delegate to base inference for other operations
3208+
return self._base_inference._infer_expr(expr)
3209+
3210+
3211+
def _try_infer_string_constant(arg, func_ir):
3212+
"""
3213+
Try to infer a constant string value from an IR node.
3214+
Uses extended ConstantInference that supports binop for string concatenation
3215+
and format_value calls used in f-strings.
3216+
3217+
Returns the string value if resolvable, None otherwise.
3218+
"""
3219+
from numba.core.errors import ConstantInferenceError
3220+
3221+
try:
3222+
inference = _ExtendedConstantInference(func_ir)
3223+
3224+
# For variables, use ConstantInference to resolve them
3225+
if isinstance(arg, ir.Var):
3226+
value = inference.infer_constant(arg.name, loc=arg.loc)
3227+
if isinstance(value, str):
3228+
return value
3229+
# For expressions, try using the extended inference
3230+
elif isinstance(arg, ir.Expr):
3231+
value = inference.infer_expr(arg)
3232+
if isinstance(value, str):
3233+
return value
3234+
except (ConstantInferenceError, AttributeError, NotImplementedError):
3235+
pass
3236+
3237+
return None
3238+
3239+
30873240
def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra):
30883241
"""Given the starting and ending block of the with-context,
30893242
replaces the head block with a new block that has the starting
@@ -3097,14 +3250,16 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
30973250
args = extra["args"]
30983251
arg = args[0]
30993252

3253+
pragma_value = None
3254+
31003255
# If OpenMP argument is not a constant or not a string then raise exception
3101-
# Accept ir.Const, ir.FreeVar, and ir.Global (closure variables)
3102-
if not isinstance(arg, (ir.Const, ir.FreeVar, ir.Global)):
3256+
# Accept ir.Const, ir.FreeVar, ir.Global, ir.Expr, and ir.Var
3257+
if not isinstance(arg, (ir.Const, ir.FreeVar, ir.Global, ir.Expr, ir.Var)):
31033258
raise NonconstantOpenmpSpecification(
31043259
f"Non-constant OpenMP specification at line {arg.loc}"
31053260
)
31063261

3107-
# Extract the actual string value from Const, FreeVar, or Global
3262+
# Extract the actual string value from various IR types
31083263
if isinstance(arg, ir.Const):
31093264
pragma_value = arg.value
31103265
if not isinstance(pragma_value, str):
@@ -3125,6 +3280,17 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
31253280
raise NonStringOpenmpSpecification(
31263281
f"Non-string OpenMP specification at line {arg.loc}"
31273282
)
3283+
else:
3284+
# Handle ir.Var and ir.Expr using ConstantInference
3285+
pragma_value = _try_infer_string_constant(arg, func_ir)
3286+
if pragma_value is None:
3287+
raise NonconstantOpenmpSpecification(
3288+
f"Cannot infer constant OpenMP specification at line {arg.loc}"
3289+
)
3290+
if not isinstance(pragma_value, str):
3291+
raise NonStringOpenmpSpecification(
3292+
f"Non-string OpenMP specification at line {arg.loc}"
3293+
)
31283294

31293295
if DEBUG_OPENMP >= 1:
31303296
print("args:", args, type(args))

0 commit comments

Comments
 (0)