Skip to content

Commit cdc4415

Browse files
Merge pull request #202 from alyst/spec_matrices
Cleanup Special Matrices code
2 parents f1d0b85 + c1e7a69 commit cdc4415

File tree

6 files changed

+160
-140
lines changed

6 files changed

+160
-140
lines changed

src/StructuralEquationModels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ const SEM = StructuralEquationModels
2424
# type hierarchy
2525
include("types.jl")
2626
include("objective_gradient_hessian.jl")
27+
28+
# helper objects and functions
29+
include("additional_functions/commutation_matrix.jl")
30+
2731
# fitted objects
2832
include("frontend/fit/SemFit.jl")
2933
# specification of models
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
3+
transpose_linear_indices(n, [m])
4+
5+
Put each linear index of the *n×m* matrix to the position of the
6+
corresponding element in the transposed matrix.
7+
8+
## Example
9+
`
10+
1 4
11+
2 5 => 1 2 3
12+
3 6 4 5 6
13+
`
14+
"""
15+
transpose_linear_indices(n::Integer, m::Integer = n) =
16+
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)
17+
18+
"""
19+
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}
20+
21+
A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
22+
If *vec(A)* is a vectorized form of a n×n matrix *A*,
23+
then ``C * vec(A) = vec(Aᵀ)``.
24+
"""
25+
struct CommutationMatrix <: AbstractMatrix{Int}
26+
n::Int
27+
::Int
28+
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*
29+
30+
CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
31+
end
32+
33+
Base.size(A::CommutationMatrix) = (A.n², A.n²)
34+
Base.size(A::CommutationMatrix, dim::Integer) =
35+
1 <= dim <= 2 ? A.: throw(ArgumentError("invalid matrix dimension $dim"))
36+
Base.length(A::CommutationMatrix) = A.^2
37+
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0
38+
39+
function Base.:(*)(A::CommutationMatrix, B::AbstractVector)
40+
size(A, 2) == size(B, 1) || throw(
41+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) elements"),
42+
)
43+
return B[A.transpose_inds]
44+
end
45+
46+
function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
47+
size(A, 2) == size(B, 1) || throw(
48+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
49+
)
50+
return B[A.transpose_inds, :]
51+
end
52+
53+
function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
54+
size(A, 2) == size(B, 1) || throw(
55+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
56+
)
57+
return SparseMatrixCSC(
58+
size(B, 1),
59+
size(B, 2),
60+
copy(B.colptr),
61+
A.transpose_inds[B.rowval],
62+
copy(B.nzval),
63+
)
64+
end
65+
66+
function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
67+
size(A, 2) == size(B, 1) || throw(
68+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
69+
)
70+
71+
@inbounds for (i, rowind) in enumerate(B.rowval)
72+
B.rowval[i] = A.transpose_inds[rowind]
73+
end
74+
return B
75+
end

src/additional_functions/helper.jl

Lines changed: 23 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function get_observed(rowind, data, semobserved; args = (), kwargs = NamedTuple(
4141
return observed_vec
4242
end
4343

44-
skipmissing_mean(mat::AbstractMatrix) =
44+
skipmissing_mean(mat::AbstractMatrix) =
4545
[mean(skipmissing(coldata)) for coldata in eachcol(mat)]
4646

4747
function F_one_person(imp_mean, meandiff, inverse, data, logdet)
@@ -111,143 +111,34 @@ function cov_and_mean(rows; corrected = false)
111111
return obs_cov, vec(obs_mean)
112112
end
113113

114-
function duplication_matrix(nobs)
115-
nobs = Int(nobs)
116-
n1 = Int(nobs * (nobs + 1) * 0.5)
117-
n2 = Int(nobs^2)
118-
Dt = zeros(n1, n2)
119-
120-
for j in 1:nobs
121-
for i in j:nobs
122-
u = zeros(n1)
123-
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
124-
T = zeros(nobs, nobs)
125-
T[j, i] = 1
126-
T[i, j] = 1
127-
Dt += u * transpose(vec(T))
114+
# n²×(n(n+1)/2) matrix to transform a vector of lower
115+
# triangular entries into a vectorized form of a n×n symmetric matrix,
116+
# opposite of elimination_matrix()
117+
function duplication_matrix(n::Integer)
118+
ntri = div(n * (n + 1), 2)
119+
D = zeros(n^2, ntri)
120+
for j in 1:n
121+
for i in j:n
122+
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
123+
D[j+n*(i-1), tri_ix] = 1
124+
D[i+n*(j-1), tri_ix] = 1
128125
end
129126
end
130-
D = transpose(Dt)
131127
return D
132128
end
133129

134-
function elimination_matrix(nobs)
135-
nobs = Int(nobs)
136-
n1 = Int(nobs * (nobs + 1) * 0.5)
137-
n2 = Int(nobs^2)
138-
L = zeros(n1, n2)
139-
140-
for j in 1:nobs
141-
for i in j:nobs
142-
u = zeros(n1)
143-
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
144-
T = zeros(nobs, nobs)
145-
T[i, j] = 1
146-
L += u * transpose(vec(T))
130+
# (n(n+1)/2)×n² matrix to transform a
131+
# vectorized form of a n×n symmetric matrix
132+
# into vector of its lower triangular entries,
133+
# opposite of duplication_matrix()
134+
function elimination_matrix(n::Integer)
135+
ntri = div(n * (n + 1), 2)
136+
L = zeros(ntri, n^2)
137+
for j in 1:n
138+
for i in j:n
139+
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
140+
L[tri_ix, i+n*(j-1)] = 1
147141
end
148142
end
149143
return L
150144
end
151-
152-
function commutation_matrix(n; tosparse = false)
153-
M = zeros(n^2, n^2)
154-
155-
for i in 1:n
156-
for j in 1:n
157-
M[i+n*(j-1), j+n*(i-1)] = 1.0
158-
end
159-
end
160-
161-
if tosparse
162-
M = sparse(M)
163-
end
164-
165-
return M
166-
end
167-
168-
function commutation_matrix_pre_square(A)
169-
n2 = size(A, 1)
170-
n = Int(sqrt(n2))
171-
172-
ind = repeat(1:n, inner = n)
173-
indadd = (0:(n-1)) * n
174-
for i in 1:n
175-
ind[((i-1)*n+1):i*n] .+= indadd
176-
end
177-
178-
A_post = A[ind, :]
179-
180-
return A_post
181-
end
182-
183-
function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
184-
n2 = size(A, 1)
185-
n = Int(sqrt(n2))
186-
187-
ind = repeat(1:n, inner = n)
188-
indadd = (0:(n-1)) * n
189-
for i in 1:n
190-
ind[((i-1)*n+1):i*n] .+= indadd
191-
end
192-
193-
@views @inbounds B .+= A[ind, :]
194-
195-
return B
196-
end
197-
198-
function get_commutation_lookup(n2::Int64)
199-
n = Int(sqrt(n2))
200-
ind = repeat(1:n, inner = n)
201-
indadd = (0:(n-1)) * n
202-
for i in 1:n
203-
ind[((i-1)*n+1):i*n] .+= indadd
204-
end
205-
206-
lookup = Dict{Int64, Int64}()
207-
208-
for i in 1:n2
209-
j = findall(x -> (x == i), ind)[1]
210-
push!(lookup, i => j)
211-
end
212-
213-
return lookup
214-
end
215-
216-
function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
217-
for (i, rowind) in enumerate(A.rowval)
218-
A.rowval[i] = lookup[rowind]
219-
end
220-
end
221-
222-
function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
223-
lookup = get_commutation_lookup(size(A, 2))
224-
commutation_matrix_pre_square!(A, lookup)
225-
end
226-
227-
function commutation_matrix_pre_square(A::SparseMatrixCSC)
228-
B = copy(A)
229-
commutation_matrix_pre_square!(B)
230-
return B
231-
end
232-
233-
function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
234-
B = copy(A)
235-
commutation_matrix_pre_square!(B, lookup)
236-
return B
237-
end
238-
239-
function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
240-
n2 = size(A, 1)
241-
n = Int(sqrt(n2))
242-
243-
indadd = (0:(n-1)) * n
244-
245-
Threads.@threads for i in 1:n
246-
for j in 1:n
247-
row = i + indadd[j]
248-
@views @inbounds B[row, :] .+= A[row, :]
249-
end
250-
end
251-
252-
return B
253-
end

src/loss/ML/FIML.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Analytic gradients are available.
2424
## Implementation
2525
Subtype of `SemLossFunction`.
2626
"""
27-
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
27+
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, W} <: SemLossFunction
2828
inverses::INV #preallocated inverses of imp_cov
2929
choleskys::C #preallocated choleskys
3030
logdets::L #logdets of implied covmats
@@ -37,7 +37,7 @@ mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
3737

