Skip to content

Commit 0b3955a

Browse files
authored
[oneMKL] Fix gesvd! (#485)
1 parent af10c9f commit 0b3955a

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

lib/mkl/wrappers_lapack.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,30 +304,31 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv
304304
jobvt::Char,
305305
A::oneStridedMatrix{$elty})
306306
m, n = size(A)
307+
k = min(m, n)
307308
lda = max(1, stride(A, 2))
308309

309310
U = if jobu === 'A'
310311
oneMatrix{$elty}(undef, m, m)
311-
elseif jobu == 'S' || jobu === 'O'
312-
oneMatrix{$elty}(undef, m, min(m, n))
313-
elseif jobu === 'N'
314-
oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL?
312+
elseif jobu === 'S'
313+
oneMatrix{$elty}(undef, m, k)
314+
elseif jobu === 'N' || jobu === 'O'
315+
ZE_NULL
315316
else
316317
error("jobu must be one of 'A', 'S', 'O', or 'N'")
317318
end
318-
ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2))
319-
S = oneVector{$relty}(undef, min(m, n))
319+
ldu = U == ZE_NULL ? 1 : max(1, stride(U, 2))
320+
S = oneVector{$relty}(undef, k)
320321

321322
Vt = if jobvt === 'A'
322323
oneMatrix{$elty}(undef, n, n)
323-
elseif jobvt === 'S' || jobvt === 'O'
324-
oneMatrix{$elty}(undef, min(m, n), n)
325-
elseif jobvt === 'N'
326-
oneMatrix{$elty}(undef, 0, 0)
324+
elseif jobvt === 'S'
325+
oneMatrix{$elty}(undef, k, n)
326+
elseif jobvt === 'N' || jobvt === 'O'
327+
ZE_NULL
327328
else
328329
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
329330
end
330-
ldvt = Vt == oneArray{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2))
331+
ldvt = Vt == ZE_NULL ? 1 : max(1, stride(Vt, 2))
331332

332333
queue = global_queue(context(A), device())
333334
scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt)

test/onemkl.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,16 @@ end
14211421
d_A = oneMatrix(A)
14221422
U, Σ, Vt = oneMKL.gesvd!('A', 'A', d_A)
14231423
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)
1424+
1425+
for jobu in ('A', 'S', 'N', 'O')
1426+
for jobvt in ('A', 'S', 'N', 'O')
1427+
(jobu == 'A') && (jobvt == 'A') && continue
1428+
(jobu == 'O') && (jobvt == 'O') && continue
1429+
d_A = oneMatrix(A)
1430+
U2, Σ2, Vt2 = oneMKL.gesvd!(jobu, jobvt, d_A)
1431+
@test Σ Σ2
1432+
end
1433+
end
14241434
end
14251435

14261436
@testset "syevd! -- heevd!" begin

0 commit comments

Comments
 (0)