@@ -5,49 +5,62 @@ using RecursiveFactorization
55using SparseBandedMatrices
66
77@inline exphalf (x) = exp (x) * oftype (x, 0.5 )
8- function 🦋 ! (wv, :: Val{SEED} = Val (888 )) where {SEED}
8+ function generate_rand_butterfly_vals ! (wv, :: Val{SEED} = Val (888 )) where {SEED}
99 T = eltype (wv)
1010 mrng = VectorizedRNG. MutableXoshift (SEED)
1111 GC. @preserve mrng begin rand! (exphalf, VectorizedRNG. Xoshift (mrng), wv, static (0 ),
1212 T (- 0.05 ), T (0.1 )) end
1313end
1414
1515function 🦋generate_random! (A, :: Val{SEED} = Val (888 )) where {SEED}
16- Usz = 2 * size (A, 1 )
17- Vsz = 2 * size (A, 2 )
18- uv = similar (A, Usz + Vsz)
19- 🦋! (uv, Val (SEED))
20- (uv,)
16+ uv = similar (A, 4 * size (A, 1 ))
17+ generate_rand_butterfly_vals! (uv, Val (SEED))
18+ uv
2119end
22-
23- function 🦋workspace (A, b, B:: Matrix{T} , U:: Adjoint{T, Matrix{T}} , V:: Matrix{T} , thread, :: Val{SEED} = Val (888 )) where {T, SEED}
24- M = size (A, 1 )
25- if (M % 4 != 0 )
26- A = pad! (A)
20+ struct 🦋workspace{T}
21+ A:: Matrix{T}
22+ b:: Vector{T}
23+ ws:: Vector{T}
24+ U:: Matrix{T}
25+ V:: Matrix{T}
26+ out:: Vector{T}
27+ function 🦋workspace (A, b, :: Val{SEED} = Val (888 )) where {SEED}
28+ M = size (A, 1 )
29+ out = similar (b, M)
30+ if (M % 4 != 0 )
31+ A = pad! (A)
32+ xn = 4 - M % 4
33+ b = [b; rand (xn)]
34+ end
35+ U, V = (similar (A), similar (A))
36+ ws = 🦋generate_random! (A)
37+ materializeUV (U, V, ws)
38+ new {eltype(A)} (A, b, ws, U, V, out)
2739 end
28- B = similar (A)
29- ws = 🦋generate_random! (copyto! (B, A))
30- 🦋mul! (copyto! (B, A), ws)
31- U, V = materializeUV (B, ws)
32- F = RecursiveFactorization. lu! (B, thread)
33- out = similar (b, M)
34-
35- U, V, F, out
40+ end
41+
42+ function 🦋lu! (workspace: :🦋workspace, M, thread)
43+ (;A, b, ws, U, V, out) = workspace
44+ 🦋mul! (A, ws)
45+ F = RecursiveFactorization. lu! (A, Val (false ), thread)
46+ sol = V * (F \ (U' * b))
47+ out .= @view sol[1 : M]
48+ out
3649end
3750
3851const butterfly_workspace = 🦋workspace;
3952
4053function 🦋mul_level! (A, u, v)
4154 M, N = size (A)
4255 @assert M == length (u) && N == length (v)
43- Mh = M >>> 1
44- Nh = N >>> 1
45- @turbo for n in 1 : Nh
46- for m in 1 : Mh
56+ M_half = M >>> 1
57+ N_half = N >>> 1
58+ @turbo for n in 1 : N_half
59+ for m in 1 : M_half
4760 A11 = A[m, n]
48- A21 = A[m + Mh , n]
49- A12 = A[m, n + Nh ]
50- A22 = A[m + Mh , n + Nh ]
61+ A21 = A[m + M_half , n]
62+ A12 = A[m, n + N_half ]
63+ A22 = A[m + M_half , n + N_half ]
5164
5265 T1 = A11 + A12
5366 T2 = A21 + A22
@@ -59,32 +72,32 @@ function 🦋mul_level!(A, u, v)
5972 C22 = T3 - T4
6073
6174 u1 = u[m]
62- u2 = u[m + Mh ]
75+ u2 = u[m + M_half ]
6376 v1 = v[n]
64- v2 = v[n + Nh ]
77+ v2 = v[n + N_half ]
6578
6679 A[m, n] = u1 * C11 * v1
67- A[m + Mh , n] = u2 * C21 * v1
68- A[m, n + Nh ] = u1 * C12 * v2
69- A[m + Mh , n + Nh ] = u2 * C22 * v2
80+ A[m + M_half , n] = u2 * C21 * v1
81+ A[m, n + N_half ] = u1 * C12 * v2
82+ A[m + M_half , n + N_half ] = u2 * C22 * v2
7083 end
7184 end
7285end
7386
74- function 🦋mul! (A, (uv,) )
87+ function 🦋mul! (A, uv )
7588 M, N = size (A)
7689 @assert M == N
77- Mh = M >>> 1
90+ M_half = M >>> 1
7891
79- U₁ = @view (uv[1 : Mh ])
80- V₁ = @view (uv[(Mh + 1 ): (M)])
81- U₂ = @view (uv[(1 + M): (M + Mh )])
82- V₂ = @view (uv[(1 + M + Mh ): (2 * M)])
92+ U₁ = @view (uv[1 : M_half ])
93+ V₁ = @view (uv[(M_half + 1 ): (M)])
94+ U₂ = @view (uv[(1 + M): (M + M_half )])
95+ V₂ = @view (uv[(1 + M + M_half ): (2 * M)])
8396
84- 🦋mul_level! (@view (A[1 : Mh , 1 : Mh ]), U₁, V₁)
85- 🦋mul_level! (@view (A[Mh + 1 : M, 1 : Mh ]), U₂, V₁)
86- 🦋mul_level! (@view (A[1 : Mh, Mh + 1 : M]), U₁, V₂)
87- 🦋mul_level! (@view (A[Mh + 1 : M, Mh + 1 : M]), U₂, V₂)
97+ 🦋mul_level! (@view (A[1 : M_half , 1 : M_half ]), U₁, V₁)
98+ 🦋mul_level! (@view (A[M_half + 1 : M, 1 : M_half ]), U₂, V₁)
99+ 🦋mul_level! (@view (A[1 : M_half, M_half + 1 : M]), U₁, V₂)
100+ 🦋mul_level! (@view (A[M_half + 1 : M, M_half + 1 : M]), U₂, V₂)
88101
89102 U = @view (uv[(1 + 2 * M): (3 * M)])
90103 V = @view (uv[(1 + 3 * M): (4 * M)])
@@ -106,7 +119,14 @@ function diagnegbottom(x)
106119 Diagonal (y), Diagonal (z)
107120end
108121
109- function 🦋2 !(C, A:: Diagonal , B:: Diagonal )
122+ function 🦋! (C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
123+ setdiagonal! (C, [A. diag; - B. diag], true )
124+ setdiagonal! (C, A. diag, true )
125+ setdiagonal! (C, B. diag, false )
126+ C
127+ end
128+
129+ function 🦋! (C, A:: Diagonal , B:: Diagonal )
110130 @assert size (A) == size (B)
111131 A1 = size (A, 1 )
112132
@@ -120,61 +140,35 @@ function 🦋2!(C, A::Diagonal, B::Diagonal)
120140 C
121141end
122142
123- function 🦋! (A:: Matrix , C:: SparseBandedMatrix , X:: Diagonal , Y:: Diagonal )
124- @assert size (X) == size (Y)
125- if (size (X, 1 ) + size (Y, 1 ) != size (A, 1 ))
126- x = size (A, 1 ) - size (X, 1 ) - size (Y, 1 )
127- setdiagonal! (C, [X. diag; rand (x); - Y. diag], true )
128- setdiagonal! (C, X. diag, true )
129- setdiagonal! (C, Y. diag, false )
130- else
131- setdiagonal! (C, [X. diag; - Y. diag], true )
132- setdiagonal! (C, X. diag, true )
133- setdiagonal! (C, Y. diag, false )
134- end
135-
136- C
137- end
138-
139- function 🦋2 !(C:: SparseBandedMatrix , A:: Diagonal , B:: Diagonal )
140- setdiagonal! (C, [A. diag; - B. diag], true )
141- setdiagonal! (C, A. diag, true )
142- setdiagonal! (C, B. diag, false )
143- C
144- end
145-
146- function materializeUV (A, (uv,))
147- M, N = size (A)
148- Mh = M >>> 1
149- Nh = N >>> 1
143+ function materializeUV (U, V, uv)
144+ M = size (U, 1 )
145+ M_half = M >>> 1
150146
151- U₁u, U₁l = diagnegbottom (@view (uv[1 : Mh ])) # Mh
152- U₂u, U₂l = diagnegbottom (@view (uv[(1 + Mh + Nh ): (M + Nh )])) # M2
153- V₁u, V₁l = diagnegbottom (@view (uv[(Mh + 1 ): (Mh + Nh )])) # Nh
154- V₂u, V₂l = diagnegbottom (@view (uv[(1 + 2 * Mh + Nh ): (2 * Mh + N )])) # N2
155- Uu, Ul = diagnegbottom (@view (uv[(1 + M + N ): (2 * M + N )])) # M
156- Vu, Vl = diagnegbottom (@view (uv[(1 + 2 * M + N ): (2 * M + 2 * N )])) # N
147+ U₁u, U₁l = diagnegbottom (@view (uv[1 : M_half ])) # M_half
148+ U₂u, U₂l = diagnegbottom (@view (uv[(1 + 2 * M_half ): (M + M_half )])) # M_half
149+ V₁u, V₁l = diagnegbottom (@view (uv[(M_half + 1 ): (2 * M_half )])) # M_half
150+ V₂u, V₂l = diagnegbottom (@view (uv[(1 + 3 * M_half ): (2 * M_half + M )])) # M_half
151+ Uu, Ul = diagnegbottom (@view (uv[(1 + 2 * M ): (3 * M)])) # M
152+ Vu, Vl = diagnegbottom (@view (uv[(1 + 3 * M): (4 * M)])) # M
157153
158- Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
154+ Bu2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
159155
160- 🦋2 !(view (Bu2, 1 : Mh , 1 : Nh ), U₁u, U₁l)
161- 🦋2 !(view (Bu2, Mh + 1 : M, Nh + 1 : N ), U₂u, U₂l)
156+ 🦋! (view (Bu2, 1 : M_half , 1 : M_half ), U₁u, U₁l)
157+ 🦋! (view (Bu2, M_half + 1 : M, M_half + 1 : M ), U₂u, U₂l)
162158
163- Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
164- 🦋! (A, Bu1, Uu, Ul)
159+ Bu1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
160+ 🦋! (Bu1, Uu, Ul)
165161
166- Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
162+ Bv2 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
167163
168- 🦋2 !(view (Bv2, 1 : Mh , 1 : Nh ), V₁u, V₁l)
169- 🦋2 !(view (Bv2, Mh + 1 : M, Nh + 1 : N ), V₂u, V₂l)
164+ 🦋! (view (Bv2, 1 : M_half , 1 : M_half ), V₁u, V₁l)
165+ 🦋! (view (Bv2, M_half + 1 : M, M_half + 1 : M ), V₂u, V₂l)
170166
171- Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, N )
172- 🦋! (A, Bv1, Vu, Vl)
167+ Bv1 = SparseBandedMatrix {typeof(uv[1])} (undef, M, M )
168+ 🦋! (Bv1, Vu, Vl)
173169
174- U = (Bu2 * Bu1)'
175- V = Bv2 * Bv1
176-
177- U, V
170+ mul! (U, Bu2, Bu1)
171+ mul! (V, Bv2, Bv1)
178172end
179173
180174function pad! (A)
0 commit comments