-
Notifications
You must be signed in to change notification settings - Fork 82
More root pathing fixups #2793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More root pathing fixups #2793
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/compiler.jl b/src/compiler.jl
index 45d9df78..75b1a67a 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -3804,59 +3804,59 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
# aka bfs/etc
while length(todo) != 0
path, ty = popfirst!(todo)
- if !any_jltypes(ty)
- continue
- end
+ if !any_jltypes(ty)
+ continue
+ end
if isa(ty, LLVM.PointerType)
- if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
- loc = inbounds_gep!(
- builder,
- root_ty,
- rootRet,
- to_llvm(Cuint[count]),
- )
- end
-
- if direction == SRetPointerToRootPointer
- outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
- outloc = load!(builder, ty, outloc)
- store!(builder, outloc, loc)
- elseif direction == SRetValueToRootPointer
- outloc = Enzyme.API.e_extract_value!(builder, sret, path)
- store!(builder, outloc, loc)
- elseif direction == RootPointerToSRetValue
- loc = load!(builder, ty, loc)
- val = Enzyme.API.e_insert_value!(builder, val, loc, path)
- elseif direction == NullifySRetValue
- loc = unsafe_to_llvm(builder, nothing)
- val = Enzyme.API.e_insert_value!(builder, val, loc, path)
- elseif direction == RootPointerToSRetPointer
- outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
- loc = load!(builder, ty, loc)
- push!(extracted, loc)
- store!(builder, loc, outloc)
- else
- @assert false "Unhandled direction"
- end
-
- count += 1
+ if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
+ loc = inbounds_gep!(
+ builder,
+ root_ty,
+ rootRet,
+ to_llvm(Cuint[count]),
+ )
+ end
+
+ if direction == SRetPointerToRootPointer
+ outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+ outloc = load!(builder, ty, outloc)
+ store!(builder, outloc, loc)
+ elseif direction == SRetValueToRootPointer
+ outloc = Enzyme.API.e_extract_value!(builder, sret, path)
+ store!(builder, outloc, loc)
+ elseif direction == RootPointerToSRetValue
+ loc = load!(builder, ty, loc)
+ val = Enzyme.API.e_insert_value!(builder, val, loc, path)
+ elseif direction == NullifySRetValue
+ loc = unsafe_to_llvm(builder, nothing)
+ val = Enzyme.API.e_insert_value!(builder, val, loc, path)
+ elseif direction == RootPointerToSRetPointer
+ outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
+ loc = load!(builder, ty, loc)
+ push!(extracted, loc)
+ store!(builder, loc, outloc)
+ else
+ @assert false "Unhandled direction"
+ end
+
+ count += 1
continue
end
if isa(ty, LLVM.ArrayType)
- for i = 1:length(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
+ for i in 1:length(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
continue
end
if isa(ty, LLVM.VectorType)
- for i = 1:size(ty)
- npath = copy(path)
- push!(npath, i - 1)
- push!(todo, (npath, eltype(ty)))
+ for i in 1:size(ty)
+ npath = copy(path)
+ push!(npath, i - 1)
+ push!(todo, (npath, eltype(ty)))
end
continue
end
@@ -6906,7 +6906,7 @@ end
@nospecialize(strongzero::Val)
)
id = deferred_id_codegen(fa, a, tt, mode, width, modifiedbetween, returnprimal, shadowinit, expectedtapetype, erriffuncwritten, runtimeactivity, strongzero)
- ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (UInt,), id)
+ return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (UInt,), id)
end
include("compiler/reflection.jl")
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index 6ff32c99..99d89625 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -188,7 +188,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
run!(GCInvariantVerifierPass(strong=false), mod)
- removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
+ removeDeadArgs!(mod, tm, #=post_gc_fixup=# false)
run!(GCInvariantVerifierPass(strong=false), mod)
@@ -364,9 +364,9 @@ const DumpPostCallConv = Ref(false)
function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
addr13NoAlias(mod)
- removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
+ removeDeadArgs!(mod, tm, #=post_gc_fixup=# false)
if DumpPreCallConv[]
- API.EnzymeDumpModuleRef(mod.ref)
+ API.EnzymeDumpModuleRef(mod.ref)
end
for f in collect(functions(mod))
API.EnzymeFixupJuliaCallingConvention(f)
@@ -375,7 +375,7 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
API.EnzymeFixupBatchedJuliaCallingConvention(f)
end
if DumpPostCallConv[]
- API.EnzymeDumpModuleRef(mod.ref)
+ API.EnzymeDumpModuleRef(mod.ref)
end
for g in collect(globals(mod))
if startswith(LLVM.name(g), "ccall")
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 4d43c651..71f26fef 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -776,11 +776,11 @@ function nodecayed_phis!(mod::LLVM.Module)
end
if addrspace(value_type(v2)) == 0
if addr == 11
- PT = if LLVM.is_opaque(value_type(v))
- LLVM.PointerType(10)
- else
- LLVM.PointerType(eltype(value_type(v)), 10)
- end
+ PT = if LLVM.is_opaque(value_type(v))
+ LLVM.PointerType(10)
+ else
+ LLVM.PointerType(eltype(value_type(v)), 10)
+ end
v2 = const_addrspacecast(
v2,
PT
@@ -789,11 +789,11 @@ function nodecayed_phis!(mod::LLVM.Module)
end
end
if LLVM.isnull(v2)
- PT = if LLVM.is_opaque(value_type(v))
- LLVM.PointerType(10)
- else
- LLVM.PointerType(eltype(value_type(v)), 10)
- end
+ PT = if LLVM.is_opaque(value_type(v))
+ LLVM.PointerType(10)
+ else
+ LLVM.PointerType(eltype(value_type(v)), 10)
+ end
v2 = const_addrspacecast(
v2,
PT
@@ -826,7 +826,7 @@ function nodecayed_phis!(mod::LLVM.Module)
offset,
API.EnzymeComputeByteOffsetOfGEP(b, v, offty),
)
- if !LLVM.is_opaque(value_type(v))
+ if !LLVM.is_opaque(value_type(v))
v2 = const_bitcast(
v2,
LLVM.PointerType(
@@ -835,7 +835,7 @@ function nodecayed_phis!(mod::LLVM.Module)
),
)
@assert eltype(value_type(v2)) == eltype(value_type(v))
- end
+ end
return v2, offset, skipload
end
@@ -843,11 +843,11 @@ function nodecayed_phis!(mod::LLVM.Module)
if isa(v, LLVM.AddrSpaceCastInst)
if addrspace(value_type(operands(v)[1])) == 0
- PT = if LLVM.is_opaque(value_type(v))
- LLVM.PointerType(10)
- else
- LLVM.PointerType(eltype(value_type(v)), 10)
- end
+ PT = if LLVM.is_opaque(value_type(v))
+ LLVM.PointerType(10)
+ else
+ LLVM.PointerType(eltype(value_type(v)), 10)
+ end
v2 = addrspacecast!(
b,
operands(v)[1],
@@ -895,16 +895,16 @@ function nodecayed_phis!(mod::LLVM.Module)
)
v2, offset, skipload =
getparent(b, operands(v)[1], offset, hasload, phicache)
- if !LLVM.is_opaque(value_type(v))
- v2 = bitcast!(
- b,
- v2,
- LLVM.PointerType(
- eltype(value_type(v)),
- addrspace(value_type(v2)),
- ),
- )
- end
+ if !LLVM.is_opaque(value_type(v))
+ v2 = bitcast!(
+ b,
+ v2,
+ LLVM.PointerType(
+ eltype(value_type(v)),
+ addrspace(value_type(v2)),
+ ),
+ )
+ end
@assert eltype(value_type(v2)) == eltype(value_type(v))
return v2, offset, skipload
end
@@ -1286,7 +1286,7 @@ function fix_decayaddr!(mod::LLVM.Module)
t_sret = true
end
if kind(a) == kind(StringAttribute("enzymejl_returnRoots"))
- sret_elty = sret_ty(fop, i)
+ sret_elty = sret_ty(fop, i)
t_sret = true
end
# if kind(a) == kind(StringAttribute("enzyme_sret_v"))
diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl
index 5d8533db..6ec6335e 100644
--- a/src/rules/llvmrules.jl
+++ b/src/rules/llvmrules.jl
@@ -1953,12 +1953,12 @@ end
API.moveBefore(newo, err, B)
if unsafe_load(shadowR) != C_NULL
- valTys = API.CValueType[]
- args = LLVM.Value[]
- for i in 1:(length(operands(orig))-1)
- push!(valTys, API.VT_Primal)
- push!(args, new_from_original(gutils, operands(orig)[i]))
- end
+ valTys = API.CValueType[]
+ args = LLVM.Value[]
+ for i in 1:(length(operands(orig)) - 1)
+ push!(valTys, API.VT_Primal)
+ push!(args, new_from_original(gutils, operands(orig)[i]))
+ end
normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=#
width = get_width(gutils)
if width == 1
@@ -1988,13 +1988,13 @@ end
newo = new_from_original(gutils, orig)
API.moveBefore(newo, err, B)
if unsafe_load(shadowR) != C_NULL
- valTys = API.CValueType[]
- args = LLVM.Value[]
- for i in 1:(length(operands(orig))-1)
- push!(valTys, API.VT_Primal)
- push!(args, new_from_original(gutils, operands(orig)[i]))
- end
- normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=#
+ valTys = API.CValueType[]
+ args = LLVM.Value[]
+ for i in 1:(length(operands(orig)) - 1)
+ push!(valTys, API.VT_Primal)
+ push!(args, new_from_original(gutils, operands(orig)[i]))
+ end
+ normal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=#
width = get_width(gutils)
if width == 1
shadowres = normal
@@ -2393,9 +2393,10 @@ end
@fwdfunc(new_structt_fwd),
)
register_handler!(
- ("jl_get_binding_or_error", "ijl_get_binding_or_error",
- "jl_get_binding_value_seqcst", "ijl_get_binding_value_seqcst",
- ),
+ (
+ "jl_get_binding_or_error", "ijl_get_binding_or_error",
+ "jl_get_binding_value_seqcst", "ijl_get_binding_value_seqcst",
+ ),
@augfunc(get_binding_or_error_augfwd),
@revfunc(get_binding_or_error_rev),
@fwdfunc(get_binding_or_error_fwd), |
Benchmark Results
Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/19544687048/artifacts/4630531820. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2793 +/- ##
==========================================
- Coverage 68.28% 68.06% -0.22%
==========================================
Files 58 58
Lines 20521 20495 -26
==========================================
- Hits 14012 13950 -62
- Misses 6509 6545 +36 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
No description provided.