Skip to content

Commit 03dac06

Browse files
committed
CommutationMatrix type
replace comm_matrix helper functions with a CommutationMatrix and overloaded linalg ops
1 parent f1d0b85 commit 03dac06

File tree

4 files changed

+77
-111
lines changed

4 files changed

+77
-111
lines changed

src/StructuralEquationModels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ const SEM = StructuralEquationModels
2424
# type hierarchy
2525
include("types.jl")
2626
include("objective_gradient_hessian.jl")
27+
28+
# helper objects and functions
29+
include("additional_functions/commutation_matrix.jl")
30+
2731
# fitted objects
2832
include("frontend/fit/SemFit.jl")
2933
# specification of models
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
3+
transpose_linear_indices(n, [m])
4+
5+
Put each linear index of the *n×m* matrix to the position of the
6+
corresponding element in the transposed matrix.
7+
8+
## Example
9+
`
10+
1 4
11+
2 5 => 1 2 3
12+
3 6 4 5 6
13+
`
14+
"""
15+
transpose_linear_indices(n::Integer, m::Integer = n) =
16+
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)
17+
18+
"""
19+
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}
20+
21+
A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
22+
If *vec(A)* is a vectorized form of a n×n matrix *A*,
23+
then ``C * vec(A) = vec(Aᵀ)``.
24+
"""
25+
struct CommutationMatrix <: AbstractMatrix{Int}
26+
n::Int
27+
::Int
28+
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*
29+
30+
CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
31+
end
32+
33+
Base.size(A::CommutationMatrix) = (A.n², A.n²)
34+
Base.size(A::CommutationMatrix, dim::Integer) =
35+
1 <= dim <= 2 ? A.: throw(ArgumentError("invalid matrix dimension $dim"))
36+
Base.length(A::CommutationMatrix) = A.^2
37+
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0
38+
39+
function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
40+
size(A, 2) == size(B, 1) || throw(
41+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
42+
)
43+
return B[A.transpose_inds, :]
44+
end
45+
46+
function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
47+
size(A, 2) == size(B, 1) || throw(
48+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
49+
)
50+
return SparseMatrixCSC(
51+
size(B, 1),
52+
size(B, 2),
53+
copy(B.colptr),
54+
A.transpose_inds[B.rowval],
55+
copy(B.nzval),
56+
)
57+
end
58+
59+
function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
60+
size(A, 2) == size(B, 1) || throw(
61+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
62+
)
63+
64+
@inbounds for (i, rowind) in enumerate(B.rowval)
65+
B.rowval[i] = A.transpose_inds[rowind]
66+
end
67+
return B
68+
end

src/additional_functions/helper.jl

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -148,106 +148,3 @@ function elimination_matrix(nobs)
148148
end
149149
return L
150150
end
151-
152-
function commutation_matrix(n; tosparse = false)
153-
M = zeros(n^2, n^2)
154-
155-
for i in 1:n
156-
for j in 1:n
157-
M[i+n*(j-1), j+n*(i-1)] = 1.0
158-
end
159-
end
160-
161-
if tosparse
162-
M = sparse(M)
163-
end
164-
165-
return M
166-
end
167-
168-
function commutation_matrix_pre_square(A)
169-
n2 = size(A, 1)
170-
n = Int(sqrt(n2))
171-
172-
ind = repeat(1:n, inner = n)
173-
indadd = (0:(n-1)) * n
174-
for i in 1:n
175-
ind[((i-1)*n+1):i*n] .+= indadd
176-
end
177-
178-
A_post = A[ind, :]
179-
180-
return A_post
181-
end
182-
183-
function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
184-
n2 = size(A, 1)
185-
n = Int(sqrt(n2))
186-
187-
ind = repeat(1:n, inner = n)
188-
indadd = (0:(n-1)) * n
189-
for i in 1:n
190-
ind[((i-1)*n+1):i*n] .+= indadd
191-
end
192-
193-
@views @inbounds B .+= A[ind, :]
194-
195-
return B
196-
end
197-
198-
function get_commutation_lookup(n2::Int64)
199-
n = Int(sqrt(n2))
200-
ind = repeat(1:n, inner = n)
201-
indadd = (0:(n-1)) * n
202-
for i in 1:n
203-
ind[((i-1)*n+1):i*n] .+= indadd
204-
end
205-
206-
lookup = Dict{Int64, Int64}()
207-
208-
for i in 1:n2
209-
j = findall(x -> (x == i), ind)[1]
210-
push!(lookup, i => j)
211-
end
212-
213-
return lookup
214-
end
215-
216-
function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
217-
for (i, rowind) in enumerate(A.rowval)
218-
A.rowval[i] = lookup[rowind]
219-
end
220-
end
221-
222-
function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
223-
lookup = get_commutation_lookup(size(A, 2))
224-
commutation_matrix_pre_square!(A, lookup)
225-
end
226-
227-
function commutation_matrix_pre_square(A::SparseMatrixCSC)
228-
B = copy(A)
229-
commutation_matrix_pre_square!(B)
230-
return B
231-
end
232-
233-
function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
234-
B = copy(A)
235-
commutation_matrix_pre_square!(B, lookup)
236-
return B
237-
end
238-
239-
function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
240-
n2 = size(A, 1)
241-
n = Int(sqrt(n2))
242-
243-
indadd = (0:(n-1)) * n
244-
245-
Threads.@threads for i in 1:n
246-
for j in 1:n
247-
row = i + indadd[j]
248-
@views @inbounds B[row, :] .+= A[row, :]
249-
end
250-
end
251-
252-
return B
253-
end

src/loss/ML/FIML.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Analytic gradients are available.
2424
## Implementation
2525
Subtype of `SemLossFunction`.
2626
"""
27-
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
27+
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, W} <: SemLossFunction
2828
inverses::INV #preallocated inverses of imp_cov
2929
choleskys::C #preallocated choleskys
3030
logdets::L #logdets of implied covmats
@@ -37,7 +37,7 @@ mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
3737

