Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.16"
Enzyme_jll = "0.0.216"
Enzyme_jll = "0.0.217"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
LLVM = "9.1"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/src/easyrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
push!(genexprs, Expr(:(=), :dΩ, :(cache[end])))
else
throw(AssertionError("Easy Rule should never be provided a constant reverse seed"))
push!(genexprs, Expr(Base.throw, AssertionError("Easy Rule should never be provided a constant reverse seed")))
end

actives = Union{Nothing, Expr}[$(actives...)]
Expand Down
95 changes: 49 additions & 46 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3130,6 +3130,7 @@ function create_abi_wrapper(

returnRoots = false
root_ty = nothing
tracked = nothing
if uses_sret
returnRoots = deserves_rooting(jltype)
if returnRoots
Expand Down Expand Up @@ -3192,6 +3193,7 @@ function create_abi_wrapper(
end
push!(parameter_attributes(llvm_f, 1), attr)
push!(parameter_attributes(llvm_f, 1), EnumAttribute("noalias"))
push!(parameter_attributes(llvm_f, 2), StringAttribute("enzymejl_returnRoots", string(Int(tracked.count))))
push!(parameter_attributes(llvm_f, 2), EnumAttribute("noalias"))
elseif jltype != T_void
sret = alloca!(builder, jltype)
Expand Down Expand Up @@ -3802,58 +3804,59 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
# aka bfs/etc
while length(todo) != 0
path, ty = popfirst!(todo)
if !any_jltypes(ty)
continue
end

if isa(ty, LLVM.PointerType)
if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
loc = inbounds_gep!(
builder,
root_ty,
rootRet,
to_llvm(Cuint[count]),
)
end

if direction == SRetPointerToRootPointer
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
outloc = load!(builder, ty, outloc)
store!(builder, outloc, loc)
elseif direction == SRetValueToRootPointer
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
store!(builder, outloc, loc)
elseif direction == RootPointerToSRetValue
loc = load!(builder, ty, loc)
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
elseif direction == NullifySRetValue
loc = unsafe_to_llvm(builder, nothing)
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
elseif direction == RootPointerToSRetPointer
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
loc = load!(builder, ty, loc)
push!(extracted, loc)
store!(builder, loc, outloc)
else
@assert false "Unhandled direction"
end

count += 1

if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
loc = inbounds_gep!(
builder,
root_ty,
rootRet,
to_llvm(Cuint[count]),
)
end

if direction == SRetPointerToRootPointer
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
outloc = load!(builder, ty, outloc)
store!(builder, outloc, loc)
elseif direction == SRetValueToRootPointer
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
store!(builder, outloc, loc)
elseif direction == RootPointerToSRetValue
loc = load!(builder, ty, loc)
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
elseif direction == NullifySRetValue
loc = unsafe_to_llvm(builder, nothing)
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
elseif direction == RootPointerToSRetPointer
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
loc = load!(builder, ty, loc)
push!(extracted, loc)
store!(builder, loc, outloc)
else
@assert false "Unhandled direction"
end

count += 1
continue
end
if isa(ty, LLVM.ArrayType)
if any_jltypes(ty)
for i = 1:length(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
for i = 1:length(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end
if isa(ty, LLVM.VectorType)
if any_jltypes(ty)
for i = 1:size(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
for i = 1:size(ty)
npath = copy(path)
push!(npath, i - 1)
push!(todo, (npath, eltype(ty)))
end
continue
end
Expand Down Expand Up @@ -6849,7 +6852,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
id = Base.reinterpret(Int, pointer(addr))
deferred_codegen_jobs[id] = job

code = Any[Core.Compiler.ReturnNode(reinterpret(Ptr{Cvoid}, id))]
code = Any[Core.Compiler.ReturnNode(reinterpret(UInt, id))]
ci = create_fresh_codeinfo(deferred_id_codegen, source, world, slotnames, code)

ci.edges = Any[mi]
Expand Down Expand Up @@ -6903,7 +6906,7 @@ end
@nospecialize(strongzero::Val)
)
id = deferred_id_codegen(fa, a, tt, mode, width, modifiedbetween, returnprimal, shadowinit, expectedtapetype, erriffuncwritten, runtimeactivity, strongzero)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), id)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (UInt,), id)
end

include("compiler/reflection.jl")
Expand Down
15 changes: 12 additions & 3 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)

run!(GCInvariantVerifierPass(strong=false), mod)

removeDeadArgs!(mod, tm)
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)

run!(GCInvariantVerifierPass(strong=false), mod)

detect_writeonly!(mod)
API.EnzymeDetectReadonlyOrThrow(mod)

run!(GCInvariantVerifierPass(strong=false), mod)

Expand Down Expand Up @@ -359,15 +359,24 @@ function addJuliaLegalizationPasses!(mpm::LLVM.NewPMPassManager, lower_intrinsic
end
end

const DumpPreCallConv = Ref(false)
const DumpPostCallConv = Ref(false)

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
addr13NoAlias(mod)
removeDeadArgs!(mod, tm)
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
if DumpPreCallConv[]
API.EnzymeDumpModuleRef(mod.ref)
end
for f in collect(functions(mod))
API.EnzymeFixupJuliaCallingConvention(f)
end
for f in collect(functions(mod))
API.EnzymeFixupBatchedJuliaCallingConvention(f)
end
if DumpPostCallConv[]
API.EnzymeDumpModuleRef(mod.ref)
end
for g in collect(globals(mod))
if startswith(LLVM.name(g), "ccall")
hasuse = false
Expand Down
5 changes: 3 additions & 2 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1169,8 +1169,9 @@ function julia_error(
print(io, msg)
println(io)
data2 = LLVM.Value(data2)
println(io, "Fn = ", string(LLVM.parent(data2::LLVM.Argument)))
println(io, "argFn = ", string(data2::LLVM.Argument))
fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
println(io, "Fn = ", string(fn))
println(io, "arg = ", string(data2::LLVM.Argument))
if data !== C_NULL
data = LLVM.Value(LLVM.API.LLVMValueRef(data))
println(io, "cur = ", string(data))
Expand Down
Loading
Loading