Skip to content

Commit 48f2dfa

Browse files
committed
Dynamic Inserts
1 parent f7164c8 commit 48f2dfa

File tree

2 files changed

+53
-37
lines changed

2 files changed

+53
-37
lines changed

enzyme/Enzyme/CallDerivatives.cpp

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,39 +2595,22 @@ bool AdjointGenerator::handleKnownCallDerivatives(
25952595

25962596
// 2) STL std::list insertion
25972597
if (funcName == "_ZNSt8__detail15_List_node_base7_M_hookEPS0_") {
2598-
if (Mode == DerivativeMode::ReverseModeGradient) {
2599-
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2600-
return true;
2601-
}
2602-
if (gutils->isConstantValue(call.getArgOperand(0)))
2603-
return true;
2604-
SmallVector<Value *, 2> args;
2605-
for (auto &arg : call.args()) {
2606-
if (gutils->isConstantValue(arg))
2607-
args.push_back(gutils->getNewFromOriginal(arg));
2608-
else
2609-
args.push_back(gutils->invertPointerM(arg, BuilderZ));
2610-
}
2611-
BuilderZ.CreateCall(called, args);
2598+
// Only run in primal version to avoid memory management issues
2599+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
26122600
return true;
26132601
}
26142602

26152603
// 3) STL std::list transfer (splice operations)
26162604
if (funcName == "_ZNSt8__detail15_List_node_base11_M_transferEPS0_S1_") {
2617-
if (Mode == DerivativeMode::ReverseModeGradient) {
2618-
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2619-
return true;
2620-
}
2621-
if (gutils->isConstantValue(call.getArgOperand(0)))
2622-
return true;
2623-
SmallVector<Value *, 2> args;
2624-
for (auto &arg : call.args()) {
2625-
if (gutils->isConstantValue(arg))
2626-
args.push_back(gutils->getNewFromOriginal(arg));
2627-
else
2628-
args.push_back(gutils->invertPointerM(arg, BuilderZ));
2629-
}
2630-
BuilderZ.CreateCall(called, args);
2605+
// Only run in primal version to avoid memory management issues
2606+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
2607+
return true;
2608+
}
2609+
2610+
// 4) STL std::list size increment (called when adding elements)
2611+
if (funcName == "_ZNSt7__cxx1110_List_baseIdSaIdEE11_M_inc_sizeEm") {
2612+
// Only run in primal version to avoid memory management issues
2613+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
26312614
return true;
26322615
}
26332616

enzyme/test/Integration/ReverseMode/stl_list.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010

1111
template<typename T>
12-
extern void __enzyme_fwddiff(void*, int, T&, T&);
12+
extern double __enzyme_fwddiff(void*, int, T&, T&);
1313
template<typename T>
14-
extern void __enzyme_autodiff(void*, int, T&, T&);
14+
extern double __enzyme_autodiff(void*, int, T&, T&);
1515

1616

1717
double test_simple_list(std::list<double>& vals) {
@@ -26,23 +26,56 @@ double test_simple_list(std::list<double>& vals) {
2626
return result;
2727
}
2828

29+
double test_dynamic_list(std::list<double>& vals) {
30+
// dynamically insert new elements
31+
vals.push_back(4.0);
32+
vals.push_back(5.0);
33+
34+
// iterate over list
35+
double result = 0.0;
36+
for (const auto& val : vals) {
37+
result += val * val;
38+
}
39+
return result;
40+
}
41+
2942
void test_forward_list() {
30-
std::list<double> vals = {1.0, 2.0, 3.0};
31-
std::list<double> dvals = {1.0, 1.0, 1.0};
43+
{
44+
std::list<double> vals = {1.0, 2.0, 3.0};
45+
std::list<double> dvals = {1.0, 1.0, 1.0};
46+
47+
double ret = __enzyme_fwddiff((void*)test_simple_list, enzyme_dup, vals, dvals);
48+
APPROX_EQ( ret, 10., 1e-10);
49+
}
50+
{
51+
std::list<double> vals = {1.0, 2.0, 3.0};
52+
std::list<double> dvals = {1.0, 1.0, 1.0};
3253

33-
__enzyme_fwddiff((void*)test_simple_list, enzyme_dup, vals, dvals);
54+
double ret = __enzyme_fwddiff((void*)test_dynamic_list, enzyme_dup, vals, dvals);
55+
APPROX_EQ( ret, 12., 1e-10);
56+
}
3457
}
3558

3659
void test_reverse_list() {
37-
std::list<double> vals = {1.0, 2.0, 3.0};
38-
std::list<double> dvals = {1.0, 1.0, 1.0};
60+
{
61+
std::list<double> vals = {1.0, 2.0, 3.0};
62+
std::list<double> dvals = {1.0, 1.0, 1.0};
3963

40-
__enzyme_autodiff((void*)test_simple_list, enzyme_dup, vals, dvals);
64+
double ret = __enzyme_autodiff((void*)test_simple_list, enzyme_dup, vals, dvals);
65+
//APPROX_EQ( ret, 10., 1e-10);
66+
}
67+
{
68+
std::list<double> vals = {1.0, 2.0, 3.0};
69+
std::list<double> dvals = {1.0, 1.0, 1.0};
70+
71+
double ret = __enzyme_autodiff((void*)test_dynamic_list, enzyme_dup, vals, dvals);
72+
//APPROX_EQ( ret, 12., 1e-10);
73+
}
4174
}
4275

4376

4477
int main() {
4578
test_forward_list();
4679
test_reverse_list();
4780
return 0;
48-
}
81+
}

0 commit comments

Comments
 (0)