@@ -113,7 +113,17 @@ function tensorify(ex::Expr)
113113 ExistingTensor)
114114 end
115115 else
116- return instantiate (dst, false , rhs, true , leftind, rightind, NewTensor)
116+ # deal with the case that dst can be an existing variable while also a new value is assigned to it
117+ if dst ∈ getinputtensorobjects (rhs)
118+ dst2 = gensym (dst)
119+ return quote
120+ $ dst2 = $ dst
121+ $ (instantiate (dst2, false , rhs, true , leftind, rightind, NewTensor))
122+ $ dst = $ dst2
123+ end
124+ else
125+ return instantiate (dst, false , rhs, true , leftind, rightind, NewTensor)
126+ end
117127 end
118128 elseif isassignment (ex) && isscalarexpr (lhs)
119129 if istensorexpr (rhs) && isempty (getindices (rhs))
@@ -136,12 +146,6 @@ function tensorify(ex::Expr)
136146 if ex. head == :block # @tensor begin ... end
137147 return Expr (ex. head, map (tensorify, ex. args)... )
138148 end
139- if ex. head == :for # @tensor for ... end
140- return Expr (ex. head, ex. args[1 ], tensorify (ex. args[2 ]))
141- end
142- if ex. head == :function # @tensor function ... end
143- return Expr (ex. head, ex. args[1 ], tensorify (ex. args[2 ]))
144- end
145149 # constructions of the form: a = @tensor ...
146150 if isscalarexpr (ex)
147151 return instantiate_scalar (ex)
0 commit comments