Skip to content

Commit 5bd2288

Browse files
committed
important bug fixes
1 parent 0a92fe2 commit 5bd2288

File tree

5 files changed

+27
-13
lines changed

5 files changed

+27
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
4-
version = "4.0.1"
4+
version = "4.0.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/indexnotation/parser.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,17 @@ function tensorify(ex::Expr)
113113
ExistingTensor)
114114
end
115115
else
116-
return instantiate(dst, false, rhs, true, leftind, rightind, NewTensor)
116+
# deal with the case that dst can be an existing variable while also a new value is assigned to it
117+
if dst getinputtensorobjects(rhs)
118+
dst2 = gensym(dst)
119+
return quote
120+
$dst2 = $dst
121+
$(instantiate(dst2, false, rhs, true, leftind, rightind, NewTensor))
122+
$dst = $dst2
123+
end
124+
else
125+
return instantiate(dst, false, rhs, true, leftind, rightind, NewTensor)
126+
end
117127
end
118128
elseif isassignment(ex) && isscalarexpr(lhs)
119129
if istensorexpr(rhs) && isempty(getindices(rhs))
@@ -136,12 +146,6 @@ function tensorify(ex::Expr)
136146
if ex.head == :block # @tensor begin ... end
137147
return Expr(ex.head, map(tensorify, ex.args)...)
138148
end
139-
if ex.head == :for # @tensor for ... end
140-
return Expr(ex.head, ex.args[1], tensorify(ex.args[2]))
141-
end
142-
if ex.head == :function # @tensor function ... end
143-
return Expr(ex.head, ex.args[1], tensorify(ex.args[2]))
144-
end
145149
# constructions of the form: a = @tensor ...
146150
if isscalarexpr(ex)
147151
return instantiate_scalar(ex)

src/indexnotation/preprocessors.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,9 @@ end
119119
Group all scalar factors of a tensor expression into a single scalar factor at the start of the expression.
120120
"""
121121
function groupscalarfactors(ex)
122-
if isa(ex, Expr) # prewalk
123-
ex = Expr(ex.head, map(groupscalarfactors, ex.args)...)
124-
end
125-
if istensorexpr(ex) && ex.args[1] == :*
122+
if istensor(ex) || (isexpr(ex, :macrocall) && ex.args[1] == Symbol("@notensor"))
123+
return ex
124+
elseif istensorexpr(ex) && ex.args[1] == :*
126125
args = ex.args[2:end]
127126
scalarpos = findall(isscalarexpr, args)
128127
length(scalarpos) == 0 && return ex
@@ -133,6 +132,8 @@ function groupscalarfactors(ex)
133132
scalar = Expr(:call, :*, args[scalarpos]...)
134133
end
135134
return Expr(:call, :*, scalar, args[tensorpos]...)
135+
elseif isa(ex, Expr)
136+
return Expr(ex.head, map(groupscalarfactors, ex.args)...)
136137
end
137138
return ex
138139
end

src/indexnotation/verifiers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function isscalarexpr(ex)
125125
elseif isexpr(ex, :call)
126126
return all(isscalarexpr, ex.args[2:end])
127127
else
128-
return false
128+
return true # assume everything else is valid scalar code
129129
end
130130
end
131131

test/tensor.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ using LinearAlgebra
418418
@test res -(mat1 * vec + mat2 * vec)
419419
end
420420

421+
@testset "Handling definitions" begin
422+
vec = rand(2)
423+
veccopy = copy(vec)
424+
mat1 = rand(2, 2)
425+
mat2 = rand(2, 2)
426+
@tensor vec[a] := (mat1[a, b] * mat2[b, c]) * vec[c]
427+
@test vec mat1 * mat2 * veccopy
428+
end
429+
421430
# @testset "Issue 136" begin
422431
# A = rand(2, 2)
423432
# B = rand(2, 2)

0 commit comments

Comments
 (0)