1818#include " src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1919#include " src/enzyme_ad/jax/Passes/Tessera/Passes.h"
2020#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
21+ #include " mlir/IR/BuiltinDialect.h"
2122
2223using namespace mlir ;
2324using namespace mlir ::enzyme;
@@ -59,6 +60,20 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
5960 funcOp.getBody ().cloneInto (&tesseraDefineOp.getBody (),
6061 tesseraDefineOp.getBody ().end (),
6162 mapper);
63+
64+ // Now walk through the cloned operations and convert func.return to tessera.return
65+ tesseraDefineOp.walk ([&](func::ReturnOp returnOp) {
66+ rewriter.setInsertionPoint (returnOp);
67+ rewriter.replaceOpWithNewOp <tessera::ReturnOp>(returnOp, returnOp.getOperands ());
68+ });
69+
70+ // Convert func.call to tessera.call
71+ tesseraDefineOp.walk ([&](func::CallOp callOp) {
72+ rewriter.setInsertionPoint (callOp);
73+ rewriter.replaceOpWithNewOp <tessera::CallOp>(callOp, callOp.getResultTypes (),
74+ callOp.getOperands (),
75+ callOp->getAttrs ());
76+ });
6277 }
6378
6479 rewriter.eraseOp (funcOp);
@@ -81,7 +96,7 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8196 Operation *calleeOp = SymbolTable::lookupSymbolIn (moduleOp, calleeAttr);
8297
8398 // Only convert if the callee is a Tessera DefineOp
84- if (isa<tessera::DefineOp>(calleeOp))
99+ if (! isa<tessera::DefineOp>(calleeOp))
85100 return rewriter.notifyMatchFailure (callOp, " Callee is not a Tessera DefineOp" );
86101
87102 rewriter.replaceOpWithNewOp <tessera::CallOp>(callOp, callOp.getResultTypes (),
@@ -122,15 +137,28 @@ namespace mlir::enzyme::tessera {
122137struct FuncToTesseraPass
123138 : public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
124139
140+ StringRef getArgument () const final { return " func-to-tessera" ; }
141+ StringRef getDescription () const final { return " Convert func dialect to tessera dialect." ; }
142+
143+ void getDependentDialects (DialectRegistry ®istry) const override {
144+ registry.insert <tessera::TesseraDialect>();
145+ }
146+
125147 void runOnOperation () override {
126148 MLIRContext *ctx = &getContext ();
149+
150+ ConversionTarget target (*ctx);
151+ target.addLegalDialect <tessera::TesseraDialect>();
152+ target.addLegalDialect <BuiltinDialect>();
153+ target.addIllegalDialect <func::FuncDialect>();
154+
127155 RewritePatternSet patterns (ctx);
128156
129157 patterns.add <FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
130158
131- if (failed (applyPatternsAndFoldGreedily (getOperation (),
132- std::move (patterns))))
159+ if (failed (applyFullConversion (getOperation (), target, std::move (patterns))))
133160 signalPassFailure ();
161+
134162}
135163};
136164
0 commit comments