Skip to content

Commit ba5111c

Browse files
committed
Add test_modify_list
Currently fails.
1 parent c14abdc commit ba5111c

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

enzyme/test/Integration/ReverseMode/stl_list.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#include <list>
99

1010

11-
template<typename T>
12-
extern double __enzyme_fwddiff(void*, int, T&, T&);
11+
template<typename ...T>
12+
extern double __enzyme_fwddiff(void*, T...);
1313
template<typename T>
1414
extern double __enzyme_autodiff(void*, int, T&, T&);
1515

@@ -23,14 +23,41 @@ double test_iterate_list(std::list<double>& vals) {
2323
return result;
2424
}
2525

26+
struct S {
27+
S(double r) : x(r) {};
28+
double x = 0.0;
29+
};
30+
31+
double test_modify_list(std::list<S> vals, double x) {
32+
vals.front().x = x;
33+
34+
// iterate over list
35+
double result = 0.0;
36+
for (const auto& val : vals) {
37+
result += val.x * val.x;
38+
}
39+
return result;
40+
}
41+
2642
void test_forward_list() {
43+
// diff all values of list
2744
{
2845
std::list<double> vals = {1.0, 2.0, 3.0};
2946
std::list<double> dvals = {1.0, 1.0, 1.0};
3047

3148
double ret = __enzyme_fwddiff((void*)test_iterate_list, enzyme_dup, vals, dvals);
3249
APPROX_EQ(ret, 12., 1e-10);
3350
}
51+
52+
// list is const, then first value set to active
53+
{
54+
std::list<S> vals = {S{1.0}, S{2.0}, S{3.0}};
55+
double x = 3.0;
56+
double dx = 1.0;
57+
58+
double ret = __enzyme_fwddiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, x, dx);
59+
APPROX_EQ(ret, 6., 1e-10);
60+
}
3461
}
3562

3663
void test_reverse_list() {

0 commit comments

Comments
 (0)