Skip to content

Commit 8e4db13

Browse files
committed
support blocks, better precompilation / less compilation
1 parent 1d32e9e commit 8e4db13

File tree

4 files changed

+116
-50
lines changed

4 files changed

+116
-50
lines changed

src/TensorOperations.jl

Lines changed: 101 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -109,45 +109,108 @@ cachesize() = cache.currentsize
109109
# Some precompile statements
110110
#----------------------------
111111
function _precompile_()
112+
AVector = Vector{Any}
112113
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
113-
precompile(Tuple{typeof(_findfirst), typeof(identity), Int})
114-
precompile(Tuple{typeof(_findnext), typeof(identity), Int})
115-
precompile(Tuple{typeof(_ncontree!), Int, Int})
116-
precompile(Tuple{typeof(conjexpr), Expr})
117-
precompile(Tuple{typeof(deindexify_contraction), Int, Int, Expr, Int, Vector, Vector, Int})
118-
precompile(Tuple{typeof(deindexify_generaltensor), Int, Int, Expr, Int, Vector, Vector, Int})
119-
precompile(Tuple{typeof(deindexify_linearcombination), Int, Int, Expr, Int, Vector, Vector, Int})
120-
precompile(Tuple{typeof(deindexify), Int, Int, Expr, Int, Vector, Vector, Int})
121-
precompile(Tuple{typeof(deindexify), Int, Int, Expr, Int, Vector, Vector})
122-
precompile(Tuple{typeof(expandconj), Expr})
123-
precompile(Tuple{typeof(expandconj), Int})
124-
precompile(Tuple{typeof(getallindices), Expr})
125-
precompile(Tuple{typeof(getallindices), Int})
126-
precompile(Tuple{typeof(geteltype), Expr})
127-
precompile(Tuple{typeof(geteltype), Int})
128-
precompile(Tuple{typeof(getindices), Expr})
129-
precompile(Tuple{typeof(hastraceindices), Int})
130-
precompile(Tuple{typeof(isgeneraltensor), Expr})
131-
precompile(Tuple{typeof(isindex), Int})
132-
precompile(Tuple{typeof(isnconstyle), Vector})
133-
precompile(Tuple{typeof(isscalarexpr), Expr})
134-
precompile(Tuple{typeof(istensor), Expr})
135-
precompile(Tuple{typeof(istensorexpr), Int})
136-
precompile(Tuple{typeof(makegeneraltensor), Int})
137-
precompile(Tuple{typeof(makeindex), Int})
138-
precompile(Tuple{typeof(makescalar), Expr})
139-
precompile(Tuple{typeof(maketensor), Int})
140-
precompile(Tuple{typeof(ncontree), Vector})
141-
precompile(Tuple{typeof(processcontractorder), Expr, Int})
142-
precompile(Tuple{typeof(processcontractorder), Int, Int})
143-
precompile(Tuple{typeof(tensorify), Expr, Int})
144-
precompile(Tuple{typeof(tensorify), Expr})
145-
precompile(Tuple{typeof(tensorify), Int, Int})
146-
precompile(Tuple{typeof(tree2expr), Int, Int})
147-
precompile(Tuple{typeof(unique2), Array{Any, 1}})
148-
precompile(Tuple{typeof(unique2), Array{Int64, 1}})
149-
precompile(Tuple{typeof(use_blas)})
150-
precompile(Tuple{typeof(use_cache)})
114+
@assert precompile(Tuple{typeof(_findfirst), Base.Fix2{typeof(Base.isequal), Symbol}, AVector})
115+
@assert precompile(Tuple{typeof(_findnext), Base.Fix2{typeof(Base.isequal), Int64}, AVector, Int64})
116+
@assert precompile(Tuple{typeof(_intersect), Base.BitArray{1}, Base.BitArray{1}})
117+
@assert precompile(Tuple{typeof(_intersect), Base.BitSet, Base.BitSet})
118+
@assert precompile(Tuple{typeof(_intersect), UInt128, UInt128})
119+
@assert precompile(Tuple{typeof(_intersect), UInt32, UInt32})
120+
@assert precompile(Tuple{typeof(_intersect), UInt64, UInt64})
121+
@assert precompile(Tuple{typeof(_isemptyset), Base.BitArray{1}})
122+
@assert precompile(Tuple{typeof(_isemptyset), Base.BitSet})
123+
@assert precompile(Tuple{typeof(_isemptyset), UInt128})
124+
@assert precompile(Tuple{typeof(_isemptyset), UInt32})
125+
@assert precompile(Tuple{typeof(_isemptyset), UInt64})
126+
@assert precompile(Tuple{typeof(_ncontree!), AVector, Vector{Vector{Int64}}})
127+
@assert precompile(Tuple{typeof(_setdiff), Base.BitArray{1}, Base.BitArray{1}})
128+
@assert precompile(Tuple{typeof(_setdiff), Base.BitSet, Base.BitSet})
129+
@assert precompile(Tuple{typeof(_setdiff), UInt128, UInt128})
130+
@assert precompile(Tuple{typeof(_setdiff), UInt32, UInt32})
131+
@assert precompile(Tuple{typeof(_setdiff), UInt64, UInt64})
132+
@assert precompile(Tuple{typeof(_union), Base.BitArray{1}, Base.BitArray{1}})
133+
@assert precompile(Tuple{typeof(_union), Base.BitSet, Base.BitSet})
134+
@assert precompile(Tuple{typeof(_union), UInt128, UInt128})
135+
@assert precompile(Tuple{typeof(_union), UInt32, UInt32})
136+
@assert precompile(Tuple{typeof(_union), UInt64, UInt64})
137+
@assert precompile(Tuple{typeof(addcost), Power{, Int64}, Power{, Int64}})
138+
@assert precompile(Tuple{typeof(degree), Power{:x, Int64}})
139+
@assert precompile(Tuple{typeof(deindexify_contraction), Int, Int, Expr, Int, AVector, AVector, Int})
140+
@assert precompile(Tuple{typeof(deindexify_generaltensor), Int, Int, Expr, Int, AVector, AVector, Int})
141+
@assert precompile(Tuple{typeof(deindexify_linearcombination), Int, Int, Expr, Int, AVector, AVector, Int})
142+
@assert precompile(Tuple{typeof(deindexify), Int, Int, Expr, Int, AVector, AVector, Int})
143+
@assert precompile(Tuple{typeof(deindexify), Int, Int, Expr, Int, AVector, AVector})
144+
@assert precompile(Tuple{typeof(deindexify), Expr, Bool, Expr, Int64, AVector, AVector})
145+
@assert precompile(Tuple{typeof(deindexify), Nothing, Bool, Expr, Bool, AVector, AVector, Bool})
146+
@assert precompile(Tuple{typeof(deindexify), Symbol, Bool, Expr, Bool, AVector, AVector})
147+
@assert precompile(Tuple{typeof(disable_blas)})
148+
@assert precompile(Tuple{typeof(disable_cache)})
149+
@assert precompile(Tuple{typeof(enable_blas)})
150+
@assert precompile(Tuple{typeof(enable_cache)})
151+
@assert precompile(Tuple{typeof(expandconj), Expr})
152+
@assert precompile(Tuple{typeof(expandconj), Symbol})
153+
@assert precompile(Tuple{typeof(getallindices), Expr})
154+
@assert precompile(Tuple{typeof(getallindices), Int})
155+
@assert precompile(Tuple{typeof(getallindices), Symbol})
156+
@assert precompile(Tuple{typeof(geteltype), Expr})
157+
@assert precompile(Tuple{typeof(getindices), Symbol})
158+
@assert precompile(Tuple{typeof(getindices), Expr})
159+
@assert precompile(Tuple{typeof(getlhsrhs), Expr})
160+
@assert precompile(Tuple{typeof(hastraceindices), Expr})
161+
@assert precompile(Tuple{typeof(isassignment), Expr})
162+
@assert precompile(Tuple{typeof(isdefinition), Expr})
163+
@assert precompile(Tuple{typeof(isgeneraltensor), Expr})
164+
@assert precompile(Tuple{typeof(isindex), Symbol})
165+
@assert precompile(Tuple{typeof(isindex), Int})
166+
@assert precompile(Tuple{typeof(isnconstyle), Array{AVector, 1}})
167+
@assert precompile(Tuple{typeof(isscalarexpr), Expr})
168+
@assert precompile(Tuple{typeof(isscalarexpr), Float64})
169+
@assert precompile(Tuple{typeof(isscalarexpr), LineNumberNode})
170+
@assert precompile(Tuple{typeof(isscalarexpr), Symbol})
171+
@assert precompile(Tuple{typeof(istensor), Expr})
172+
@assert precompile(Tuple{typeof(istensorexpr), Expr})
173+
@assert precompile(Tuple{typeof(makegeneraltensor), Expr})
174+
@assert precompile(Tuple{typeof(makeindex), Int})
175+
@assert precompile(Tuple{typeof(makeindex), Symbol})
176+
@assert precompile(Tuple{typeof(makescalar), Expr})
177+
@assert precompile(Tuple{typeof(makescalar), Float64})
178+
@assert precompile(Tuple{typeof(maketensor), Expr})
179+
@assert precompile(Tuple{typeof(mulcost), Power{, Int64}, Power{, Int64}})
180+
@assert precompile(Tuple{typeof(ncontree), Vector{AVector}})
181+
@assert precompile(Tuple{typeof(optdata), Expr})
182+
@assert precompile(Tuple{typeof(optimaltree), Vector{AVector}, Base.Dict{Any, Power{, Int64}}})
183+
@assert precompile(Tuple{typeof(parsecost), Expr})
184+
@assert precompile(Tuple{typeof(parsecost), Int64})
185+
@assert precompile(Tuple{typeof(parsecost), Symbol})
186+
@assert precompile(Tuple{typeof(processcontractorder), Expr, Nothing})
187+
@assert precompile(Tuple{typeof(processcontractorder), Expr, Int})
188+
@assert precompile(Tuple{typeof(processcontractorder), Int, Int})
189+
@assert precompile(Tuple{typeof(processcontractorder), Symbol, Nothing})
190+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitArray{1}}, AVector, Int64})
191+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitArray{1}}, Array{Int64, 1}, Int64})
192+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitArray{1}}, Base.Set{Int64}, Int64})
193+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitSet}, AVector, Int64})
194+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitSet}, Array{Int64, 1}, Int64})
195+
@assert precompile(Tuple{typeof(storeset), Type{Base.BitSet}, Base.Set{Int64}, Int64})
196+
@assert precompile(Tuple{typeof(storeset), Type{UInt128}, AVector, Int64})
197+
@assert precompile(Tuple{typeof(storeset), Type{UInt128}, Array{Int64, 1}, Int64})
198+
@assert precompile(Tuple{typeof(storeset), Type{UInt128}, Base.Set{Int64}, Int64})
199+
@assert precompile(Tuple{typeof(storeset), Type{UInt32}, AVector, Int64})
200+
@assert precompile(Tuple{typeof(storeset), Type{UInt32}, Array{Int64, 1}, Int64})
201+
@assert precompile(Tuple{typeof(storeset), Type{UInt32}, Base.Set{Int64}, Int64})
202+
@assert precompile(Tuple{typeof(storeset), Type{UInt64}, AVector, Int64})
203+
@assert precompile(Tuple{typeof(storeset), Type{UInt64}, Array{Int64, 1}, Int64})
204+
@assert precompile(Tuple{typeof(storeset), Type{UInt64}, Base.Set{Int64}, Int64})
205+
@assert precompile(Tuple{typeof(tensorify), Expr, Nothing})
206+
@assert precompile(Tuple{typeof(tensorify), Expr, Int})
207+
@assert precompile(Tuple{typeof(tensorify), Expr})
208+
@assert precompile(Tuple{typeof(tensorify), Int, Int})
209+
@assert precompile(Tuple{typeof(tree2expr), Int, Int})
210+
@assert precompile(Tuple{typeof(unique2), AVector})
211+
@assert precompile(Tuple{typeof(unique2), Array{Int64, 1}})
212+
@assert precompile(Tuple{typeof(use_blas)})
213+
@assert precompile(Tuple{typeof(use_cache)})
151214
end
152215
_precompile_()
153216

