Skip to content

Commit 8afedcb

Browse files
Alexey Stukalovalyst
authored andcommitted
MissingPattern: transpose data
for faster EM MVN
1 parent 47df271 commit 8afedcb

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

src/observed/EM.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,20 @@ function em_mvn(
3232
)
3333
n_man = SEM.n_man(patterns[1])
3434

35-
### precompute for full cases
35+
# precompute for full cases
3636
𝔼x_full = zeros(n_man)
3737
𝔼xxᵀ_full = zeros(n_man, n_man)
38-
if nmissed_vars(patterns[1]) == 0
39-
fullpat = patterns[1]
40-
sum!(reshape(𝔼x_full, 1, n_man), fullpat.data)
41-
mul!(𝔼xxᵀ_full, fullpat.data', fullpat.data)
42-
else
43-
@warn "No full cases pattern found"
38+
nobs_full = 0
39+
for pat in patterns
40+
if nmissed_vars(pat) == 0
41+
𝔼x_full .+= sum(pat.data, dims = 2)
42+
mul!(𝔼xxᵀ_full, pat.data, pat.data', 1, 1)
43+
nobs_full += n_obs(pat)
44+
end
45+
end
46+
if nobs_full == 0
47+
@warn "No full cases in data"
4448
end
45-
46-
# ess = 𝔼x, 𝔼xxᵀ, ismissing, missingRows, n_obs
47-
# estepFn = (em_model, data) -> estep(em_model, data, EXsum, EXXsum, ismissing, missingRows, n_obs)
4849

4950
# initialize
5051
Σ₀, μ = start_em(patterns; kwargs...)
@@ -121,8 +122,8 @@ function em_step!(
121122
𝔼xxᵀuo = fill!(similar(Σuo), 0)
122123
𝔼xxᵀuu = n_obs(pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo'))
123124

124-
# loop trough data
125-
@inbounds for rowdata in eachrow(pat.data)
125+
# loop through observations
126+
@inbounds for rowdata in eachcol(pat.data)
126127
mul!(𝔼xᵢu, Σuo, Σoo_chol \ (rowdata - μo))
127128
𝔼xᵢu .+= μu
128129
mul!(𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu', 1, 1)

src/observed/missing_pattern.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct SemObservedMissingPattern{T, S}
55
nobserved::Int
66
nmissed::Int
77
rows::Vector{Int} # rows in original data
8-
data::Matrix{T} # non-missing submatrix of data
8+
data::Matrix{T} # non-missing submatrix of data (vars × observations)
99

1010
obs_mean::Vector{S} # means of observed vars
1111
obs_cov::Symmetric{S, Matrix{S}} # covariance of observed vars
@@ -35,7 +35,7 @@ function SemObservedMissingPattern(
3535
sum(obs_mask),
3636
sum(miss_mask),
3737
rows,
38-
pat_data,
38+
permutedims(pat_data),
3939
dropdims(pat_mean, dims = 1),
4040
Symmetric(pat_cov),
4141
)

0 commit comments

Comments
 (0)