Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
@enum ActivityState begin
AnyState = 0
ActiveState = 1
DupState = 2
MixedState = 3
end

@inline function Base.:|(a1::ActivityState, a2::ActivityState)
ActivityState(Int(a1) | Int(a2))
end

@inline element(::Val{T}) where {T} = T

@inline ptreltype(::Type{Ptr{T}}) where {T} = T
Expand Down Expand Up @@ -393,6 +382,14 @@ Base.@nospecializeinfer @inline function active_reg_inner(
return ty
end

function active_reg_cached(ctx::EnzymeContext, @nospecialize(ST::Type); justActive=false, UnionSret = false, AbstractIsMixed = false)
key = (ST, justActive, UnionSret, AbstractIsMixed)
get!(ctx.activity_cache, key) do
set = Base.IdSet{Type}()
active_reg_inner(ST, set, ctx.world, justActive, UnionSret, AbstractIsMixed)
end
end

Base.@nospecializeinfer @inline function active_reg(@nospecialize(ST::Type), world::UInt; justActive=false, UnionSret = false, AbstractIsMixed = false)
set = Base.IdSet{Type}()
return active_reg_inner(ST, set, world, justActive, UnionSret, AbstractIsMixed)
Expand Down
31 changes: 25 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ import LLVM: Target, TargetMachine
import SparseArrays
using Printf

@enum ActivityState begin
AnyState = 0
ActiveState = 1
DupState = 2
MixedState = 3
end

@inline function Base.:|(a1::ActivityState, a2::ActivityState)
ActivityState(Int(a1) | Int(a2))
end

mutable struct EnzymeContext
world::UInt
activity_cache::Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}
function EnzymeContext(world)
new(world, Dict{Tuple{Type,Bool,Bool,Bool},ActivityState}())
end
end

using Preferences

bitcode_replacement() = parse(Bool, @load_preference("bitcode_replacement", "true"))
Expand Down Expand Up @@ -3282,7 +3301,7 @@ function create_abi_wrapper(
# 3 is index of shadow
if existed[3] != 0 &&
sret_union &&
active_reg(pactualRetType, world; justActive=true, UnionSret=true) == ActiveState
active_reg_cached(interp.context, pactualRetType; justActive=true, UnionSret=true) == ActiveState
rewrite_union_returns_as_ref(enzymefn, data[3], world, width)
end
returnNum = 0
Expand Down Expand Up @@ -4951,7 +4970,7 @@ end
if params.err_if_func_written
FT = TT.parameters[1]
Ty = eltype(FT)
reg = active_reg(Ty, job.world)
reg = active_reg_cached(interp.context, Ty)
if reg == DupState || reg == MixedState
swiftself = has_swiftself(primalf)
todo = LLVM.Value[parameters(primalf)[1+swiftself]]
Expand All @@ -4975,7 +4994,7 @@ end
if !mayWriteToMemory(user)
slegal, foundv, byref = abs_typeof(user)
if slegal
reg2 = active_reg(foundv, job.world)
reg2 = active_reg_cached(interp.context, foundv)
if reg2 == ActiveState || reg2 == AnyState
continue
end
Expand Down Expand Up @@ -5003,7 +5022,7 @@ end
if operands(user)[2] == cur
slegal, foundv, byref = abs_typeof(operands(user)[1])
if slegal
reg2 = active_reg(foundv, job.world)
reg2 = active_reg_cached(interp.context, foundv)
if reg2 == AnyState
continue
end
Expand Down Expand Up @@ -5037,7 +5056,7 @@ end
if is_readonly(called)
slegal, foundv, byref = abs_typeof(user)
if slegal
reg2 = active_reg(foundv, job.world)
reg2 = active_reg_cached(interp.context, foundv)
if reg2 == ActiveState || reg2 == AnyState
continue
end
Expand All @@ -5055,7 +5074,7 @@ end
end
slegal, foundv, byref = abs_typeof(user)
if slegal
reg2 = active_reg(foundv, job.world)
reg2 = active_reg_cached(interp.context, foundv)
if reg2 == ActiveState || reg2 == AnyState
continue
end
Expand Down
12 changes: 9 additions & 3 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
within_autodiff_rewrite::Bool

handler::T

context::Enzyme.Compiler.EnzymeContext
end

const SigCache = Dict{Tuple, Dict{UInt, Base.IdSet{Type}}}()
Expand Down Expand Up @@ -247,7 +249,8 @@ function EnzymeInterpreter(
inactive_rules::Bool,
broadcast_rewrite::Bool,
within_autodiff_rewrite::Bool,
handler
handler,
Enzyme.Compiler.EnzymeContext(world)
)
end

Expand Down Expand Up @@ -278,7 +281,9 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
within_autodiff_rewrite = interp.within_autodiff_rewrite,
handler = interp.handler)
handler = interp.handler,
context = interp.context,)
@assert context.world == world
return EnzymeInterpreter(
cache_or_token,
mt,
Expand All @@ -291,7 +296,8 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
inactive_rules,
broadcast_rewrite,
within_autodiff_rewrite,
handler
handler,
context
)
end

Expand Down
Loading