Skip to content

Commit 70eef9d

Browse files
authored
More root pathing fixups (#2793)
* More root pathing fixups * fix * fix * fix * more * fix * Update Enzyme_jll version to 0.0.217
1 parent cd6eb98 commit 70eef9d

File tree

8 files changed

+119
-179
lines changed

8 files changed

+119
-179
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
4242
CEnum = "0.4, 0.5"
4343
ChainRulesCore = "1"
4444
EnzymeCore = "0.8.16"
45-
Enzyme_jll = "0.0.216"
45+
Enzyme_jll = "0.0.217"
4646
GPUArraysCore = "0.1.6, 0.2"
4747
GPUCompiler = "1.6.2"
4848
LLVM = "9.1"

lib/EnzymeCore/src/easyrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, input_names
559559
elseif RTA <: Type{<:Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated, EnzymeCore.BatchDuplicatedNoNeed}}
560560
push!(genexprs, Expr(:(=), :dΩ, :(cache[end])))
561561
else
562-
throw(AssertionError("Easy Rule should never be provided a constant reverse seed"))
562+
push!(genexprs, Expr(Base.throw, AssertionError("Easy Rule should never be provided a constant reverse seed")))
563563
end
564564

565565
actives = Union{Nothing, Expr}[$(actives...)]

src/compiler.jl

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,7 @@ function create_abi_wrapper(
31303130

31313131
returnRoots = false
31323132
root_ty = nothing
3133+
tracked = nothing
31333134
if uses_sret
31343135
returnRoots = deserves_rooting(jltype)
31353136
if returnRoots
@@ -3192,6 +3193,7 @@ function create_abi_wrapper(
31923193
end
31933194
push!(parameter_attributes(llvm_f, 1), attr)
31943195
push!(parameter_attributes(llvm_f, 1), EnumAttribute("noalias"))
3196+
push!(parameter_attributes(llvm_f, 2), StringAttribute("enzymejl_returnRoots", string(Int(tracked.count))))
31953197
push!(parameter_attributes(llvm_f, 2), EnumAttribute("noalias"))
31963198
elseif jltype != T_void
31973199
sret = alloca!(builder, jltype)
@@ -3802,58 +3804,59 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
38023804
# aka bfs/etc
38033805
while length(todo) != 0
38043806
path, ty = popfirst!(todo)
3807+
if !any_jltypes(ty)
3808+
continue
3809+
end
3810+
38053811
if isa(ty, LLVM.PointerType)
3806-
if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3807-
loc = inbounds_gep!(
3808-
builder,
3809-
root_ty,
3810-
rootRet,
3811-
to_llvm(Cuint[count]),
3812-
)
3813-
end
3814-
3815-
if direction == SRetPointerToRootPointer
3816-
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3817-
outloc = load!(builder, ty, outloc)
3818-
store!(builder, outloc, loc)
3819-
elseif direction == SRetValueToRootPointer
3820-
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
3821-
store!(builder, outloc, loc)
3822-
elseif direction == RootPointerToSRetValue
3823-
loc = load!(builder, ty, loc)
3824-
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3825-
elseif direction == NullifySRetValue
3826-
loc = unsafe_to_llvm(builder, nothing)
3827-
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3828-
elseif direction == RootPointerToSRetPointer
3829-
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3830-
loc = load!(builder, ty, loc)
3831-
push!(extracted, loc)
3832-
store!(builder, loc, outloc)
3833-
else
3834-
@assert false "Unhandled direction"
3835-
end
3836-
3837-
count += 1
3812+
3813+
if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3814+
loc = inbounds_gep!(
3815+
builder,
3816+
root_ty,
3817+
rootRet,
3818+
to_llvm(Cuint[count]),
3819+
)
3820+
end
3821+
3822+
if direction == SRetPointerToRootPointer
3823+
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3824+
outloc = load!(builder, ty, outloc)
3825+
store!(builder, outloc, loc)
3826+
elseif direction == SRetValueToRootPointer
3827+
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
3828+
store!(builder, outloc, loc)
3829+
elseif direction == RootPointerToSRetValue
3830+
loc = load!(builder, ty, loc)
3831+
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3832+
elseif direction == NullifySRetValue
3833+
loc = unsafe_to_llvm(builder, nothing)
3834+
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3835+
elseif direction == RootPointerToSRetPointer
3836+
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3837+
loc = load!(builder, ty, loc)
3838+
push!(extracted, loc)
3839+
store!(builder, loc, outloc)
3840+
else
3841+
@assert false "Unhandled direction"
3842+
end
3843+
3844+
count += 1
38383845
continue
38393846
end
38403847
if isa(ty, LLVM.ArrayType)
3841-
if any_jltypes(ty)
3842-
for i = 1:length(ty)
3843-
npath = copy(path)
3844-
push!(npath, i - 1)
3845-
push!(todo, (npath, eltype(ty)))
3846-
end
3848+
for i = 1:length(ty)
3849+
npath = copy(path)
3850+
push!(npath, i - 1)
3851+
push!(todo, (npath, eltype(ty)))
38473852
end
38483853
continue
38493854
end
38503855
if isa(ty, LLVM.VectorType)
3851-
if any_jltypes(ty)
3852-
for i = 1:size(ty)
3853-
npath = copy(path)
3854-
push!(npath, i - 1)
3855-
push!(todo, (npath, eltype(ty)))
3856-
end
3856+
for i = 1:size(ty)
3857+
npath = copy(path)
3858+
push!(npath, i - 1)
3859+
push!(todo, (npath, eltype(ty)))
38573860
end
38583861
continue
38593862
end
@@ -6849,7 +6852,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
68496852
id = Base.reinterpret(Int, pointer(addr))
68506853
deferred_codegen_jobs[id] = job
68516854

6852-
code = Any[Core.Compiler.ReturnNode(reinterpret(Ptr{Cvoid}, id))]
6855+
code = Any[Core.Compiler.ReturnNode(reinterpret(UInt, id))]
68536856
ci = create_fresh_codeinfo(deferred_id_codegen, source, world, slotnames, code)
68546857

68556858
ci.edges = Any[mi]
@@ -6903,7 +6906,7 @@ end
69036906
@nospecialize(strongzero::Val)
69046907
)
69056908
id = deferred_id_codegen(fa, a, tt, mode, width, modifiedbetween, returnprimal, shadowinit, expectedtapetype, erriffuncwritten, runtimeactivity, strongzero)
6906-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), id)
6909+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (UInt,), id)
69076910
end
69086911