src/indexnotation/tensorexpressions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,21 @@ end
148148
function maketensor(ex)
149149
if isa(ex, Expr) && (ex.head == :ref || ex.head == :typed_hcat)
150150
object = esc(ex.args[1])
151-
leftind = map(makeindex, ex.args[2:end])
151+
leftind = Any[makeindex(x) for x in ex.args[2:end]]
152152
rightind = Any[]
153153
return (object, leftind, rightind)
154154
elseif isa(ex, Expr) && ex.head == :typed_vcat
155155
length(ex.args) <= 3 || throw(ArgumentError("invalid tensor index expression: $ex"))
156156
object = esc(ex.args[1])
157157
if isa(ex.args[2], Expr) && ex.args[2].head == :row
158-
leftind = map(makeindex, ex.args[2].args)
158+
leftind = Any[makeindex(x) for x in ex.args[2].args]
159159
elseif ex.args[2] == :_
160160
leftind = Any[]
161161
else
162162
leftind = Any[makeindex(ex.args[2])]
163163
end
164164
if length(ex.args) > 2 && isa(ex.args[3], Expr) && ex.args[3].head == :row
165-
rightind = map(makeindex, ex.args[3].args)
165+
rightind = Any[makeindex(x) for x in ex.args[3].args]
166166
elseif length(ex.args) == 2 || ex.args[3] == :_
167167
rightind = Any[]
168168
else