3838
mult::T
3939

40-
commutation_indices::U
40+
commutator::CommutationMatrix
4141

4242
interaction::W
4343
end
@@ -64,8 +64,6 @@ function SemFIML(; observed, specification, kwargs...)
6464
∇ind =
6565
[findall(x -> !(x[1] ind || x[2] ind), ∇ind) for ind in patterns_not(observed)]
6666

67-
commutation_indices = get_commutation_lookup(get_n_nodes(specification)^2)
68-
6967
return SemFIML(
7068
inverses,
7169
choleskys,
@@ -75,7 +73,7 @@ function SemFIML(; observed, specification, kwargs...)
7573
meandiff,
7674
imp_inv,
7775
mult,
78-
commutation_indices,
76+
CommutationMatrix(get_n_nodes(specification)),
7977
nothing,
8078
)
8179
end
@@ -163,10 +161,9 @@ function ∇F_fiml_outer(JΣ, Jμ, imply, model, semfiml)
163161
Iₙ = sparse(1.0I, size(A(imply))...)
164162
P = kron(F⨉I_A⁻¹(imply), F⨉I_A⁻¹(imply))
165163
Q = kron(S(imply) * I_A⁻¹(imply)', Iₙ)
166-
#commutation_matrix_pre_square_add!(Q, Q)
167-
Q2 = commutation_matrix_pre_square(Q, semfiml.commutation_indices)
164+
Q .+= semfiml.commutator * Q
168165

169-
∇Σ = P * (∇S(imply) + (Q + Q2) * ∇A(imply))
166+
∇Σ = P * (∇S(imply) + Q * ∇A(imply))
170167

171168
∇μ =
172169
F⨉I_A⁻¹(imply) * ∇M(imply) +

0 commit comments

Comments
 (0)