3838
mult::T
3939

40-
commutation_indices::U
40+
commutator::CommutationMatrix
4141

4242
interaction::W
4343
end
@@ -64,8 +64,6 @@ function SemFIML(; observed, specification, kwargs...)
6464
∇ind =
6565
[findall(x -> !(x[1] ind || x[2] ind), ∇ind) for ind in patterns_not(observed)]
6666

67-
commutation_indices = get_commutation_lookup(get_n_nodes(specification)^2)
68-
6967
return SemFIML(
7068
inverses,
7169
choleskys,
@@ -75,7 +73,7 @@ function SemFIML(; observed, specification, kwargs...)
7573
meandiff,
7674
imp_inv,
7775
mult,
78-
commutation_indices,
76+
CommutationMatrix(get_n_nodes(specification)),
7977
nothing,
8078
)
8179
end
@@ -163,10 +161,9 @@ function ∇F_fiml_outer(JΣ, Jμ, imply, model, semfiml)
163161
Iₙ = sparse(1.0I, size(A(imply))...)
164162
P = kron(F⨉I_A⁻¹(imply), F⨉I_A⁻¹(imply))
165163
Q = kron(S(imply) * I_A⁻¹(imply)', Iₙ)
166-
#commutation_matrix_pre_square_add!(Q, Q)
167-
Q2 = commutation_matrix_pre_square(Q, semfiml.commutation_indices)
164+
Q .+= semfiml.commutator * Q
168165

169-
∇Σ = P * (∇S(imply) + (Q + Q2) * ∇A(imply))
166+
∇Σ = P * (∇S(imply) + Q * ∇A(imply))
170167

171168
∇μ =
172169
F⨉I_A⁻¹(imply) * ∇M(imply) +

test/unit_tests/matrix_helpers.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using StructuralEquationModels, Test, Random, SparseArrays, LinearAlgebra
2+
using StructuralEquationModels:
3+
CommutationMatrix, transpose_linear_indices, duplication_matrix, elimination_matrix
4+
5+
Random.seed!(73721)
6+
7+
n = 4
8+
m = 5
9+
10+
@testset "Commutation matrix" begin
11+
# transpose linear indices
12+
A = rand(n, m)
13+
@test reshape(A[transpose_linear_indices(n, m)], m, n) == A'
14+
# commutation matrix multiplication
15+
K = CommutationMatrix(n)
16+
# test K array interface methods
17+
@test size(K) == (n^2, n^2)
18+
@test size(K, 1) == n^2
19+
@test length(K) == n^4
20+
nn_linind = LinearIndices((n, n))
21+
@test K[nn_linind[3, 2], nn_linind[2, 3]] == 1
22+
@test K[nn_linind[3, 2], nn_linind[3, 2]] == 0
23+
24+
B = rand(n, n)
25+
@test_throws DimensionMismatch K * rand(n, m)
26+
@test K * vec(B) == vec(B')
27+
C = sprand(n, n, 0.5)
28+
@test K * vec(C) == vec(C')
29+
# lmul!
30+
D = sprand(n^2, n^2, 0.1)
31+
E = copy(D)
32+
F = Matrix(E)
33+
lmul!(K, D)
34+
@test D == K * E
35+
@test Matrix(D) == K * F
36+
end
37+
38+
@testset "Duplication / elimination matrix" begin
39+
A = rand(m, m)
40+
A = A * A'
41+
42+
# dupication
43+
D = duplication_matrix(m)
44+
@test D * A[tril(trues(size(A)))] == vec(A)
45+
46+
# elimination
47+
E = elimination_matrix(m)
48+
@test E * vec(A) == A[tril(trues(size(A)))]
49+
end

test/unit_tests/unit_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ end
77
@safetestset "SemObs" begin
88
include("data_input_formats.jl")
99
end
10+
11+
@safetestset "Matrix algebra helper functions" begin
12+
include("matrix_helpers.jl")
13+
end

0 commit comments

Comments
 (0)