Skip to content

Commit 6335968

Browse files
authored
limited-memory SR1 operator (#14)
1 parent 66d92a2 commit 6335968

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed

src/lsr1.jl

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
export LSR1Operator #, InverseLSR1Operator
2+
3+
"A data type to hold information relative to LSR1 operators."
4+
type LSR1Data
5+
mem :: Int
6+
scaling :: Bool
7+
scaling_factor :: Float64
8+
s :: Array
9+
y :: Array
10+
ys :: Vector
11+
a :: Array
12+
as :: Array
13+
insert :: Int
14+
15+
function LSR1Data(n :: Int, mem :: Int;
16+
dtype :: DataType=Float64, scaling :: Bool=false, inverse :: Bool=true)
17+
return new(max(mem, 1),
18+
scaling,
19+
1.0,
20+
zeros(dtype, n, mem),
21+
zeros(dtype, n, mem),
22+
zeros(dtype, mem),
23+
zeros(dtype, n, mem),
24+
zeros(dtype, mem),
25+
1)
26+
end
27+
end
28+
29+
30+
"A type for limited-memory SR1 approximations."
31+
type LSR1Operator <: AbstractLinearOperator
32+
nrow :: Int
33+
ncol :: Int
34+
dtype :: DataType
35+
symmetric :: Bool
36+
hermitian :: Bool
37+
prod :: Function # apply the operator to a vector
38+
tprod :: Nullable{Function} # apply the transpose operator to a vector
39+
ctprod :: Nullable{Function} # apply the transpose conjugate operator to a vector
40+
inverse :: Bool
41+
data :: LSR1Data
42+
end
43+
44+
"Construct a limited-memory SR1 approximation in forward form."
45+
function LSR1Operator(n, mem :: Int=5; dtype :: DataType=Float64, scaling :: Bool=false)
46+
lsr1_data = LSR1Data(n, mem, dtype=dtype, scaling=scaling, inverse=false)
47+
48+
function lsr1_multiply(data :: LSR1Data, x :: Array)
49+
# Multiply operator with a vector.
50+
51+
if dtype == typeof(x[1])
52+
q = copy(x)
53+
else
54+
result_type = promote_type(dtype, typeof(x[1]))
55+
q = convert(Array{result_type}, x)
56+
end
57+
58+
data.scaling && (q[:] /= data.scaling_factor)
59+
60+
for i = 1 : data.mem
61+
k = mod(data.insert + i - 2, data.mem) + 1
62+
if data.ys[k] != 0
63+
q[:] += dot(data.a[:, k], x) / data.as[k] * data.a[:, k]
64+
end
65+
end
66+
return q
67+
end
68+
69+
return LSR1Operator(n, n, dtype, true, true,
70+
x -> lsr1_multiply(lsr1_data, x),
71+
Nullable{Function}(),
72+
Nullable{Function}(),
73+
false,
74+
lsr1_data)
75+
end
76+
77+
78+
"Push a new {s,y} pair into a L-SR1 operator."
79+
function push!(op :: LSR1Operator, s :: Vector, y :: Vector)
80+
81+
# op.counters.updates += 1
82+
data = op.data
83+
Bs = op * s
84+
ymBs = y - Bs
85+
ys = dot(y, s)
86+
87+
well_defined = abs(dot(ymBs, s)) >= 1.0e-8 + 1.0e-8 * norm(s) * norm(ymBs)
88+
89+
sufficient_curvature = true
90+
scaling_condition = true
91+
y_neq_s = true
92+
if data.scaling
93+
sufficient_curvature = abs(ys) >= 1.0e-8
94+
if sufficient_curvature
95+
scaling_factor = ys / dot(y, y)
96+
scaling_condition = norm(y - s / scaling_factor) >= 1.0e-8
97+
end
98+
end
99+
100+
if ~(well_defined && sufficient_curvature && scaling_condition && y_neq_s)
101+
# op.counters.rejects += 1
102+
return op
103+
end
104+
105+
data.s[:, data.insert] = s
106+
data.y[:, data.insert] = y
107+
data.ys[data.insert] = ys
108+
109+
# update scaling factor
110+
data.scaling && (data.scaling_factor = ys / dot(y, y))
111+
112+
# update next insertion position
113+
data.insert = mod(data.insert, data.mem) + 1
114+
115+
# update rank-1 terms
116+
for i = 1 : data.mem
117+
k = mod(data.insert + i - 2, data.mem) + 1
118+
if data.ys[k] != 0.0
119+
data.a[:, k] = data.y[:, k] - data.s[:, k] / data.scaling_factor
120+
for j = 1 : i-1
121+
l = mod(data.insert + j - 2, data.mem) + 1
122+
if data.ys[l] != 0.0
123+
data.a[:, k] -= dot(data.a[:, l], data.s[:, k]) / data.as[l] * data.a[:, l]
124+
end
125+
end
126+
data.as[k] = dot(data.a[:, k], data.s[:, k])
127+
end
128+
end
129+
130+
return op
131+
end
132+
133+
134+
"Extract the diagonal of a L-SR1 operator in forward mode."
135+
function diag(op :: LSR1Operator)
136+
op.inverse && throw("only the diagonal of a forward L-SR1 approximation is available")
137+
data = op.data
138+
139+
d = ones(op.nrow)
140+
data.scaling && (d[:] /= data.scaling_factor)
141+
142+
for i = 1 : data.mem
143+
k = mod(data.insert + i - 2, data.mem) + 1
144+
if data.ys[k] != 0.0
145+
for j = 1 : op.nrow
146+
d[j] += data.a[j, k]^2 / data.as[k]
147+
end
148+
end
149+
end
150+
return d
151+
end

src/qn.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ import Base.push!
44
import Base.diag
55

66
include("lbfgs.jl")
7+
include("lsr1.jl")

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ using LinearOperators
55
include("test_linop.jl")
66
include("test_cat.jl")
77
include("test_lbfgs.jl")
8+
include("test_lsr1.jl")

test/test_lsr1.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
ϵ = eps(Float64)
2+
rtol = sqrt(ϵ)
3+
4+
# test limited-memory SR1
5+
n = 10
6+
mem = 5
7+
B = LSR1Operator(n, mem)
8+
9+
@assert norm(diag(B) - diag(full(B))) <= rtol
10+
11+
@assert B.data.insert == 1
12+
@test norm(full(B) - eye(n)) <= ϵ
13+
14+
# Test that only valid updates are accepted.
15+
s = rand(n)
16+
y = B * s
17+
push!(B, s, y); @assert B.data.insert == 1
18+
19+
# Insert a few {s,y} pairs.
20+
for i = 1 : mem+2
21+
s = rand(n)
22+
y = rand(n)
23+
push!(B, s, y)
24+
end
25+
26+
@test check_hermitian(B)
27+
@assert norm(diag(B) - diag(full(B))) <= rtol
28+
29+
# test against full SR1 without scaling
30+
mem = n
31+
LB = LSR1Operator(n, mem)
32+
B = eye(n)
33+
34+
function sr1!(B, s, y)
35+
# dense SR1 update
36+
ymBs = y - B * s
37+
denom = dot(ymBs, s)
38+
if abs(denom) >= 1.0e-8 + 1.0e-8 * norm(s) * norm(ymBs)
39+
B = B + ymBs * ymBs' / denom
40+
end
41+
return B
42+
end
43+
44+
@assert norm(full(LB) - B) < rtol * norm(B)
45+
@assert norm(diag(LB) - diag(B)) < rtol * norm(diag(B))
46+
47+
for k = 1 : mem
48+
s = rand(n)
49+
y = rand(n)
50+
B = sr1!(B, s, y)
51+
LB = push!(LB, s, y)
52+
@assert norm(full(LB) - B) < rtol * norm(B)
53+
@assert norm(diag(LB) - diag(B)) < rtol * norm(diag(B))
54+
end

0 commit comments

Comments
 (0)