Skip to content

Commit 30814c2

Browse files
authored
Add triton call forward rule (#1612)
* Add triton call forward rule * Layouts * Add test
1 parent b55ffd4 commit 30814c2

File tree

3 files changed

+350
-47
lines changed

3 files changed

+350
-47
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
//===- TritonExtAutoDiffOpInterfaceImpl.cpp - Interface external model ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the external model implementation of the automatic
10+
// differentiation op interfaces for the MLIR triton_ext dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
15+
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
16+
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
17+
18+
#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h"
19+
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
20+
21+
using namespace mlir;
22+
using namespace mlir::enzyme;
23+
using namespace mlir::enzymexla;
24+
25+
namespace {
26+
27+
// this assumes no tuple in either args or results.
28+
static std::optional<unsigned>
29+
findAliasedOperand(ArrayAttr outputOperandAliases, unsigned outputIndex) {
30+
for (auto attr : outputOperandAliases) {
31+
auto alias = cast<stablehlo::OutputOperandAliasAttr>(attr);
32+
if (alias.getOutputTupleIndices()[0] != outputIndex)
33+
continue;
34+
assert(alias.getOutputTupleIndices().size() == 1);
35+
assert(alias.getOperandTupleIndices().empty());
36+
return alias.getOperandIndex();
37+
}
38+
return std::nullopt;
39+
}
40+
41+
class AutoDiffTritonCallFwd
42+
: public AutoDiffOpInterface::ExternalModel<AutoDiffTritonCallFwd,
43+
triton_ext::TritonCallOp> {
44+
public:
45+
LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
46+
MGradientUtils *gutils) const {
47+
DerivativeMode mode = DerivativeMode::ForwardMode;
48+
49+
auto callOp = cast<triton_ext::TritonCallOp>(orig);
50+
51+
for (auto [i, arg] : llvm::enumerate(callOp.getInputs())) {
52+
if (!isa<TensorType>(arg.getType())) {
53+
orig->emitError()
54+
<< "unsupported forward rule of triton kernel call with non array "
55+
"return at return #"
56+
<< i << " of type " << arg.getType() << ".";
57+
return failure();
58+
}
59+
}
60+
61+
for (auto [i, res] : llvm::enumerate(callOp->getResults())) {
62+
if (!isa<TensorType>(res.getType())) {
63+
orig->emitError()
64+
<< "unsupported forward rule of triton kernel call with non array "
65+
"return at return #"
66+
<< i << " of type " << res.getType() << ".";
67+
return failure();
68+
}
69+
}
70+
71+
auto output_operand_aliases = callOp.getOutputOperandAliases();
72+
auto operandLayouts = dyn_cast_or_null<ArrayAttr>(
73+
callOp.getOperandLayouts().value_or(nullptr));
74+
auto resultLayouts = dyn_cast_or_null<ArrayAttr>(
75+
callOp.getResultLayouts().value_or(nullptr));
76+
77+
Operation *callee =
78+
SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getFn());
79+
auto fn = cast<FunctionOpInterface>(callee);
80+
81+
size_t width = gutils->width;
82+
83+
int numInputs = callOp.getInputs().size();
84+
int narg = numInputs + orig->getNumResults();
85+
86+
std::vector<DIFFE_TYPE> RetActivity;
87+
std::vector<bool> returnPrimal;
88+
std::vector<bool> returnShadow;
89+
90+
// Unless there is aliasing, returns values arguments are assumed to
91+
// appended to the argument list in the triton kernel.
92+
SmallVector<unsigned> operandIndexMap;
93+
94+
unsigned argCnt = 0;
95+
96+
std::vector<DIFFE_TYPE> ArgActivity;
97+
for (auto arg : callOp.getInputs()) {
98+
auto act = gutils->isConstantValue(arg) ? DIFFE_TYPE::CONSTANT
99+
: DIFFE_TYPE::DUP_ARG;
100+
operandIndexMap.push_back(argCnt);
101+
ArgActivity.push_back(act);
102+
argCnt++;
103+
if (act == DIFFE_TYPE::DUP_ARG)
104+
argCnt++;
105+
}
106+
107+
for (auto [i, res] : llvm::enumerate(callOp.getResults())) {
108+
auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i);
109+
if (!aliasedOperandIndex.has_value()) {
110+
auto act = gutils->isConstantValue(res) ? DIFFE_TYPE::CONSTANT
111+
: DIFFE_TYPE::DUP_ARG;
112+
ArgActivity.push_back(act);
113+
} else {
114+
narg--;
115+
}
116+
}
117+
118+
auto type_args = gutils->TA.getAnalyzedTypeInfo(fn);
119+
120+
bool freeMemory = true;
121+
122+
std::vector<bool> volatile_args(narg, false);
123+
124+
auto forwardFn = gutils->Logic.CreateForwardDiff(
125+
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
126+
freeMemory, width,
127+
/* addedType */ nullptr, type_args, volatile_args,
128+
/* augmented */ nullptr, gutils->omp, gutils->postpasses,
129+
gutils->verifyPostPasses, gutils->strongZero);
130+
131+
SmallVector<Value> fwdArguments;
132+
SmallVector<Type> returnTypes;
133+
134+
// let's assume the same layout for a value and its shadow.
135+
SmallVector<Attribute> newOperandLayouts;
136+
SmallVector<Attribute> newResultLayouts;
137+
138+
unsigned argIdx = 0;
139+
for (auto &&[arg, act] : llvm::zip(callOp.getInputs(), ArgActivity)) {
140+
fwdArguments.push_back(gutils->getNewFromOriginal(arg));
141+
142+
if (operandLayouts) {
143+
newOperandLayouts.push_back(operandLayouts[argIdx]);
144+
if (act == DIFFE_TYPE::DUP_ARG)
145+
newOperandLayouts.push_back(operandLayouts[argIdx]);
146+
}
147+
argIdx++;
148+
149+
if (act == DIFFE_TYPE::DUP_ARG)
150+
fwdArguments.push_back(gutils->invertPointerM(arg, builder));
151+
}
152+
153+
SmallVector<Attribute> newOutputOperandAliases;
154+
155+
unsigned naliased = 0;
156+
for (auto &&[i, res] : llvm::enumerate(callOp->getResults())) {
157+
auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i);
158+
159+
DIFFE_TYPE act;
160+
if (aliasedOperandIndex.has_value()) {
161+
naliased++;
162+
163+
act = ArgActivity[*aliasedOperandIndex];
164+
165+
auto newOperandIndex = operandIndexMap[*aliasedOperandIndex];
166+
int64_t newResultIndex = returnTypes.size();
167+
newOutputOperandAliases.push_back(
168+
stablehlo::OutputOperandAliasAttr::get(
169+
callOp.getContext(), ArrayRef<int64_t>{newResultIndex},
170+
newOperandIndex, ArrayRef<int64_t>{}));
171+
172+
if (act == DIFFE_TYPE::DUP_ARG) {
173+
newOutputOperandAliases.push_back(
174+
stablehlo::OutputOperandAliasAttr::get(
175+
callOp.getContext(), ArrayRef<int64_t>{newResultIndex + 1},
176+
newOperandIndex + 1, ArrayRef<int64_t>{}));
177+
}
178+
} else {
179+
act = ArgActivity[i - naliased + numInputs];
180+
}
181+
182+
if (resultLayouts) {
183+
newResultLayouts.push_back(resultLayouts[i]);
184+
if (act == DIFFE_TYPE::DUP_ARG)
185+
newResultLayouts.push_back(resultLayouts[i]);
186+
}
187+
188+
returnTypes.push_back(res.getType());
189+
if (act == DIFFE_TYPE::DUP_ARG)
190+
returnTypes.push_back(
191+
cast<AutoDiffTypeInterface>(res.getType()).getShadowType(width));
192+
}
193+
194+
SmallVector<FlatSymbolRefAttr, 2> nestedRefs = {
195+
FlatSymbolRefAttr::get(
196+
forwardFn->getParentOfType<mlir::ModuleOp>().getSymNameAttr()),
197+
FlatSymbolRefAttr::get(
198+
StringAttr::get(callOp.getContext(), forwardFn.getName()))};
199+
auto fnRef = SymbolRefAttr::get(
200+
callOp.getContext(),
201+
forwardFn->getParentOfType<triton_ext::TritonModuleOp>().getSymName(),
202+
nestedRefs);
203+
204+
Value gridx = gutils->getNewFromOriginal(callOp.getGridx()),
205+
gridy = gutils->getNewFromOriginal(callOp.getGridy()),
206+
gridz = gutils->getNewFromOriginal(callOp.getGridz());
207+
208+
Value clusterx = gutils->getNewFromOriginal(callOp.getClusterx()),
209+
clustery = gutils->getNewFromOriginal(callOp.getClustery()),
210+
clusterz = gutils->getNewFromOriginal(callOp.getClusterz());
211+
212+
Attribute newOperandLayoutsAttr =
213+
operandLayouts ? ArrayAttr::get(callOp.getContext(), newOperandLayouts)
214+
: nullptr;
215+
Attribute newResultLayoutsAttr =
216+
resultLayouts ? ArrayAttr::get(callOp.getContext(), newResultLayouts)
217+
: nullptr;
218+
219+
auto fwdCallOp = triton_ext::TritonCallOp::create(
220+
builder, callOp.getLoc(), TypeRange(returnTypes),
221+
/*fn*/ fnRef,
222+
223+
gridx, gridy, gridz,
224+
225+
clusterx, clustery, clusterz,
226+
227+
ValueRange(fwdArguments),
228+
/* backendConfig */ StringAttr::get(callOp.getContext(), ""),
229+
newOperandLayoutsAttr, newResultLayoutsAttr,
230+
/* argAttrs */ mlir::ArrayAttr::get(callOp.getContext(), {}),
231+
/* resAttrs */ mlir::ArrayAttr::get(callOp.getContext(), {}),
232+
ArrayAttr::get(callOp.getContext(), newOutputOperandAliases),
233+
/* xla_side_effect_free */ nullptr);
234+
235+
SmallVector<Value> primals;
236+
primals.reserve(callOp->getNumResults());
237+
238+
naliased = 0;
239+
int fwdIndex = 0;
240+
for (auto &&[i, ret] : llvm::enumerate(callOp.getResults())) {
241+
auto fwdRet = fwdCallOp.getResult(fwdIndex);
242+
primals.push_back(fwdRet);
243+
244+
fwdIndex++;
245+
246+
auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i);
247+
248+
DIFFE_TYPE act;
249+
if (aliasedOperandIndex.has_value()) {
250+
act = ArgActivity[*aliasedOperandIndex];
251+
naliased++;
252+
} else {
253+
act = ArgActivity[i - naliased + numInputs];
254+
}
255+
256+
if (act == DIFFE_TYPE::DUP_ARG) {
257+
gutils->setDiffe(ret, fwdCallOp.getResult(fwdIndex), builder);
258+
fwdIndex++;
259+
}
260+
}
261+
262+
auto newOp = gutils->getNewFromOriginal(orig);
263+
gutils->replaceOrigOpWith(orig, primals);
264+
gutils->erase(newOp);
265+
266+
return success();
267+
}
268+
};
269+
270+
} // end anonymous namespace
271+
272+
void mlir::enzyme::registerTritonExtDialectAutoDiffInterface(
273+
mlir::DialectRegistry &registry) {
274+
registry.addExtension(+[](MLIRContext *context,
275+
triton_ext::TritonExtDialect *) {
276+
triton_ext::TritonCallOp::attachInterface<AutoDiffTritonCallFwd>(*context);
277+
});
278+
}

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1515
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1616
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
1717
void registerTritonDialectAutoDiffInterface(mlir::DialectRegistry &registry);
18+
void registerTritonExtDialectAutoDiffInterface(mlir::DialectRegistry &registry);
1819

1920
static inline void
2021
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
@@ -23,6 +24,7 @@ registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
2324
registerCHLODialectAutoDiffInterface(registry);
2425
registerEnzymeXLADialectAutoDiffInterface(registry);
2526
registerTritonDialectAutoDiffInterface(registry);
27+
registerTritonExtDialectAutoDiffInterface(registry);
2628
}
2729
} // namespace enzyme
2830
} // namespace mlir

0 commit comments

Comments
 (0)