@@ -10,10 +10,15 @@ function tensorcontract!(C::AbstractArray,
1010 argcheck_tensorcontract (C, A, pA, B, pB, pAB)
1111 dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
1212
13- _diagtensorcontract! (StridedView (C),
14- StridedView (A), pA, conjA,
15- StridedView (B. diag), pB, conjB,
16- pAB, α, β)
13+ if conjA && conjB
14+ _diagtensorcontract! (SV (C), conj (SV (A)), pA, conj (SV (B. diag)), pB, pAB, α, β)
15+ elseif conjA
16+ _diagtensorcontract! (SV (C), conj (SV (A)), pA, SV (B. diag), pB, pAB, α, β)
17+ elseif conjB
18+ _diagtensorcontract! (SV (C), SV (A), pA, conj (SV (B. diag)), pB, pAB, α, β)
19+ else
20+ _diagtensorcontract! (SV (C), SV (A), pA, SV (B. diag), pB, pAB, α, β)
21+ end
1722 return C
1823end
1924
@@ -35,10 +40,15 @@ function tensorcontract!(C::AbstractArray,
3540 rpAB = (TupleTools. getindices (indCinoBA, tpAB[1 ]),
3641 TupleTools. getindices (indCinoBA, tpAB[2 ]))
3742
38- _diagtensorcontract! (StridedView (C),
39- StridedView (B), rpB, conjB,
40- StridedView (A. diag), rpA, conjA,
41- rpAB, α, β)
43+ if conjA && conjB
44+ _diagtensorcontract! (SV (C), conj (SV (B)), rpB, conj (SV (A. diag)), rpA, rpAB, α, β)
45+ elseif conjA
46+ _diagtensorcontract! (SV (C), SV (B), rpB, conj (SV (A. diag)), rpA, rpAB, α, β)
47+ elseif conjB
48+ _diagtensorcontract! (SV (C), conj (SV (B)), rpB, SV (A. diag), rpA, rpAB, α, β)
49+ else
50+ _diagtensorcontract! (SV (C), SV (B), rpB, SV (A. diag), rpA, rpAB, α, β)
51+ end
4252 return C
4353end
4454
@@ -50,40 +60,16 @@ function tensorcontract!(C::AbstractArray,
5060 :: StridedNative , allocator= DefaultAllocator ())
5161 argcheck_tensorcontract (C, A, pA, B, pB, pAB)
5262 dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
53- if numin (pA) == 1 # matrix multiplication
54- scale! (C, β)
55- β = one (β)
56-
57- A2 = sreshape (flag2op (conjA)(StridedView (A. diag)), (length (A. diag), 1 ))
58- B2 = sreshape (flag2op (conjB)(StridedView (B. diag)), (length (B. diag), 1 ))
59- # take a view of the diagonal elements of C, having strides 1 + length(diag)
60- totsize = (length (A. diag),)
61- C2 = StridedView (C, totsize, (sum (strides (C)),))
62-
63- elseif numin (pA) == 2 # trace
64- A2 = flag2op (conjA)(StridedView (A. diag, (length (A. diag),)))
65- B2 = flag2op (conjB)(StridedView (B. diag, (length (B. diag),)))
66- totsize = (length (A. diag),)
67- C2 = sreshape (StridedView (C), (1 ,))
68-
69- else # outer product
70- scale! (C, β)
71- β = one (β)
7263
73- A2 = sreshape ( StridedView (A . diag), ( length (A . diag), 1 ))
74- B2 = sreshape ( StridedView (B . diag), ( 1 , length (A . diag)))
75-
76- C3 = permutedims ( StridedView (C), invperm ( linearize (pAB)) )
77- strC = strides (C3)
78- newstrides = (strC[ 1 ] + strC[ 2 ], strC[ 3 ] + strC[ 4 ] )
79- totsize = ( length (A2), length (B2))
80- C2 = StridedView (C3 . parent, totsize, newstrides, C3 . offset, C3 . op )
64+ if conjA && conjB
65+ _diagdiagcontract! ( SV (C), conj ( SV (A . diag)), pA, conj ( SV (B . diag)), pB, pAB, α, β )
66+ elseif conjA
67+ _diagdiagcontract! ( SV (C), conj ( SV (A . diag)), pA, SV (B . diag), pB, pAB, α, β )
68+ elseif conjB
69+ _diagdiagcontract! ( SV (C), SV (A . diag), pA, conj ( SV (B . diag)), pB, pAB, α, β )
70+ else
71+ _diagdiagcontract! ( SV (C), SV (A . diag), pA, SV (B . diag), pB, pAB, α, β )
8172 end
82-
83- op1 = Base. Fix2 (scale, α) ∘ *
84- op2 = Base. Fix2 (scale, β)
85- Strided. _mapreducedim! (op1, + , op2, totsize, (C2, A2, B2))
86-
8773 return C
8874end
8975
@@ -96,41 +82,49 @@ function tensorcontract!(C::Diagonal,
9682 argcheck_tensorcontract (C, A, pA, B, pB, pAB)
9783 dimcheck_tensorcontract (C, A, pA, B, pB, pAB)
9884
99- A2 = flag2op (conjA)( StridedView (A. diag) )
100- B2 = flag2op (conjB)( StridedView (B. diag) )
85+ A2 = StridedView (A. diag)
86+ B2 = StridedView (B. diag)
10187 C2 = StridedView (C. diag)
10288
103- C2 .= C2 .* β .+ A2 .* B2 .* α
89+ if conjA && conjB
90+ C2 .= C2 .* β .+ conj .(A2 .* B2) .* α
91+ elseif conjA
92+ C2 .= C2 .* β .+ conj .(A2) .* B2 .* α
93+ elseif conjB
94+ C2 .= C2 .* β .+ A2 .* conj .(B2) .* α
95+ else
96+ C2 .= C2 .* β .+ A2 .* B2 .* α
97+ end
10498 return C
10599end
106100
107101function _diagtensorcontract! (C:: StridedView ,
108- A:: StridedView , pA:: Index2Tuple , conjA :: Bool ,
109- Bdiag:: StridedView , pB:: Index2Tuple , conjB :: Bool ,
102+ A:: StridedView , pA:: Index2Tuple ,
103+ Bdiag:: StridedView , pB:: Index2Tuple ,
110104 pAB:: Index2Tuple , α:: Number , β:: Number )
111105 sizeA = i -> size (A, i)
112106 csizeA = sizeA .(pA[2 ])
113107 osizeA = sizeA .(pA[1 ])
114108
115109 if numin (pB) == 1 # => numin(A) == numout(B) == 1
116110 totsize = (osizeA... , csizeA... )
117- A2 = flag2op (conjA)( permutedims (A, linearize (pA) ))
118- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , csizeA... ) ))
111+ A2 = permutedims (A, linearize (pA))
112+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , csizeA... ))
119113 C2 = permutedims (C, invperm (linearize (pAB)))
120114
121115 elseif numin (pB) == 0
122116 strideA = i -> stride (A, i)
123117 newstrides = (strideA .(pA[1 ])... , strideA (pA[2 ][1 ]) + strideA (pA[2 ][2 ]))
124118 totsize = (osizeA... , csizeA[1 ])
125- A2 = flag2op (conjA)( StridedView (A. parent, totsize, newstrides, A. offset, A. op) )
126- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , csizeA[1 ]) ))
119+ A2 = StridedView (A. parent, totsize, newstrides, A. offset, A. op)
120+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , csizeA[1 ]))
127121 C2 = permutedims (C, invperm (linearize (pAB)))
128122
129123 else # numout(pB) == 2 # direct product
130124 scale! (C, β)
131125 β = one (β)
132- A2 = flag2op (conjA)( sreshape (permutedims (A, linearize (pA)), (osizeA... , 1 ) ))
133- B2 = flag2op (conjB)( sreshape (Bdiag, ((one .(osizeA)). .. , length (Bdiag) )))
126+ A2 = sreshape (permutedims (A, linearize (pA)), (osizeA... , 1 ))
127+ B2 = sreshape (Bdiag, ((one .(osizeA)). .. , length (Bdiag)))
134128
135129 C3 = permutedims (C, invperm (linearize (pAB)))
136130 sC = strides (C3)
@@ -145,3 +139,44 @@ function _diagtensorcontract!(C::StridedView,
145139
146140 return C
147141end
142+
143+ function _diagdiagcontract! (C:: StridedView ,
144+ Adiag:: StridedView , pA:: Index2Tuple ,
145+ Bdiag:: StridedView , pB:: Index2Tuple ,
146+ pAB:: Index2Tuple , α:: Number , β:: Number )
147+ if numin (pA) == 1 # matrix multiplication
148+ scale! (C, β)
149+ β = one (β)
150+
151+ A2 = sreshape (Adiag, (length (Adiag), 1 ))
152+ B2 = sreshape (Bdiag, (length (Bdiag), 1 ))
153+ # take a view of the diagonal elements of C, having strides 1 + length(diag)
154+ totsize = (length (Adiag),)
155+ C2 = StridedView (C. parent, totsize, (sum (strides (C)),))
156+
157+ elseif numin (pA) == 2 # trace
158+ A2 = Adiag
159+ B2 = Bdiag
160+ totsize = (length (Adiag),)
161+ C2 = sreshape (C, (1 ,))
162+
163+ else # outer product
164+ scale! (C, β)
165+ β = one (β)
166+
167+ A2 = sreshape (Adiag, (length (Adiag), 1 ))
168+ B2 = sreshape (Bdiag, (1 , length (Adiag)))
169+
170+ C3 = permutedims (C, invperm (linearize (pAB)))
171+ strC = strides (C3)
172+ newstrides = (strC[1 ] + strC[2 ], strC[3 ] + strC[4 ])
173+ totsize = (length (A2), length (B2))
174+ C2 = StridedView (C3. parent, totsize, newstrides, C3. offset, C3. op)
175+ end
176+
177+ op1 = Base. Fix2 (scale, α) ∘ *
178+ op2 = Base. Fix2 (scale, β)
179+ Strided. _mapreducedim! (op1, + , op2, totsize, (C2, A2, B2))
180+
181+ return C
182+ end
0 commit comments