Skip to content

Commit 541f2f4

Browse files
authored
[Fix] makeblascontractable should return blascontractable objects. (#210)
1 parent 390aed8 commit 541f2f4

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

.github/workflows/FormatCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
#
3333
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
3434
run: |
35-
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
35+
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="1"))'
3636
julia -e 'using JuliaFormatter; format(".", verbose=true)'
3737
- name: Format check
3838
run: |

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "5.2.0"
4+
version = "5.2.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/implementation/blascontract.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function makeblascontractable(A, pA, TC, backend, allocator)
8585
flagA = isblascontractable(A, pA) && eltype(A) == TC
8686
if !flagA
8787
A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator)
88-
Anew = SV(A_, size(A_), strides(A_), 0, A.op)
88+
Anew = SV(A_)
8989
Anew = tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
9090
pAnew = trivialpermutation(pA)
9191
else

test/tensor.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,4 +528,45 @@ end
528528
C3 = ncon([A], [[-1, -3, -4, -2]], [true])
529529
@test C3 conj(C)
530530
end
531+
532+
@testset "methods for StridedBLAS" begin
533+
using TensorOperations: isblascontractable, isblasdestination, makeblascontractable,
534+
StridedBLAS, DefaultAllocator
535+
using Strided
536+
backend = StridedBLAS()
537+
allocator = DefaultAllocator()
538+
539+
A = StridedView(rand(ComplexF64, 4, 4, 4, 4))
540+
541+
p = ((1, 2), (3, 4))
542+
@test isblascontractable(A, p)
543+
Anew, pnew, flag = makeblascontractable(A, p, ComplexF64, backend, allocator)
544+
@test isblascontractable(Anew, pnew)
545+
@test Anew === A
546+
547+
@test !isblascontractable(conj(A), p)
548+
Anew, pnew, flag = makeblascontractable(A, p, ComplexF64, backend, allocator)
549+
@test isblascontractable(Anew, pnew)
550+
551+
for p in (((2, 1), (3, 4)), ((1,), (3, 2, 4)), ((2, 1, 4), (3,))),
552+
op in (identity, conj)
553+
554+
@test !isblascontractable(op(A), p)
555+
Anew, pnew, flag = makeblascontractable(op(A), p, ComplexF64, backend,
556+
allocator)
557+
@test isblascontractable(Anew, pnew)
558+
end
559+
560+
vA = view(A, 1:2, 1:2, 1:2, 1:2)
561+
p = ((1, 2), (3, 4))
562+
@test !isblascontractable(vA, p)
563+
Anew, pnew, flag = makeblascontractable(vA, p, ComplexF64, backend, allocator)
564+
@test !flag
565+
@test isblascontractable(Anew, pnew)
566+
567+
pA = permutedims(A, (3, 4, 1, 2))
568+
p = ((1, 2), (3, 4))
569+
@test isblascontractable(pA, p)
570+
@test isblascontractable(conj(pA), p)
571+
end
531572
end

0 commit comments

Comments
 (0)