69096912
include("compiler/reflection.jl")

src/compiler/optimize.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
188188

189189
run!(GCInvariantVerifierPass(strong=false), mod)
190190

191-
removeDeadArgs!(mod, tm)
191+
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
192192

193193
run!(GCInvariantVerifierPass(strong=false), mod)
194194

195-
detect_writeonly!(mod)
195+
API.EnzymeDetectReadonlyOrThrow(mod)
196196

197197
run!(GCInvariantVerifierPass(strong=false), mod)
198198

@@ -359,15 +359,24 @@ function addJuliaLegalizationPasses!(mpm::LLVM.NewPMPassManager, lower_intrinsic
359359
end
360360
end
361361

362+
const DumpPreCallConv = Ref(false)
363+
const DumpPostCallConv = Ref(false)
364+
362365
function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
363366
addr13NoAlias(mod)
364-
removeDeadArgs!(mod, tm)
367+
removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
368+
if DumpPreCallConv[]
369+
API.EnzymeDumpModuleRef(mod.ref)
370+
end
365371
for f in collect(functions(mod))
366372
API.EnzymeFixupJuliaCallingConvention(f)
367373
end
368374
for f in collect(functions(mod))
369375
API.EnzymeFixupBatchedJuliaCallingConvention(f)
370376
end
377+
if DumpPostCallConv[]
378+
API.EnzymeDumpModuleRef(mod.ref)
379+
end
371380
for g in collect(globals(mod))
372381
if startswith(LLVM.name(g), "ccall")
373382
hasuse = false

0 commit comments

Comments
 (0)