Skip to content

Commit 8019bbc

Browse files
Alexey Stukalovalyst
authored andcommitted
test_grad/hess(): check that alt calls give same results
1 parent eb95b6a commit 8019bbc

File tree

1 file changed

+25
-31
lines changed

1 file changed

+25
-31
lines changed

test/examples/helper.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,38 @@ function test_gradient(model, params; rtol = 1e-10, atol = 0)
88
@test nparams(model) == length(params)
99

1010
true_grad = FiniteDiff.finite_difference_gradient(Base.Fix1(objective!, model), params)
11-
gradient = similar(params)
1211

13-
# F and G
14-
fill!(gradient, NaN)
15-
gradient!(gradient, model, params)
16-
@test gradient true_grad rtol = rtol atol = atol
12+
gradient_G = fill!(similar(params), NaN)
13+
gradient!(gradient_G, model, params)
14+
gradient_FG = fill!(similar(params), NaN)
15+
objective_gradient!(gradient_FG, model, params)
1716

18-
# only G
19-
fill!(gradient, NaN)
20-
objective_gradient!(gradient, model, params)
21-
@test gradient true_grad rtol = rtol atol = atol
17+
@test gradient_G == gradient_FG
18+
19+
#@info "G norm = $(norm(gradient_G - true_grad, Inf))"
20+
@test gradient_G true_grad rtol = rtol atol = atol
2221
end
2322

2423
function test_hessian(model, params; rtol = 1e-4, atol = 0)
2524
true_hessian =
2625
FiniteDiff.finite_difference_hessian(Base.Fix1(objective!, model), params)
27-
hessian = similar(params, size(true_hessian))
28-
gradient = similar(params)
29-
30-
# H
31-
fill!(hessian, NaN)
32-
hessian!(hessian, model, params)
33-
@test hessian true_hessian rtol = rtol atol = atol
34-
35-
# F and H
36-
fill!(hessian, NaN)
37-
objective_hessian!(hessian, model, params)
38-
@test hessian true_hessian rtol = rtol atol = atol
39-
40-
# G and H
41-
fill!(hessian, NaN)
42-
gradient_hessian!(gradient, hessian, model, params)
43-
@test hessian true_hessian rtol = rtol atol = atol
44-
45-
# F, G and H
46-
fill!(hessian, NaN)
47-
objective_gradient_hessian!(gradient, hessian, model, params)
48-
@test hessian true_hessian rtol = rtol atol = atol
26+
gradient = fill!(similar(params), NaN)
27+
28+
hessian_H = fill!(similar(parent(true_hessian)), NaN)
29+
hessian!(hessian_H, model, params)
30+
31+
hessian_FH = fill!(similar(hessian_H), NaN)
32+
objective_hessian!(hessian_FH, model, params)
33+
34+
hessian_GH = fill!(similar(hessian_H), NaN)
35+
gradient_hessian!(gradient, hessian_GH, model, params)
36+
37+
hessian_FGH = fill!(similar(hessian_H), NaN)
38+
objective_gradient_hessian!(gradient, hessian_FGH, model, params)
39+
40+
@test hessian_H == hessian_FH == hessian_GH == hessian_FGH
41+
42+
@test hessian_H true_hessian rtol = rtol atol = atol
4943
end
5044

5145
fitmeasure_names_ml = Dict(

0 commit comments

Comments
 (0)