diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index fd02c7a6b0..b856b49a24 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -660,6 +660,64 @@ gentbl_cc_library( ], ) +td_library( + name = "PerfifyDialectFiles", + srcs = [ + "Dialect/Perfify/Dialect.td", + "Dialect/Perfify/Ops.td", + ], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfaces", + ], +) + +gentbl_cc_library( + name = "PerfifyDialectIncGen", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=perfify", + ], + "Dialect/Perfify/PerfifyDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=perfify", + ], + "Dialect/Perfify/PerfifyDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Dialect/Perfify/Dialect.td", + deps = [ + ":PerfifyDialectFiles", + ], +) + +gentbl_cc_library( + name = "PerfifyOpsIncGen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "Dialect/Perfify/PerfifyOps.h.inc", + ), + ( + ["-gen-op-defs"], + "Dialect/Perfify/PerfifyOps.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Dialect/Perfify/Ops.td", + deps = [ + ":PerfifyDialectFiles", + ], +) + cc_library( name = "CheckedRewrite", hdrs = ["CheckedRewrite.h"], @@ -845,6 +903,7 @@ cc_library( "Dialect/*.cpp", "Dialect/Distributed/*.cpp", "Dialect/Tessera/*.cpp", + "Dialect/Perfify/*.cpp", ]) + [ "Utils.cpp", ], @@ -854,6 +913,7 @@ cc_library( "Dialect/*.h", "Dialect/Distributed/*.h", "Dialect/Tessera/*.h", + "Dialect/Perfify/*.h", ]) + [ "Utils.h", ], @@ -878,6 +938,8 @@ cc_library( ":EnzymeXLADialectUtils", ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", + ":PerfifyDialectIncGen", + ":PerfifyOpsIncGen", ":RaisingTransformOps", ":RaisingTransformOpsImplIncGen", ":RaisingTransformOpsIncGen", diff --git a/src/enzyme_ad/jax/Dialect/Perfify/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.cpp new file mode 100644 index 0000000000..7945eb6981 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.cpp @@ -0,0 +1,14 @@ +#include "Dialect.h" + +#include "mlir/IR/Builders.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.cpp.inc" + +// Initialize the dialect +void mlir::enzyme::perfify::PerfifyDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc" + >(); +} diff --git a/src/enzyme_ad/jax/Dialect/Perfify/Dialect.h b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.h new file mode 100644 index 0000000000..33899055e7 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.h @@ -0,0 +1,19 @@ +#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H +#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Types.h" + +// Include the dialect +#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.h.inc" + +// Operations +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.h.inc" + +#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Perfify/Dialect.td b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.td new file mode 100644 index 0000000000..3ac22ac7a0 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Perfify/Dialect.td @@ -0,0 +1,32 @@ +#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD +#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD + +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/Traits.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Perfify dialect definition. +//===----------------------------------------------------------------------===// + +def PerfifyDialect : Dialect { + let name = "perfify"; + let summary = "A dialect for specifying and proving runtime bounds"; + let description = [{ + Lets users specify a bound on the number of steps/latency (per a predefined cost model) that a function or other operation should take. + Leverages SAT solvers to automatically prove this, or interactive theorem provers to allow for complete proofs. + }]; + let cppNamespace = "::mlir::enzyme::perfify"; +} + +//===----------------------------------------------------------------------===// +// Base Perfify operation definition. +//===----------------------------------------------------------------------===// + +class PerfifyOp traits = []> + : Op; + +class PerfifyType : TypeDef; // may need to be modified + +#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/Dialect/Perfify/Ops.cpp b/src/enzyme_ad/jax/Dialect/Perfify/Ops.cpp new file mode 100644 index 0000000000..f2badddb6b --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Perfify/Ops.cpp @@ -0,0 +1,15 @@ +#include "mlir/IR/Builders.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "Dialect.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +using namespace mlir; +using namespace mlir::enzyme::perfify; + +namespace mlir::perfify {} // namespace mlir::perfify + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc" diff --git a/src/enzyme_ad/jax/Dialect/Perfify/Ops.td b/src/enzyme_ad/jax/Dialect/Perfify/Ops.td new file mode 100644 index 0000000000..3f9b8f861a --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Perfify/Ops.td @@ -0,0 +1,50 @@ +#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD +#define ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD + +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "Dialect.td" + +// Perfify.cost op +def CostOp : PerfifyOp<"cost", []> { + // summary + // description + // arguments + let arguments = (ins StrAttr:$target_op, + APIntAttr:$cycle_cost); + let assemblyFormat = "$target_op $cycle_cost attr-dict"; + +} + +def ArgOp : PerfifyOp<"arg", []> { + let arguments = (ins I64Attr:$val); + let assemblyFormat = "$val attr-dict"; + let results = (outs I64); +} + +def AssumeOp : PerfifyOp<"assume", [HasParent<"ConditionsOp">, Terminator]> { + let arguments = (ins I1:$precondition); + let assemblyFormat = "$precondition attr-dict"; +} + +def ConditionsOp : PerfifyOp<"conditions", [HasParent<"AssumptionsOp">, Terminator]> { + let arguments = (ins FlatSymbolRefAttr:$func_handle, + BoolAttr:$verify_huh); + let regions = (region AnyRegion:$precondition, AnyRegion:$postcondition); + + let assemblyFormat = [{ + $func_handle $verify_huh attr-dict `pre` + $precondition + `post` + $postcondition + }]; +} + +def AssumptionsOp : PerfifyOp<"assumptions", [Terminator]> { + let regions = (region AnyRegion:$body); + let assemblyFormat = [{$body attr-dict}]; +} + +#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD \ No newline at end of file diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 46ce854a3b..05b17ca64d 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -89,6 +89,7 @@ #include "src/enzyme_ad/jax/Passes/Passes.h" #include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Perfify/Dialect.h" #include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h" #include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" @@ -213,6 +214,7 @@ void registerDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/test/lit_tests/perfify/roundtrip.mlir b/test/lit_tests/perfify/roundtrip.mlir new file mode 100644 index 0000000000..e92a1ba3cf --- /dev/null +++ b/test/lit_tests/perfify/roundtrip.mlir @@ -0,0 +1,47 @@ +// RUN: enzymexlamlir-opt %s | FileCheck %s +module { + func.func @foo() {func.return} + perfify.assumptions { // operation in the dialect + perfify.cost "arith.mul" 3 // op + perfify.cost "func.return" 0 + perfify.cost "scf.yield" 0 + + + perfify.conditions @foo true pre { + %b0 = perfify.arg 0 // op + %c0 = arith.constant 0 + %cmp = arith.cmpi eq, %c0, %b0 : i64 + perfify.assume %cmp + } post { + // %cost = perfify.fn_cost : perfify.cost + // %c9 = perfify.constant_cost 9 : perfify.cost // then our cost is 9 + // %cmp = arith.cmpi eq, %cost, %c9 + %b0 = perfify.arg 0 // op + %c0 = arith.constant 0 + %cmp = arith.cmpi eq, %c0, %b0 : i64 + perfify.assume %cmp + } + } +} + +// CHECK: module { +// CHECK-NEXT: func.func @foo() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: perfify.assumptions { +// CHECK-NEXT: perfify.cost "arith.mul" 3 : i64 +// CHECK-NEXT: perfify.cost "func.return" 0 : i64 +// CHECK-NEXT: perfify.cost "scf.yield" 0 : i64 +// CHECK-NEXT: perfify.conditions @foo true pre { +// CHECK-NEXT: %0 = perfify.arg 0 +// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64 +// CHECK-NEXT: %1 = arith.cmpi eq, %c0_i64, %0 : i64 +// CHECK-NEXT: perfify.assume %1 +// CHECK-NEXT: } post { +// CHECK-NEXT: %0 = perfify.arg 0 +// CHECK-NEXT: %c0_i64 = arith.constant 0 : i64 +// CHECK-NEXT: %1 = arith.cmpi eq, %c0_i64, %0 : i64 +// CHECK-NEXT: perfify.assume %1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file