diff --git a/Project.toml b/Project.toml index 5876dac72..afe564e53 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" diff --git a/src/Dagger.jl b/src/Dagger.jl index 7eea42df3..b29254d5d 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -8,7 +8,7 @@ import MemPool import MemPool: DRef, FileRef, poolget, poolset import Base: collect, reduce, view - +import NextLA import LinearAlgebra import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LU, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, Cholesky, diagind, ishermitian, issymmetric, I import Random @@ -127,6 +127,7 @@ include("array/mul.jl") include("array/cholesky.jl") include("array/trsm.jl") include("array/lu.jl") +include("array/qr.jl") # GPU include("gpu.jl") diff --git a/src/array/linalg.jl b/src/array/linalg.jl index 278a3244a..3bf8e20e0 100644 --- a/src/array/linalg.jl +++ b/src/array/linalg.jl @@ -105,8 +105,10 @@ function LinearAlgebra.LAPACK.chkfinite(A::DArray) end DMatrix{T}(::LinearAlgebra.UniformScaling, m::Int, n::Int, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, m, n), IBlocks) +DMatrix(::LinearAlgebra.UniformScaling{T}, m::Int, n::Int, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, m, n), IBlocks) DMatrix{T}(::LinearAlgebra.UniformScaling, size::Tuple, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, size), IBlocks) +DMatrix(::LinearAlgebra.UniformScaling{T}, size::Tuple, IBlocks::Blocks) where T = DMatrix(Matrix{T}(I, size), IBlocks) function LinearAlgebra.inv(F::LU{T,<:DMatrix}) where T n = size(F, 1) diff --git a/src/array/qr.jl b/src/array/qr.jl new file mode 100644 index 000000000..163a2a893 --- /dev/null +++ b/src/array/qr.jl @@ -0,0 +1,796 @@ +import LinearAlgebra: QRCompactWY, AdjointQ, BlasFloat, QRCompactWYQ, AbstractQ, StridedVecOrMat, I +# Maps Tm DArray → p (CAQR domain count) so that porgqr!/pormqr! can +# reconstruct the butterfly tree without requiring p as an argument. +# The algorithms used in this implementation are based on https://www.netlib.org/lapack/lawnspdf/lawn222.pdf +const _CAQR_P_MAP = WeakKeyDict{DArray, Int}() + +Base.:(*)(Q::QRCompactWYQ{T, M}, b::Number) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b +Base.:(*)(b::Number, Q::QRCompactWYQ{T, M}) where {T<:Number, M<:DMatrix{T}} = DMatrix(Q) * b + +Base.:(*)(Q::AdjointQ{T, QRCompactWYQ{T, M, C}}, b::Number) where {T<:Number, M<:DMatrix{T}, C<:M} = DMatrix(Q) * b +Base.:(*)(b::Number, Q::AdjointQ{T, QRCompactWYQ{T, M, C}}) where {T<:Number, M<:DMatrix{T}, C<:M} = DMatrix(Q) * b + +LinearAlgebra.lmul!(B::QRCompactWYQ{T, <:DMatrix{T}}, A::DMatrix{T}) where {T} = pormqr!('L', 'N', B.factors, B.T, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, <:DMatrix{T}}}, A::DMatrix{T}) where {T} + trans = T <: Complex ? 'C' : 'T' + pormqr!('L', trans, B.Q.factors, B.Q.T, A) +end +function _apply_dense_qr!(side::Char, trans::Char, Q::QRCompactWYQ{T, <:DMatrix{T}}, M::AbstractMatrix{T}) where {T} + # Distribute M with block sizes compatible with the Q factorization. + # _repartition_pormqr will further adjust, but we need a sane starting point. + bs = Q.factors.partitioning.blocksize # (mb, nb) + if side == 'L' + Cd = distribute(M, Blocks(bs[1], bs[1])) + else + Cd = distribute(M, Blocks(bs[2], bs[2])) + end + pormqr!(side, trans, Q.factors, Q.T, Cd) + copyto!(M, collect(Cd)) + return M +end +LinearAlgebra.lmul!(B::QRCompactWYQ{T, <:DMatrix{T}}, A::StridedVecOrMat{T}) where {T} = _apply_dense_qr!('L', 'N', B, A) +function LinearAlgebra.lmul!(B::AdjointQ{T, <:QRCompactWYQ{T, <:DMatrix{T}}}, A::StridedVecOrMat{T}) where {T} + trans = T <: Complex ? 'C' : 'T' + return _apply_dense_qr!('L', trans, B.Q, A) +end + +LinearAlgebra.rmul!(A::DMatrix{T}, B::QRCompactWYQ{T, <:DMatrix{T}}) where {T} = pormqr!('R', 'N', B.factors, B.T, A) +function LinearAlgebra.rmul!(A::DMatrix{T}, B::AdjointQ{T, <:QRCompactWYQ{T, <:DMatrix{T}}}) where {T} + trans = T <: Complex ? 'C' : 'T' + pormqr!('R', trans, B.Q.factors, B.Q.T, A) +end +LinearAlgebra.rmul!(A::StridedVecOrMat{T}, B::QRCompactWYQ{T, <:DMatrix{T}}) where {T} = _apply_dense_qr!('R', 'N', B, A) +function LinearAlgebra.rmul!(A::StridedVecOrMat{T}, B::AdjointQ{<:Any, <:QRCompactWYQ{T, <:DMatrix{T}, <:DMatrix{T}}}) where {T<:Union{Float32,Float64}} + trans = T <: Complex ? 'C' : 'T' + return _apply_dense_qr!('R', trans, B.Q, A) +end +function LinearAlgebra.rmul!(A::StridedVecOrMat{T}, B::AdjointQ{<:Any, <:QRCompactWYQ{T, <:DMatrix{T}, <:DMatrix{T}}}) where {T<:Union{ComplexF32,ComplexF64}} + trans = T <: Complex ? 'C' : 'T' + return _apply_dense_qr!('R', trans, B.Q, A) +end + +""" + _infer_caqr_p(A, Tm) -> Int + +Infer the CAQR domain count `p` from the reflector matrix `A` and the +T-factor matrix `Tm`. First checks the `_CAQR_P_MAP` cache; falls back +to heuristic detection from Tm column count. +""" +function _infer_caqr_p(A::DMatrix, Tm::DMatrix) + # Check the explicit cache first + haskey(_CAQR_P_MAP, Tm) && return _CAQR_P_MAP[Tm] + Ant = size(A.chunks, 2) + Tnt = size(Tm.chunks, 2) + Tnt == Ant && return 1 # no tree T → flat QR + return 1 # fallback — should not reach here if qr! populated the cache +end + +@inline function _tile_len(cum::AbstractVector{<:Integer}, idx::Int) + prev = idx == 1 ? 0 : cum[idx-1] + return cum[idx] - prev +end + +function _tile_lens(cum::AbstractVector{<:Integer}) + out = Vector{Int}(undef, length(cum)) + prev = 0 + for i in eachindex(cum) + out[i] = cum[i] - prev + prev = cum[i] + end + return out +end + +@inline function _locate_cut(cum::AbstractVector{<:Integer}, pos::Int) + idx = searchsortedfirst(cum, pos) + prev = idx == 1 ? 0 : cum[idx-1] + return idx, pos - prev +end + +function _is_uniform_square_tiling(A::DMatrix) + rcum = A.subdomains.cumlength[1] + ccum = A.subdomains.cumlength[2] + (isempty(rcum) || isempty(ccum)) && return false + r = _tile_lens(rcum) + c = _tile_lens(ccum) + return all(==(r[1]), r) && all(==(c[1]), c) && r[1] == c[1] +end + +@inline _use_irregular_qr_tiling(A::DMatrix) = !_is_uniform_square_tiling(A) + +function _largest_square_redistribution_block(A::DMatrix) + m, n = size(A) + mb, nb = A.partitioning.blocksize + lim = min(m, n, mb, nb) + lim <= 1 && return nothing + + g = gcd(m, n) + b = min(g, lim) + while b > 1 + if g % b == 0 + return b + end + b -= 1 + end + return nothing +end + +function _panel_steps(A::DMatrix) + rcum = A.subdomains.cumlength[1] + ccum = A.subdomains.cumlength[2] + m, n = size(A) + + steps = Vector{NTuple{7, Int}}() + dpos = 1 + cpos = 1 + while cpos <= n && dpos <= m + rd, lr = _locate_cut(rcum, dpos) + kc, lc = _locate_cut(ccum, cpos) + row_rem = rcum[rd] - dpos + 1 + col_rem = ccum[kc] - cpos + 1 + b = min(row_rem, col_rem, m - dpos + 1) + b > 0 || break + rend = _tile_len(rcum, rd) + cend = lc + b - 1 + push!(steps, (rd, kc, lr, rend, lc, cend, b)) + dpos += b + cpos += b + end + return steps, rcum, ccum +end + +function _geqrf_irregular!(A::DMatrix{T}, Tm::DMatrix{T}) where {T<:Number} + Ac = A.chunks + Tc = Tm.chunks + mt, nt = size(Ac) + trans = T <: Complex ? 'C' : 'T' + steps, _, ccum = _panel_steps(A) + + spawn_datadeps() do + for (rd, kc, lr, rend, lc, cend, b) in steps + Av = view(Ac[rd, kc], lr:rend, lc:cend) + Tv = view(Tc[rd, kc], :, lc:cend) + Dagger.@spawn NextLA.geqrt!(InOut(Av), InOut(Tv)) + + for n in kc:nt + tc1 = n == kc ? cend + 1 : 1 + tc2 = _tile_len(ccum, n) + tc1 <= tc2 || continue + Av = view(Ac[rd, kc], lr:rend, lc:cend) + Tv = view(Tc[rd, kc], :, lc:cend) + Cv = view(Ac[rd, n], lr:rend, tc1:tc2) + Dagger.@spawn NextLA.unmqr!('L', trans, In(Av), In(Tv), InOut(Cv)) + end + + for m in rd+1:mt + A1 = view(Ac[rd, kc], lr:lr+b-1, lc:cend) + A2 = view(Ac[m, kc], :, lc:cend) + Tv = view(Tc[m, kc], :, lc:cend) + Dagger.@spawn NextLA.tsqrt!(InOut(A1), InOut(A2), InOut(Tv)) + for n in kc:nt + tc1 = n == kc ? cend + 1 : 1 + tc2 = _tile_len(ccum, n) + tc1 <= tc2 || continue + C1 = view(Ac[rd, n], lr:lr+b-1, tc1:tc2) + C2 = view(Ac[m, n], :, tc1:tc2) + V = view(Ac[m, kc], :, lc:cend) + Tv = view(Tc[m, kc], :, lc:cend) + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(C1), InOut(C2), In(V), In(Tv)) + end + end + end + end + return A +end + +function _porgqr_irregular!(trans::Char, A::DMatrix{T}, Tm::DMatrix{T}, Q::DMatrix{T}) where {T<:Number} + Ac = A.chunks + Tc = Tm.chunks + Qc = Q.chunks + mt, _ = size(Ac) + qmt, qnt = size(Qc) + qmt == mt || throw(ArgumentError("Q row tiling must match A row tiling for irregular QR")) + steps, _, _ = _panel_steps(A) + + spawn_datadeps() do + if trans == 'N' + for (rd, kc, lr, rend, lc, cend, b) in Iterators.reverse(steps) + for m in qmt:-1:rd+1, n in 1:qnt + C1 = view(Qc[rd, n], lr:lr+b-1, :) + C2 = view(Qc[m, n], :, :) + V = view(Ac[m, kc], :, lc:cend) + Tv = view(Tc[m, kc], :, lc:cend) + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(C1), InOut(C2), In(V), In(Tv)) + end + for n in 1:qnt + Av = view(Ac[rd, kc], lr:rend, lc:cend) + Tv = view(Tc[rd, kc], :, lc:cend) + Cv = view(Qc[rd, n], lr:rend, :) + Dagger.@spawn NextLA.unmqr!('L', trans, In(Av), In(Tv), InOut(Cv)) + end + end + else + for (rd, kc, lr, rend, lc, cend, b) in steps + for n in 1:qnt + Av = view(Ac[rd, kc], lr:rend, lc:cend) + Tv = view(Tc[rd, kc], :, lc:cend) + Cv = view(Qc[rd, n], lr:rend, :) + Dagger.@spawn NextLA.unmqr!('L', trans, In(Av), In(Tv), InOut(Cv)) + end + for m in rd+1:qmt, n in 1:qnt + C1 = view(Qc[rd, n], lr:lr+b-1, :) + C2 = view(Qc[m, n], :, :) + V = view(Ac[m, kc], :, lc:cend) + Tv = view(Tc[m, kc], :, lc:cend) + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(C1), InOut(C2), In(V), In(Tv)) + end + end + end + end + return Q +end + +function _pormqr_irregular!(side::Char, trans::Char, A::DMatrix{T}, Tm::DMatrix{T}, C::DMatrix{T}) where {T<:Number} + Q = DMatrix(QRCompactWYQ(A, Tm)) + Qop = + trans == 'N' ? Q : + trans == 'T' ? transpose(Q) : + trans == 'C' ? adjoint(Q) : + throw(ArgumentError("trans must be 'N', 'T', or 'C', got '$trans'")) + + updated = side == 'L' ? Qop * C : + side == 'R' ? C * Qop : + throw(ArgumentError("side must be 'L' or 'R', got '$side'")) + copyto!(C, updated) + return C +end + +function DMatrix(Q::QRCompactWYQ{T, <:DMatrix{T}}) where {T} + DQ = DMatrix(I*one(T), size(Q), Q.factors.partitioning) + p = _infer_caqr_p(Q.factors, Q.T) + porgqr!('N', Q.factors, Q.T, DQ; p=p) + return DQ +end + +function DMatrix(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:DMatrix{T}}}) where {T} + DQ = DMatrix(I*one(T), size(AQ), AQ.Q.factors.partitioning) + trans = T <: Complex ? 'C' : 'T' + p = _infer_caqr_p(AQ.Q.factors, AQ.Q.T) + porgqr!(trans, AQ.Q.factors, AQ.Q.T, DQ; p=p) + return DQ +end + +Base.collect(Q::QRCompactWYQ{T, <:DMatrix{T}}) where {T} = collect(DMatrix(Q)) +Base.collect(AQ::AdjointQ{T, <:QRCompactWYQ{T, <:DMatrix{T}}}) where {T} = collect(DMatrix(AQ)) + +function _repartition_pormqr(A, Tm, C, side::Char, trans::Char) + partA = A.partitioning.blocksize + partTm = Tm.partitioning.blocksize + partC = C.partitioning.blocksize + + # The pormqr! kernels assume that the number of row tiles (index k) + # matches between the reflector matrix A and the target matrix C. + # Adjust C's block size accordingly but avoid reshaping A or Tm, + # as their chunking encodes the factorisation structure. + partC_new = partC + if side == 'L' + # Q * C or Q' * C: C's row blocking must match A's row blocking. + partC_new = (partA[1], partC[2]) + else + # C * Q or C * Q': Align C's column blocking with A's column blocking + # so that the k index traverses compatible tile widths. + partC_new = (partC[1], partA[2]) + end + + return Blocks(partA...), Blocks(partTm...), Blocks(partC_new...) +end + +function pormqr!(side::Char, trans::Char, A::DMatrix{T}, Tm::DMatrix{T}, C::DMatrix{T}; p::Int=_infer_caqr_p(A, Tm)) where {T<:Number} + if _use_irregular_qr_tiling(A) + return _pormqr_irregular!(side, trans, A, Tm, C) + end + + partA, partTm, partC = _repartition_pormqr(A, Tm, C, side, trans) + + return maybe_copy_buffered(A=>partA, Tm=>partTm, C=>partC) do A, Tm, C + return _pormqr_impl!(side, trans, A, Tm, C; p=p) + end +end + +function _pormqr_impl!(side::Char, trans::Char, A::DMatrix{T}, Tm::DMatrix{T}, C::DMatrix{T}; p::Int=1) where {T<:Number} + m, n = size(C) + Ac = A.chunks + Tc = Tm.chunks + Cc = C.chunks + + Amt, Ant = size(Ac) + Tmt, Tnt = size(Tc) + Cmt, Cnt = size(Cc) + minMT = min(Amt, Ant) + has_tree = Tnt > Ant # CAQR tree T present + Tc2_offset = has_tree ? Tnt ÷ 2 : 0 + mtd = p > 1 ? Amt ÷ p : Amt # tiles per domain + + # ordering rule: + # tree_first = (side == Left) == (trans == NoTrans) + # When tree_first, apply ttmqr (descend: root→leaves) BEFORE local ops; + # otherwise, apply local ops first, then ttmqr (ascend: leaves→root). + + spawn_datadeps() do + if side == 'L' + if trans == 'T' || trans == 'C' + # Left, ConjTrans: unmqr first, then ttmqr ascending. k forward. + for k in 1:minMT + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # Domain-local reflectors: unmqr then tsmqr within each domain + for pt in proot:p + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for n in 1:Cnt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Cc[ibeg,n])) + end + for m in ibeg+1:iend, n in 1:Cnt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[ibeg, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + + # Tree TT reduction (ascend: leaves→root) + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in 1:max_level + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in 1:Cnt + Dagger.@spawn NextLA.ttmqr!(side, trans, InOut(Cc[i1, n]), InOut(Cc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + end + end + if trans == 'N' + # Left, NoTrans: ttmqr descending first, then tsmqr+unmqr. k reverse. + for k in minMT:-1:1 + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # Tree TT reduction (descend: root→leaves) + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in max_level:-1:1 + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in 1:Cnt + Dagger.@spawn NextLA.ttmqr!(side, trans, InOut(Cc[i1, n]), InOut(Cc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + + # Domain-local reflectors: tsmqr then unmqr within each domain (reverse) + for pt in p:-1:proot + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for m in iend:-1:ibeg+1, n in 1:Cnt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[ibeg, n]), InOut(Cc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in 1:Cnt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Cc[ibeg, n])) + end + end + end + end + else + if side == 'R' + nmax = min(Cnt, Ant) + if trans == 'T' || trans == 'C' + # Right, ConjTrans: ttmqr descending first, then tsmqr+unmqr. k reverse. + for k in minMT:-1:1 + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # Tree TT reduction (descend: root→leaves) + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in max_level:-1:1 + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for m in 1:Cmt + Dagger.@spawn NextLA.ttmqr!(side, trans, InOut(Cc[m, i1]), InOut(Cc[m, i2]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + + # Domain-local reflectors: tsmqr then unmqr within each domain (reverse) + for pt in p:-1:proot + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for n in iend:-1:ibeg+1 + for m in 1:Cmt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[m, ibeg]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + for m in 1:Cmt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Cc[m, ibeg])) + end + end + end + end + if trans == 'N' + # Right, NoTrans: unmqr+tsmqr first, then ttmqr ascending. k forward. + for k in 1:minMT + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # Domain-local reflectors: unmqr then tsmqr within each domain + for pt in proot:p + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for m in 1:Cmt + Dagger.@spawn NextLA.unmqr!(side, trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Cc[m, ibeg])) + end + for n in ibeg+1:iend + for m in 1:Cmt + Dagger.@spawn NextLA.tsmqr!(side, trans, InOut(Cc[m, ibeg]), InOut(Cc[m, n]), In(Ac[n, k]), In(Tc[n, k])) + end + end + end + + # Tree TT reduction (ascend: leaves→root) + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in 1:max_level + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for m in 1:Cmt + Dagger.@spawn NextLA.ttmqr!(side, trans, InOut(Cc[m, i1]), InOut(Cc[m, i2]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + end + end + end + end + end + return C +end + + +function cageqrf!(A::DMatrix{T}, Tm::DMatrix{T}; p::Int=1) where {T<: Number} + if p == 1 + return geqrf!(A, Tm) + end + + _use_irregular_qr_tiling(A) && throw(ArgumentError("p > 1 requires uniform square tiling")) + + Ac = A.chunks + mt, nt = size(Ac) + @assert mt % p == 0 "Number of tiles must be divisible by the number of domains" + mtd = Int(mt/p) + Tc = Tm.chunks + # Tree reduction T matrices are stored in the second half of Tm's + # columns (columns nt+1 : 2*nt), + Tc2_offset = size(Tc, 2) ÷ 2 # == nt + proot = 1 + nxtmt = mtd + trans = T <: Complex ? 'C' : 'T' + spawn_datadeps() do + for k in 1:min(mt, nt) + if k > nxtmt + proot += 1 + nxtmt += mtd + end + for pt in proot:p + ibeg = 1 + (pt-1) * mtd + if pt == proot + ibeg = k + end + Dagger.@spawn NextLA.geqrt!(InOut(Ac[ibeg, k]), Out(Tc[ibeg,k])) + for n in k+1:nt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg,k]), InOut(Ac[ibeg, n])) + end + for m in ibeg+1:(pt * mtd) + Dagger.@spawn NextLA.tsqrt!(InOut(Ac[ibeg, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Ac[ibeg, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + for m in 1:ceil(Int64, log2(p-proot+1)) + p1 = proot + p2 = p1 + 2^(m-1) + while p2 ≤ p + i1 = 1 + (p1-1) * mtd + i2 = 1 + (p2-1) * mtd + if p1 == proot + i1 = k + end + Dagger.@spawn NextLA.ttqrt!(InOut(Ac[i1, k]), InOut(Ac[i2, k]), Out(Tc[i2, k + Tc2_offset])) + for n in k+1:nt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Ac[i1, n]), InOut(Ac[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m + p2 += 2^m + end + end + end + end +end + +function geqrf!(A::DMatrix{T}, Tm::DMatrix{T}) where {T<: Number} + if _use_irregular_qr_tiling(A) + return _geqrf_irregular!(A, Tm) + end + + Ac = A.chunks + mt, nt = size(Ac) + Tc = Tm.chunks + trans = T <: Complex ? 'C' : 'T' + + spawn_datadeps() do + for k in 1:min(mt, nt) + Dagger.@spawn NextLA.geqrt!(InOut(Ac[k, k]), Out(Tc[k,k])) + for n in k+1:nt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[k,k]), In(Tc[k,k]), InOut(Ac[k, n])) + end + for m in k+1:mt + Dagger.@spawn NextLA.tsqrt!(InOut(Ac[k, k]), InOut(Ac[m, k]), Out(Tc[m,k])) + for n in k+1:nt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Ac[k, n]), InOut(Ac[m, n]), In(Ac[m, k]), In(Tc[m,k])) + end + end + end + end +end + +function porgqr!(trans::Char, A::DMatrix{T}, Tm::DMatrix{T}, Q::DMatrix{T}; p::Int=_infer_caqr_p(A, Tm)) where {T<:Number} + if _use_irregular_qr_tiling(A) + return _porgqr_irregular!(trans, A, Tm, Q) + end + + Ac = A.chunks + Tc = Tm.chunks + Qc = Q.chunks + mt, nt = size(Ac) + qmt, qnt = size(Qc) + has_tree = size(Tc, 2) > nt # CAQR tree T present + Tc2_offset = has_tree ? size(Tc, 2) ÷ 2 : 0 + mtd = p > 1 ? mt ÷ p : mt # tiles per domain + + spawn_datadeps() do + if has_tree && p > 1 + if trans == 'N' + # Build Q with CAQR tree: + # 1) apply tree nodes (root->leaves), 2) apply per-domain local blocks. + for k in min(mt, nt):-1:1 + proot = ((k - 1) ÷ mtd) + 1 + + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in max_level:-1:1 + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in k:qnt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Qc[i1, n]), InOut(Qc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + + # Local reflectors are domain-local (not global k+1:mt). + # Column range must start at k (not ibeg) because the + # preceding tree ttmqr step already coupled rows from + # different domains, making Q[ibeg, k:ibeg-1] nonzero. + for pt in p:-1:proot + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for m in iend:-1:ibeg+1, n in k:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[ibeg, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Qc[ibeg, n])) + end + end + end + else + # Build Q' with CAQR tree: + # 1) apply per-domain local blocks, 2) apply tree nodes (leaves->root). + for k in 1:min(mt, nt) + proot = ((k - 1) ÷ mtd) + 1 + + for pt in proot:p + ibeg = pt == proot ? k : 1 + (pt - 1) * mtd + iend = pt * mtd + for n in 1:qnt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[ibeg, k]), In(Tc[ibeg, k]), InOut(Qc[ibeg, n])) + end + for m in ibeg+1:iend, n in 1:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[ibeg, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + end + + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in 1:max_level + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in 1:qnt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Qc[i1, n]), InOut(Qc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + end + return + end + + if trans == 'N' + # Building Q: Left/NoTrans. ordering: ttmqr first (descend), + # then local tsmqr + unmqr. Sweep k in reverse. + for k in min(mt, nt):-1:1 + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # --- Tree TT reduction (descend: root→leaves) --- + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in max_level:-1:1 + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in k:qnt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Qc[i1, n]), InOut(Qc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + + # --- Local: tsmqr then unmqr (per domain, reverse) --- + for m in qmt:-1:k + 1, n in k:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + for n in k:qnt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + end + else + # Building Q': Left/ConjTrans. ordering: local unmqr + tsmqr + # first, then ttmqr (ascend: leaves→root). Sweep k forward. + for k in 1:min(mt, nt) + proot = p > 1 ? ((k - 1) ÷ mtd) + 1 : 1 + + # --- Local: unmqr then tsmqr --- + for n in 1:qnt + Dagger.@spawn NextLA.unmqr!('L', trans, In(Ac[k, k]), + In(Tc[k, k]), InOut(Qc[k, n])) + end + for m in k+1:qmt, n in 1:qnt + Dagger.@spawn NextLA.tsmqr!('L', trans, InOut(Qc[k, n]), InOut(Qc[m, n]), In(Ac[m, k]), In(Tc[m, k])) + end + + # --- Tree TT reduction (ascend: leaves→root) --- + if has_tree && p > 1 + max_level = ceil(Int, log2(p - proot + 1)) + for m_level in 1:max_level + p1 = proot + p2 = p1 + 2^(m_level - 1) + while p2 ≤ p + i1 = 1 + (p1 - 1) * mtd + i2 = 1 + (p2 - 1) * mtd + if p1 == proot + i1 = k + end + for n in 1:qnt + Dagger.@spawn NextLA.ttmqr!('L', trans, InOut(Qc[i1, n]), InOut(Qc[i2, n]), In(Ac[i2, k]), In(Tc[i2, k + Tc2_offset])) + end + p1 += 2^m_level + p2 += 2^m_level + end + end + end + end + end + end +end + +function qr_measure_workspace(A::DMatrix{T}, ib::Int) where {T<: Number} + mb, nb = A.partitioning.blocksize + m, n = size(A) + MT = cld(m, mb) + NT = cld(n, nb) + lm = ib * MT + ln = nb * NT + lm, ln +end + +function _qr_impl!(A::DMatrix{T}; ib::Int=1, p::Int=1) where {T<:Number} + p >= 1 || throw(ArgumentError("p must be >= 1, got $p")) + lm, ln = qr_measure_workspace(A, ib) + nb = A.partitioning.blocksize[2] + irregular = _use_irregular_qr_tiling(A) + + Tm_cols = p > 1 ? 2 * ln : ln + Tm = DArray{T}(undef, Blocks(ib, nb), (lm, Tm_cols)) + + if irregular + # Hierarchical irregular path (no CAQR tree metadata). + _geqrf_irregular!(A, Tm) + elseif p == 1 + geqrf!(A, Tm) + else + _CAQR_P_MAP[Tm] = p + cageqrf!(A, Tm; p=p) + end + return QRCompactWY(A, Tm) +end + +function LinearAlgebra.qr!(A::DMatrix{T}; ib::Int=1, p::Int=1) where {T<:Number} + p >= 1 || throw(ArgumentError("p must be >= 1, got $p")) + + if !_use_irregular_qr_tiling(A) + return _qr_impl!(A; ib=ib, p=p) + end + + # Non-square or non-uniform tiling: try to repartition to uniform square + # tiles first. If impossible, run the irregular hierarchical path. + b = _largest_square_redistribution_block(A) + if b !== nothing + part = Blocks(b, b) + return maybe_copy_buffered(A=>part) do Abuf + _qr_impl!(Abuf; ib=ib, p=p) + end + end + + return _qr_impl!(A; ib=ib, p=1) +end diff --git a/test/Project.toml b/test/Project.toml index 72daddfb3..af2d30b14 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" diff --git a/test/array/linalg/qr.jl b/test/array/linalg/qr.jl new file mode 100644 index 000000000..ffa55f005 --- /dev/null +++ b/test/array/linalg/qr.jl @@ -0,0 +1,573 @@ +import Dagger: geqrf!, porgqr!, pormqr!, cageqrf! + +# ────────────────────────────────────────────────────────────────────── +# Helper: run residual checks on a QR factorization result +# ────────────────────────────────────────────────────────────────────── +function check_qr_residuals(A_col, DQ, DR; tol=2.0) + T = eltype(A_col) + eps_val = eps(real(T)) + + R = collect(DR) + QR = collect(DQ * DR) + CDI = collect(DQ' * DQ) + + res1 = opnorm(A_col - QR, 1) / (opnorm(A_col, 1) * max(size(A_col)...) * eps_val) + res2 = opnorm(I - CDI, 1) / (size(A_col, 2) * eps_val) + + @test res1 < tol # Factorization residual + @test res2 < tol # Orthogonality residual + @test triu(R) ≈ R # R is upper triangular +end + +# ====================================================================== +# 1. Basic qr(DA) — varying shapes and block sizes (p=1, ib=1) +# ====================================================================== +@testset "Tile QR: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + @testset "Square matrices" begin + @testset "blocks=$bs" for bs in [(32,32), (16,32), (32,16)] + A = rand(T, 128, 128) + DA = distribute(A, Blocks(bs...)) + DQ, DR = qr(DA) + check_qr_residuals(collect(DA), DQ, DR) + end + end + + @testset "Tall matrices" begin + @testset "blocks=$bs" for bs in [(32,32), (16,32), (32,16)] + A = rand(T, 128, 64) + DA = distribute(A, Blocks(bs...)) + DQ, DR = qr(DA) + check_qr_residuals(collect(DA), DQ, DR) + end + end + + @testset "Wide matrices" begin + @testset "blocks=$bs" for bs in [(32,32), (16,32), (32,16)] + A = rand(T, 64, 128) + DA = distribute(A, Blocks(bs...)) + DQ, DR = qr(DA) + check_qr_residuals(collect(DA), DQ, DR) + end + end +end + +# ====================================================================== +# 2. In-place qr! (ensures qr! path is exercised directly) +# ====================================================================== +@testset "In-place qr!: $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32,32)) + DA_copy = copy(DA) + F = qr!(DA_copy) + + # F is a QRCompactWY; extract Q and R + DQ_mat = DMatrix(F.Q) + DR = DArray{T}( undef, DA_copy.partitioning, size(DA_copy)) + # Extract upper triangle from the factored DA_copy into DR + R_local = triu(collect(DA_copy)) + DR = distribute(R_local, DA_copy.partitioning) + + QR = collect(DQ_mat) * R_local + eps_val = eps(real(T)) + res = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + @test res < 2.0 +end + +# ====================================================================== +# 3. CAQR with p > 1 (communication-avoiding QR) +# ====================================================================== +@testset "CAQR p=$p: $T" for T in (Float64, ComplexF64), + p in (2, 4) + # mt must be divisible by p; 128/32 = 4 tiles → p ∈ {2, 4} work + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + DQ, DR = F.Q, triu(collect(DA_copy)) + + DQ_mat = collect(DMatrix(DQ)) + QR = DQ_mat * DR + eps_val = eps(real(T)) + res1 = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + CDI = DQ_mat' * DQ_mat + res2 = opnorm(I - CDI, 1) / (size(A, 2) * eps_val) + + @test res1 < 2.0 + @test res2 < 2.0 + @test triu(DR) ≈ DR +end + +# ====================================================================== +# 4. pormqr! — left and right multiply by Q and Q' +# ====================================================================== +@testset "pormqr! side=$side trans=$trans: $T" for T in (Float64, ComplexF64), + side in ('L', 'R'), + trans in ('N', T <: Complex ? 'C' : 'T') + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32,32)) + F = qr(DA) + Q_dense = collect(F.Q) + + if side == 'L' + # Q * C or Q' * C + C = rand(T, 128, 64) + DC = distribute(C, Blocks(32, 32)) + pormqr!(side, trans, F.Q.factors, F.Q.T, DC) + Qt = trans == 'N' ? Q_dense : Q_dense' + expected = Qt * C + else + # C * Q or C * Q' + C = rand(T, 64, 128) + DC = distribute(C, Blocks(32, 32)) + pormqr!(side, trans, F.Q.factors, F.Q.T, DC) + Qt = trans == 'N' ? Q_dense : Q_dense' + expected = C * Qt + end + + eps_val = eps(real(T)) + res = opnorm(collect(DC) - expected, 1) / (opnorm(expected, 1) * max(size(C)...) * eps_val) + @test res < 2.0 +end + +# ====================================================================== +# 5. lmul! / rmul! dispatch (Q*DA, Q'*DA, DA*Q, DA*Q') +# ====================================================================== +@testset "lmul!/rmul!: $T" for T in (Float64, ComplexF64) + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32,32)) + F = qr(DA) + Q_dense = collect(F.Q) + trans = T <: Complex ? 'C' : 'T' + + # lmul! Q * B (DMatrix target) + @testset "lmul! Q*DMatrix" begin + B = rand(T, 128, 64) + DB = distribute(B, Blocks(32, 32)) + lmul!(F.Q, DB) + @test collect(DB) ≈ Q_dense * B + end + + # lmul! Q' * B (DMatrix target) + @testset "lmul! Q'*DMatrix" begin + B = rand(T, 128, 64) + DB = distribute(B, Blocks(32, 32)) + lmul!(F.Q', DB) + @test collect(DB) ≈ Q_dense' * B + end + + # rmul! B * Q (DMatrix target) + @testset "rmul! DMatrix*Q" begin + B = rand(T, 64, 128) + DB = distribute(B, Blocks(32, 32)) + rmul!(DB, F.Q) + @test collect(DB) ≈ B * Q_dense + end + + # rmul! B * Q' (DMatrix target) + @testset "rmul! DMatrix*Q'" begin + B = rand(T, 64, 128) + DB = distribute(B, Blocks(32, 32)) + rmul!(DB, F.Q') + @test collect(DB) ≈ B * Q_dense' + end + + # lmul! Q * dense (AbstractMatrix target) + @testset "lmul! Q*dense" begin + B = rand(T, 128, 64) + B_copy = copy(B) + lmul!(F.Q, B_copy) + @test B_copy ≈ Q_dense * B + end + + # lmul! Q' * dense (AbstractMatrix target) + @testset "lmul! Q'*dense" begin + B = rand(T, 128, 64) + B_copy = copy(B) + lmul!(F.Q', B_copy) + @test B_copy ≈ Q_dense' * B + end + + # rmul! dense * Q (AbstractMatrix target) + @testset "rmul! dense*Q" begin + B = rand(T, 64, 128) + B_copy = copy(B) + rmul!(B_copy, F.Q) + @test B_copy ≈ B * Q_dense + end + + # rmul! dense * Q' (AbstractMatrix target) + @testset "rmul! dense*Q'" begin + B = rand(T, 64, 128) + B_copy = copy(B) + rmul!(B_copy, F.Q') + @test B_copy ≈ B * Q_dense' + end +end + +# ====================================================================== +# 6. Scalar multiplication: Q*s, s*Q, Q'*s, s*Q' +# ====================================================================== +@testset "Scalar * Q: $T" for T in (Float64, ComplexF64) + A = rand(T, 64, 64) + DA = distribute(A, Blocks(32,32)) + F = qr(DA) + Q_dense = collect(F.Q) + s = T(3.0) + + @test collect(F.Q * s) ≈ Q_dense * s + @test collect(s * F.Q) ≈ Q_dense * s + @test collect(F.Q' * s) ≈ Q_dense' * s + @test collect(s * F.Q') ≈ Q_dense' * s +end + +# ====================================================================== +# 7. DMatrix(Q) and DMatrix(Q') explicit construction +# ====================================================================== +@testset "DMatrix(Q) / DMatrix(Q'): $T" for T in (Float64, ComplexF64) + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32,32)) + F = qr(DA) + + DQ = DMatrix(F.Q) + DQa = DMatrix(F.Q') + + Q_dense = collect(DQ) + Qa_dense = collect(DQa) + eps_val = eps(real(T)) + + # Q should be orthogonal + CDI = Q_dense' * Q_dense + res = opnorm(I - CDI, 1) / (size(A, 2) * eps_val) + @test res < 2.0 + + # Q' should equal adjoint of Q + @test Qa_dense ≈ Q_dense' +end + +# ====================================================================== +# 8. collect(Q) and collect(Q') +# ====================================================================== +@testset "collect(Q) / collect(Q'): $T" for T in (Float64, ComplexF64) + A = rand(T, 64, 64) + DA = distribute(A, Blocks(32,32)) + F = qr(DA) + + Q_col = collect(F.Q) + Qa_col = collect(F.Q') + + @test Q_col' ≈ Qa_col + @test size(Q_col) == size(A) +end + +# ====================================================================== +# 9. geqrf! / porgqr! low-level API +# ====================================================================== +@testset "geqrf! + porgqr!: $T" for T in (Float64, ComplexF64) + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32,32)) + DA_copy = copy(DA) + + nb = DA_copy.partitioning.blocksize[2] + ib = 1 + lm, ln = Dagger.qr_measure_workspace(DA_copy, ib) + Tm = DArray{T}(undef, Blocks(ib, nb), (lm, ln)) + + geqrf!(DA_copy, Tm) + + # Build Q explicitly via porgqr! + DQ = DMatrix(I*one(T), (size(A, 1), size(A, 1)), DA_copy.partitioning) + porgqr!('N', DA_copy, Tm, DQ) + + Q_dense = collect(DQ) + R = triu(collect(DA_copy)) + eps_val = eps(real(T)) + + QR = Q_dense * R + res1 = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + CDI = Q_dense' * Q_dense + res2 = opnorm(I - CDI, 1) / (size(A, 1) * eps_val) + + @test res1 < 2.0 + @test res2 < 2.0 +end + +# ====================================================================== +# 11. CAQR pormqr! — left and right multiply with p > 1 +# ====================================================================== +@testset "CAQR pormqr! p=$p side=$side trans=$trans: $T" for T in (Float64, ComplexF64), + p in (2, 4), + side in ('L', 'R'), + trans in ('N', T <: Complex ? 'C' : 'T') + # Factor with CAQR + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + Q_dense = collect(DMatrix(F.Q)) + + if side == 'L' + C = rand(T, 128, 64) + DC = distribute(C, Blocks(32, 32)) + pormqr!(side, trans, F.Q.factors, F.Q.T, DC; p=p) + Qt = trans == 'N' ? Q_dense : Q_dense' + expected = Qt * C + else + C = rand(T, 64, 128) + DC = distribute(C, Blocks(32, 32)) + pormqr!(side, trans, F.Q.factors, F.Q.T, DC; p=p) + Qt = trans == 'N' ? Q_dense : Q_dense' + expected = C * Qt + end + + eps_val = eps(real(T)) + res = opnorm(collect(DC) - expected, 1) / (opnorm(expected, 1) * max(size(C)...) * eps_val) + @test res < 2.0 +end + +# ====================================================================== +# 12. CAQR lmul!/rmul! with p > 1 +# ====================================================================== +@testset "CAQR lmul!/rmul! p=$p: $T" for T in (Float64, ComplexF64), + p in (2, 4) + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + Q_dense = collect(DMatrix(F.Q)) + + @testset "lmul! Q*DMatrix" begin + B = rand(T, 128, 64) + DB = distribute(B, Blocks(32, 32)) + lmul!(F.Q, DB) + @test collect(DB) ≈ Q_dense * B + end + + @testset "lmul! Q'*DMatrix" begin + B = rand(T, 128, 64) + DB = distribute(B, Blocks(32, 32)) + lmul!(F.Q', DB) + @test collect(DB) ≈ Q_dense' * B + end + + @testset "rmul! DMatrix*Q" begin + B = rand(T, 64, 128) + DB = distribute(B, Blocks(32, 32)) + rmul!(DB, F.Q) + @test collect(DB) ≈ B * Q_dense + end + + @testset "rmul! DMatrix*Q'" begin + B = rand(T, 64, 128) + DB = distribute(B, Blocks(32, 32)) + rmul!(DB, F.Q') + @test collect(DB) ≈ B * Q_dense' + end +end + +# ====================================================================== +# 13. CAQR with Float32/ComplexF32 +# ====================================================================== +@testset "CAQR single-precision p=$p: $T" for T in (Float32, ComplexF32), + p in (2, 4) + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + DQ, DR = F.Q, triu(collect(DA_copy)) + + DQ_mat = collect(DMatrix(DQ)) + QR = DQ_mat * DR + eps_val = eps(real(T)) + res1 = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + CDI = DQ_mat' * DQ_mat + res2 = opnorm(I - CDI, 1) / (size(A, 2) * eps_val) + + @test res1 < 2.0 + @test res2 < 2.0 + @test triu(DR) ≈ DR +end + +# ====================================================================== +# 14. CAQR with square and wide matrices +# ====================================================================== +@testset "CAQR shapes p=$p: $T" for T in (Float64,), p in (2, 4) + @testset "Square 128×128" begin + A = rand(T, 128, 128) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + DQ_mat = collect(DMatrix(F.Q)) + R = triu(collect(DA_copy)) + eps_val = eps(T) + res1 = opnorm(A - DQ_mat * R, 1) / (opnorm(A, 1) * 128 * eps_val) + res2 = opnorm(I - DQ_mat' * DQ_mat, 1) / (128 * eps_val) + @test res1 < 2.0 + @test res2 < 2.0 + end +end + +# ====================================================================== +# 15. CAQR extreme p = mt (one tile per domain) +# ====================================================================== +@testset "CAQR p=mt: $T" for T in (Float64, ComplexF64) + # 128/32 = 4 tiles → p = 4 = mt for a 128×32 tall matrix + A = rand(T, 128, 32) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + p = size(DA.chunks, 1) # p = mt + F = qr!(DA_copy; p=p) + DQ_mat = collect(DMatrix(F.Q)) + R = triu(collect(DA_copy)) + eps_val = eps(real(T)) + res1 = opnorm(A - DQ_mat * R, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + res2 = opnorm(I - DQ_mat' * DQ_mat, 1) / (size(A, 2) * eps_val) + @test res1 < 2.0 + @test res2 < 2.0 +end + +# ====================================================================== +# 16. CAQR porgqr! direct call with p > 1 (both trans='N' and trans='T'/'C') +# ====================================================================== +@testset "CAQR porgqr! direct p=$p trans=$trans: $T" for T in (Float64, ComplexF64), + p in (2, 4), + trans in ('N', T <: Complex ? 'C' : 'T') + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + + DQ = DMatrix(I*one(T), (size(A, 1), size(A, 1)), F.Q.factors.partitioning) + porgqr!(trans, F.Q.factors, F.Q.T, DQ; p=p) + + Q_dense = collect(DQ) + eps_val = eps(real(T)) + + if trans == 'N' + R = triu(collect(DA_copy)) + QR = Q_dense * R + res1 = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + @test res1 < 2.0 + end + + # Orthogonality: Q'*Q ≈ I regardless of trans + # For trans != 'N', Q_dense is already Q', so Q_dense * Q_dense' ≈ I + if trans == 'N' + CDI = Q_dense' * Q_dense + else + CDI = Q_dense * Q_dense' + end + res2 = opnorm(I - CDI, 1) / (size(A, 2) * eps_val) + @test res2 < 2.0 +end + +# ====================================================================== +# 17. cageqrf! direct call +# ====================================================================== +@testset "cageqrf! direct p=$p: $T" for T in (Float64, ComplexF64), + p in (2, 4) + A = rand(T, 128, 64) + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + + nb = DA_copy.partitioning.blocksize[2] + ib = 1 + lm, ln = Dagger.qr_measure_workspace(DA_copy, ib) + Tm = DArray{T}(undef, Blocks(ib, nb), (lm, 2 * ln)) # 2× for tree T + + cageqrf!(DA_copy, Tm; p=p) + + # Build Q + DQ = DMatrix(I*one(T), (size(A, 1), size(A, 1)), DA_copy.partitioning) + porgqr!('N', DA_copy, Tm, DQ; p=p) + + Q_dense = collect(DQ) + R = triu(collect(DA_copy)) + eps_val = eps(real(T)) + + QR = Q_dense * R + res1 = opnorm(A - QR, 1) / (opnorm(A, 1) * max(size(A)...) * eps_val) + CDI = Q_dense' * Q_dense + res2 = opnorm(I - CDI, 1) / (size(A, 1) * eps_val) + + @test res1 < 2.0 + @test res2 < 2.0 +end + +# ====================================================================== +# 18. cageqrf! error on irregular tiling with p > 1 +# ====================================================================== +@testset "cageqrf! rejects irregular tiling" begin + A = rand(Float64, 100, 60) + DA = distribute(A, Blocks(30, 20)) # non-square tiles + DA_copy = copy(DA) + + nb = 20 + ib = 1 + lm = ib * cld(100, 30) + ln = nb * cld(60, 20) + Tm = DArray{Float64}(undef, Blocks(ib, nb), (lm, 2 * ln)) + + @test_throws ArgumentError cageqrf!(DA_copy, Tm; p=2) +end + +# ====================================================================== +# 19. Edge case — single tile +# ====================================================================== +@testset "Single tile QR: $T" for T in (Float64, ComplexF64) + A = rand(T, 32, 32) + DA = distribute(A, Blocks(32, 32)) + DQ, DR = qr(DA) + check_qr_residuals(A, DQ, DR) +end + +# ====================================================================== +# 20. Edge case — matrix smaller than block size +# ====================================================================== +@testset "Smaller than block QR: $T" for T in (Float64, ComplexF64) + A = rand(T, 16, 16) + DA = distribute(A, Blocks(32, 32)) + DQ, DR = qr(DA) + check_qr_residuals(A, DQ, DR) +end + +# ====================================================================== +# 21. QR least-squares solve: A x ≈ b +# ====================================================================== +@testset "QR least-squares: $T" for T in (Float64, ComplexF64) + m, n = 128, 64 + A = rand(T, m, n) + x_true = rand(T, n) + b = A * x_true + + DA = distribute(A, Blocks(32, 32)) + F = qr(DA) + + # Solve via Q'b then back-substitution + Q_dense = collect(F.Q) + R = collect(F.R) + Qtb = Q_dense' * b + x_solved = R[1:n, 1:n] \ Qtb[1:n] + + @test x_solved ≈ x_true rtol=1e-8 +end + +# ====================================================================== +# 22. CAQR least-squares solve with p > 1 +# ====================================================================== +@testset "CAQR least-squares p=$p: $T" for T in (Float64,), p in (2, 4) + m, n = 128, 64 + A = rand(T, m, n) + x_true = rand(T, n) + b = A * x_true + + DA = distribute(A, Blocks(32, 32)) + DA_copy = copy(DA) + F = qr!(DA_copy; p=p) + + Q_dense = collect(DMatrix(F.Q)) + R = triu(collect(DA_copy)) + Qtb = Q_dense' * b + x_solved = R[1:n, 1:n] \ Qtb[1:n] + + @test x_solved ≈ x_true rtol=1e-8 +end diff --git a/test/runtests.jl b/test/runtests.jl index 7929dfdce..81c7c4f73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ tests = [ ("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"), ("Array - LinearAlgebra - LU", "array/linalg/lu.jl"), ("Array - LinearAlgebra - Solve", "array/linalg/solve.jl"), + ("Array - LinearAlgebra - QR", "array/linalg/qr.jl"), ("Array - Random", "array/random.jl"), ("Array - Stencils", "array/stencil.jl"), ("Array - FFT", "array/fft.jl"),