Skip to content

Commit 974a450

Browse files
authored
[Nonlinear] Change Vector{T} to AbstractVector{T} in _extract_reverse_pass and add test (#1934)
1 parent 24d2a0d commit 974a450

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/Nonlinear/ReverseAD/reverse_mode.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ end
312312

313313
"""
314314
_extract_reverse_pass(
315-
g::Vector{T},
315+
g::AbstractVector{T},
316316
d::NLPEvaluator,
317317
f::Union{_FunctionStorage,_SubexpressionStorage},
318318
) where {T}
@@ -321,7 +321,7 @@ Fill the gradient vector `g` with the values from the reverse pass. Assumes you
321321
have already called `_reverse_eval_all(d, x)`.
322322
"""
323323
function _extract_reverse_pass(
324-
g::Vector{T},
324+
g::AbstractVector{T},
325325
d::NLPEvaluator,
326326
f::Union{_FunctionStorage,_SubexpressionStorage},
327327
) where {T}

test/Nonlinear/ReverseAD.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,35 @@ function test_gradient_nested_subexpressions()
850850
return
851851
end
852852

853+
function test_gradient_view()
854+
x = MOI.VariableIndex(1)
855+
y = MOI.VariableIndex(2)
856+
model = Nonlinear.Model()
857+
Nonlinear.set_objective(model, :(($x - 1)^2 + 4 * ($y - $x^2)^2))
858+
evaluator = Nonlinear.Evaluator(
859+
model,
860+
Nonlinear.SparseReverseMode(),
861+
MOI.VariableIndex[x, y],
862+
)
863+
MOI.initialize(evaluator, [:Grad])
864+
X_input = [0.0; 1.0; 2.0; 3.0]
865+
for idx in [1:2, 3:4, [1, 3], 4:-2:2]
866+
x_input = view(X_input, idx)
867+
x_vec = Vector(x_input)
868+
∇f = fill(Inf, 2)
869+
∇fv = fill(Inf, 2)
870+
MOI.eval_objective_gradient(evaluator, ∇f, x_input)
871+
MOI.eval_objective_gradient(evaluator, ∇fv, x_vec)
872+
@test ∇f == ∇fv
873+
∇f = fill(Inf, 2)
874+
∇fv = view(fill(Inf, 4), idx)
875+
MOI.eval_objective_gradient(evaluator, ∇f, x_input)
876+
MOI.eval_objective_gradient(evaluator, ∇fv, x_input)
877+
@test ∇f == ∇fv
878+
end
879+
return
880+
end
881+
853882
function _dense_hessian(hessian_sparsity, V, n)
854883
I = [i for (i, _) in hessian_sparsity]
855884
J = [j for (_, j) in hessian_sparsity]

0 commit comments

Comments
 (0)