src/indexnotation/tensormacro.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ function tensorify(ex::Expr, optdata = nothing)
203203
if ex.head == :block
204204
return Expr(ex.head, map(x->tensorify(x, optdata), ex.args)...)
205205
end
206+
if ex.head == :for
207+
return Expr(ex.head, esc(ex.args[1]), tensorify(ex.args[2], optdata))
208+
end
206209
# constructions of the form: a = @tensor ...
207210
if isscalarexpr(ex)
208211
return makescalar(ex)
@@ -288,7 +291,7 @@ function tree2expr(args, tree)
288291
end
289292

290293
# deindexify: parse tensor operations
291-
function deindexify(dst, β, ex::Expr, α, leftind::Vector, rightind::Vector, istemporary = false)
294+
function deindexify(dst, β, ex::Expr, α, leftind::Vector{Any}, rightind::Vector{Any}, istemporary = false)
292295
if isgeneraltensor(ex)
293296
return deindexify_generaltensor(dst, β, ex, α, leftind, rightind, istemporary)
294297
elseif ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :-) # linear combination
@@ -309,7 +312,7 @@ function deindexify(dst, β, ex::Expr, α, leftind::Vector, rightind::Vector, is
309312
throw(ArgumentError("problem with parsing $ex"))
310313
end
311314

312-
function deindexify_generaltensor(dst, β, ex::Expr, α, leftind::Vector, rightind::Vector, istemporary = false)
315+
function deindexify_generaltensor(dst, β, ex::Expr, α, leftind::Vector{Any}, rightind::Vector{Any}, istemporary = false)
313316
src, srcleftind, srcrightind, α2, conj = makegeneraltensor(ex)
314317
srcind = vcat(srcleftind, srcrightind)
315318
conjarg = conj ? :(:C) : :(:N)
@@ -362,7 +365,7 @@ function deindexify_generaltensor(dst, β, ex::Expr, α, leftind::Vector, righti
362365
end
363366
end
364367
end
365-
function deindexify_linearcombination(dst, β, ex::Expr, α, leftind::Vector, rightind::Vector, istemporary = false)
368+
function deindexify_linearcombination(dst, β, ex::Expr, α, leftind::Vector{Any}, rightind::Vector{Any}, istemporary = false)
366369
if ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :-) # addition: add one by one
367370
if dst === nothing
368371
αnew = Expr(:call, :*, α, Expr(:call, :one, geteltype(ex)))
@@ -388,7 +391,7 @@ function deindexify_linearcombination(dst, β, ex::Expr, α, leftind::Vector, ri
388391
throw(ArgumentError("unable to deindexify linear combination: $ex"))
389392
end
390393
end
391-
function deindexify_contraction(dst, β, ex::Expr, α, leftind::Vector, rightind::Vector, istemporary = false)
394+
function deindexify_contraction(dst, β, ex::Expr, α, leftind::Vector{Any}, rightind::Vector{Any}, istemporary = false)
392395
@assert ex.head == :call && ex.args[1] == :* && length(ex.args) == 3 &&
393396
istensorexpr(ex.args[2]) && istensorexpr(ex.args[3])
394397
exA = ex.args[2]

test/auxiliary.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646
@test !isindex(:(5.1))
4747

4848
@test istensor(:(a[1,2,3]))
49-
@test maketensor(:(a[1,2,3])) == (esc(:a), [1,2,3], [])
49+
@test maketensor(:(a[1,2,3])) == (esc(:a), Any[1,2,3], Any[])
5050
@test istensor(:(a[5][a b c]))
51-
@test maketensor(:(a[5][a b c])) == (esc(:(a[5])), [:a,:b,:c], [])
51+
@test maketensor(:(a[5][a b c])) == (esc(:(a[5])), Any[:a,:b,:c], Any[])
5252
@test istensor(:(cos(y)[a b c; 1 2 3]))
53-
@test maketensor(:(cos(y)[a b c; 1 2 3])) == (esc(:(cos(y))), [:a,:b,:c], [1,2,3])
53+
@test maketensor(:(cos(y)[a b c; 1 2 3])) == (esc(:(cos(y))), Any[:a,:b,:c], Any[1,2,3])
5454
@test istensor(:(x[_; 1 2 3]))
55-
@test maketensor(:(x[_; 1 2 3])) == (esc(:x), [], [1,2,3])
55+
@test maketensor(:(x[_; 1 2 3])) == (esc(:x), Any[], Any[1,2,3])
5656
@test !istensor(:(2*a[1,2,3]))
5757
@test !istensor(:(a[1 2 3; 4 5 6; 7 8 9]))
5858
@test !istensor(:(conj(a[5][a b c])))
@@ -64,7 +64,7 @@
6464
@test isgeneraltensor(:(x*a[5][a b c]))
6565
@test makegeneraltensor(:(x*a[5][a b c])) == (esc(:(a[5])), [:a,:b,:c], [], makescalar(:(x*1)), false)
6666
@test isgeneraltensor(:(3*conj(a*cos(y)[a b c; 1 2 3])))
67-
@test makegeneraltensor(:(3*conj(a*cos(y)[a b c; 1 2 3]))) == (esc(:(cos(y))), [:a,:b,:c], [1,2,3], makescalar(:(3*conj(a*1))), true)
67+
@test makegeneraltensor(:(3*conj(a*cos(y)[a b c; 1 2 3]))) == (esc(:(cos(y))), Any[:a,:b,:c], Any[1,2,3], makescalar(:(3*conj(a*1))), true)
6868
@test !isgeneraltensor(:(1/a[1,2,3]))
6969
@test !isgeneraltensor(:(a[1 2 3; 4 5 6]\x))
7070
@test !isgeneraltensor(:(cos(y)[a b c; 1 2 3]*b[4,5]))

0 commit comments

Comments
 (0)