Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ if VERSION >= v"1.11.0-DEV.1552"
Interpreter.EnzymeInterpreter(
GPUCompiler.ci_cache_token(job),
GPUCompiler.method_table(job),
GPUCompiler.inference_params(job),
GPUCompiler.optimization_params(job),
job.world,
job.config.params.mode,
true
Expand All @@ -211,6 +213,8 @@ else
Interpreter.EnzymeInterpreter(
enzyme_ci_cache(job),
GPUCompiler.method_table(job),
GPUCompiler.inference_params(job),
GPUCompiler.optimization_params(job),
job.world,
job.config.params.mode,
true
Expand Down Expand Up @@ -6393,7 +6397,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode

target = EnzymeTarget()
rt2 = if A isa UnionAll
rrt = primal_return_type_world(Mode == API.DEM_ForwardMode ? Forward : Reverse, world, mi)
rrt = primal_return_type_world(Mode, world, mi)

# Don't error here but default to nothing return since in cuda context we don't use the device overrides
if rrt == Union{}
Expand Down
19 changes: 11 additions & 8 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ function EnzymeInterpreter(
cache_or_token,
mt::Union{Nothing,Core.MethodTable},
world::UInt,
inf_params::InferenceParams,
opt_params::OptimizationParams,
forward_rules::Bool,
reverse_rules::Bool,
inactive_rules::Bool,
Expand All @@ -133,11 +135,11 @@ function EnzymeInterpreter(
)
@assert world <= Base.get_world_counter()

parms = @static if VERSION >= v"1.12.0-DEV.1017"
InferenceParams()
else
InferenceParams(; unoptimize_throw_blocks=false)
end
# parms = @static if VERSION >= v"1.12.0-DEV.1017"
# InferenceParams()
# else
# InferenceParams(; unoptimize_throw_blocks=false)
# end

@static if HAS_INTEGRATED_CACHE

Expand Down Expand Up @@ -171,10 +173,11 @@ function EnzymeInterpreter(
Base.empty!(cache_or_token)
end
end
method_table = mt == nothing ? Core.Compiler.InternalMethodTable(world) : Core.Compiler.OverlayMethodTable(world, mt),

return EnzymeInterpreter(
cache_or_token,
mt == nothing ? Core.Compiler.InternalMethodTable(world) : Core.Compiler.OverlayMethodTable(world, mt),
method_table,

# Initially empty cache
Vector{InferenceResult}(),
Expand All @@ -183,8 +186,8 @@ function EnzymeInterpreter(
world,

# parameters for inference and optimization
parms,
OptimizationParams(),
inf_params,
opt_params,
forward_rules::Bool,
reverse_rules::Bool,
inactive_rules::Bool,
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function get_job(
end

primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt, world)
rt = Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt)
rt = Compiler.primal_return_type_world(mode, world, primal)

@assert primal !== nothing
rt = A{rt}
Expand Down
25 changes: 4 additions & 21 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,10 @@ end
using InteractiveUtils

function code_typed_helper(mi::Core.MethodInstance, world::UInt, mode::Enzyme.API.CDerivativeMode = Enzyme.API.DEM_ReverseModeCombined; interactive::Bool=false, kwargs...)
CT = @static if VERSION >= v"1.11.0-DEV.1552"
EnzymeCacheToken(
typeof(DefaultCompilerTarget()),
false,
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
EnzymeCompilerParams,
world,
mode == API.DEM_ForwardMode,
mode != API.DEM_ForwardMode,
true
)
else
if mode == API.DEM_ForwardMode
GLOBAL_FWD_CACHE
else
GLOBAL_REV_CACHE
end
end

interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)

target = Compiler.DefaultCompilerTarget()
params = PrimalCompilerParams(mode)
job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params))
interp = GPUCompiler.get_interpreter(job)
sig = mi.specTypes # XXX: can we just use the method instance?
if interactive
# call Cthulhu without introducing a dependency on Cthulhu
Expand Down
78 changes: 8 additions & 70 deletions src/typeutils/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,82 +13,20 @@ function return_type(interp::Core.Compiler.AbstractInterpreter, mi::Core.MethodI
end
end

function primal_interp_world(
@nospecialize(::ReverseMode),
world::UInt
)
mode = Enzyme.API.DEM_ReverseModeCombined

CT = @static if VERSION >= v"1.11.0-DEV.1552"
EnzymeCacheToken(
typeof(DefaultCompilerTarget()),
false,
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
EnzymeCompilerParams,
world,
false,
true,
true
)
else
Enzyme.Compiler.GLOBAL_REV_CACHE
end

Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
function primal_interp_world(mode::Enzyme.API.CDerivativeMode, world, mi)
target = Compiler.DefaultCompilerTarget()
params = PrimalCompilerParams(mode)
job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params), world)
return GPUCompiler.get_interpreter(job)
end

function primal_interp_world(
@nospecialize(::ForwardMode),
world::UInt
)
mode = Enzyme.API.DEM_ForwardMode

CT = @static if VERSION >= v"1.11.0-DEV.1552"
EnzymeCacheToken(
typeof(DefaultCompilerTarget()),
false,
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
EnzymeCompilerParams,
world,
true,
false,
true
)
else
Enzyme.Compiler.GLOBAL_FWD_CACHE
end
primal_interp_world(mode::Mode, world, mi) = primal_interp_world(convert(Enzyme.API.CDerivativeMode, mode), world, mi)

Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
end

@inline primal_interp_world(
@nospecialize(::ReverseModeSplit),
world::UInt) = primal_interp_world(Reverse, world)

function primal_return_type_world(
@nospecialize(mode::Mode),
world::UInt,
@nospecialize(TT::Type),
)
Core.Compiler._return_type(primal_interp_world(mode, world), TT)
end

function primal_return_type_world(
@nospecialize(mode::Mode),
world::UInt,
mi::Core.MethodInstance,
)
interp = primal_interp_world(mode, world)
function primal_return_type_world(mode, world, mi)
interp = primal_interp_world(mode, world, mi)
return_type(interp, mi)
end

primal_return_type_world(
@nospecialize(mode::Mode),
world::UInt,
@nospecialize(FT::Type),
@nospecialize(TT::Type),
) = primal_return_type_world(mode, world, Tuple{FT, TT.parameters...})

function primal_return_type end

function primal_return_type_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type))
Expand Down
Loading