diff --git a/src/kirin/dialects/math/__init__.py b/src/kirin/dialects/math/__init__.py index 5facdaac3..c67856ce2 100644 --- a/src/kirin/dialects/math/__init__.py +++ b/src/kirin/dialects/math/__init__.py @@ -104,6 +104,10 @@ def isnan(x: float) -> bool: ... def lgamma(x: float) -> float: ... +@lowering.wraps(stmts.log) +def log(x: float, base: float) -> float: ... + + @lowering.wraps(stmts.log10) def log10(x: float) -> float: ... diff --git a/src/kirin/dialects/math/_gen.py b/src/kirin/dialects/math/_gen.py index 63be0b4d9..78d637ba4 100644 --- a/src/kirin/dialects/math/_gen.py +++ b/src/kirin/dialects/math/_gen.py @@ -32,6 +32,8 @@ def builtin_math_functions(): # 3.10 compat "cbrt", "exp2", + # 3.13 compat + "fma", ): continue @@ -40,12 +42,32 @@ def builtin_math_functions(): sig = inspect.signature(obj) yield name, obj, sig except: # noqa: E722 - continue + if name == "log": + sig = inspect.Signature( + parameters=[ + inspect.Parameter( + "x", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=object, + ), + inspect.Parameter( + "base", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=object, + ), + ], + return_annotation=object, + ) + yield name, obj, sig + + else: + continue with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f: f.write("# This file is generated by gen.py\n") - f.write("from kirin import ir, types, lowering2\n") + f.write("from kirin import ir, types, lowering\n") f.write("from kirin.decl import statement, info\n") f.write("from kirin.dialects.math.dialect import dialect\n") f.write("\n") @@ -58,6 +80,8 @@ def builtin_math_functions(): ) if "is" in name: ret_type = "types.Bool" + elif name in {"trunc", "ceil", "floor"}: + ret_type = "types.Int" else: ret_type = "types.Float" f.write(textwrap.dedent(f""" @@ -66,7 +90,7 @@ class {name}(ir.Statement): \"\"\"{name} statement, wrapping the math.{name} function \"\"\" name = "{name}" - traits = frozenset({{ir.Pure(), lowering2.FromPythonCall()}}) + traits = frozenset({{ir.Pure(), lowering.FromPythonCall()}}) {fields} result: ir.ResultValue = info.result({ret_type}) """)) @@ -109,7 +133,7 @@ class MathMethodTable(MethodTable): f.write("pi = pymath.pi\n") f.write("e = pymath.e\n") f.write("tau = pymath.tau\n") - f.write("from kirin import lowering2\n") + f.write("from kirin import lowering\n") for name, obj, sig in builtin_math_functions(): if "is" in name: @@ -119,8 +143,8 @@ class MathMethodTable(MethodTable): else: ret_type = "float" f.write(textwrap.dedent(f""" - @lowering2.wraps(stmts.{name}) - def {name}({", ".join(f"{arg}: {ret_type}" for arg in sig.parameters.keys())}) -> {ret_type}: ... + @lowering.wraps(stmts.{name}) + def {name}({", ".join(f"{arg}: float" for arg in sig.parameters.keys())}) -> {ret_type}: ... """)) f.write("\n") diff --git a/src/kirin/dialects/math/interp.py b/src/kirin/dialects/math/interp.py index f99d713a1..2ca9409c3 100644 --- a/src/kirin/dialects/math/interp.py +++ b/src/kirin/dialects/math/interp.py @@ -124,6 +124,11 @@ def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma): values = frame.get_values(stmt.args) return (math.lgamma(values[0]),) + @impl(stmts.log) + def log(self, interp, frame: Frame, stmt: stmts.log): + values = frame.get_values(stmt.args) + return (math.log(values[0], values[1]),) + @impl(stmts.log10) def log10(self, interp, frame: Frame, stmt: stmts.log10): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/math/stmts.py b/src/kirin/dialects/math/stmts.py index dc0d22f3c..0a165b0d2 100644 --- a/src/kirin/dialects/math/stmts.py +++ b/src/kirin/dialects/math/stmts.py @@ -237,6 +237,17 @@ class lgamma(ir.Statement): result: ir.ResultValue = info.result(types.Float) +@statement(dialect=dialect) +class log(ir.Statement): + """log statement, wrapping the math.log function""" + + name = "log" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + x: ir.SSAValue = info.argument(types.Float) + base: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(types.Float) + + @statement(dialect=dialect) class log10(ir.Statement): """log10 statement, wrapping the math.log10 function""" diff --git a/test/dialects/math/test_basic.py b/test/dialects/math/test_basic.py index 0a445b215..376848a53 100644 --- a/test/dialects/math/test_basic.py +++ b/test/dialects/math/test_basic.py @@ -236,6 +236,16 @@ def test_lgamma(): assert (lgamma_func(0.42) - truth) < 1e-6 +@basic +def log_func(x, base): + return math.log(x, base) + + +def test_log(): + truth = pymath.log(0.42, 0.42) + assert (log_func(0.42, 0.42) - truth) < 1e-6 + + @basic def log10_func(x): return math.log10(x)