Skip to content

Commit 6881fab

Browse files
committed
More root pathing fixups
1 parent a54bfbd commit 6881fab

File tree

2 files changed

+46
-45
lines changed

2 files changed

+46
-45
lines changed

src/compiler.jl

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3802,58 +3802,59 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
38023802
# aka bfs/etc
38033803
while length(todo) != 0
38043804
path, ty = popfirst!(todo)
3805+
if !any_jltypes(ty)
3806+
continue
3807+
end
3808+
38053809
if isa(ty, LLVM.PointerType)
3806-
if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3807-
loc = inbounds_gep!(
3808-
builder,
3809-
root_ty,
3810-
rootRet,
3811-
to_llvm(Cuint[count]),
3812-
)
3813-
end
3814-
3815-
if direction == SRetPointerToRootPointer
3816-
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3817-
outloc = load!(builder, ty, outloc)
3818-
store!(builder, outloc, loc)
3819-
elseif direction == SRetValueToRootPointer
3820-
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
3821-
store!(builder, outloc, loc)
3822-
elseif direction == RootPointerToSRetValue
3823-
loc = load!(builder, ty, loc)
3824-
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3825-
elseif direction == NullifySRetValue
3826-
loc = unsafe_to_llvm(builder, nothing)
3827-
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3828-
elseif direction == RootPointerToSRetPointer
3829-
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3830-
loc = load!(builder, ty, loc)
3831-
push!(extracted, loc)
3832-
store!(builder, loc, outloc)
3833-
else
3834-
@assert false "Unhandled direction"
3835-
end
3836-
3837-
count += 1
3810+
3811+
if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3812+
loc = inbounds_gep!(
3813+
builder,
3814+
root_ty,
3815+
rootRet,
3816+
to_llvm(Cuint[count]),
3817+
)
3818+
end
3819+
3820+
if direction == SRetPointerToRootPointer
3821+
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3822+
outloc = load!(builder, ty, outloc)
3823+
store!(builder, outloc, loc)
3824+
elseif direction == SRetValueToRootPointer
3825+
outloc = Enzyme.API.e_extract_value!(builder, sret, path)
3826+
store!(builder, outloc, loc)
3827+
elseif direction == RootPointerToSRetValue
3828+
loc = load!(builder, ty, loc)
3829+
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3830+
elseif direction == NullifySRetValue
3831+
loc = unsafe_to_llvm(builder, nothing)
3832+
val = Enzyme.API.e_insert_value!(builder, val, loc, path)
3833+
elseif direction == RootPointerToSRetPointer
3834+
outloc = inbounds_gep!(builder, jltype, sret, to_llvm(path))
3835+
loc = load!(builder, ty, loc)
3836+
push!(extracted, loc)
3837+
store!(builder, loc, outloc)
3838+
else
3839+
@assert false "Unhandled direction"
3840+
end
3841+
3842+
count += 1
38383843
continue
38393844
end
38403845
if isa(ty, LLVM.ArrayType)
3841-
if any_jltypes(ty)
3842-
for i = 1:length(ty)
3843-
npath = copy(path)
3844-
push!(npath, i - 1)
3845-
push!(todo, (npath, eltype(ty)))
3846-
end
3846+
for i = 1:length(ty)
3847+
npath = copy(path)
3848+
push!(npath, i - 1)
3849+
push!(todo, (npath, eltype(ty)))
38473850
end
38483851
continue
38493852
end
38503853
if isa(ty, LLVM.VectorType)
3851-
if any_jltypes(ty)
3852-
for i = 1:size(ty)
3853-
npath = copy(path)
3854-
push!(npath, i - 1)
3855-
push!(todo, (npath, eltype(ty)))
3856-
end
3854+
for i = 1:size(ty)
3855+
npath = copy(path)
3856+
push!(npath, i - 1)
3857+
push!(todo, (npath, eltype(ty)))
38573858
end
38583859
continue
38593860
end

src/rules/customrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2051,7 +2051,7 @@ end
20512051
fop = called_operand(orig)::LLVM.Function
20522052
for (i, v) in enumerate(operands(orig)[1:end-1])
20532053
if v == val
2054-
if !has_arg_attr(fop, i, StringAttribute("enzymejl_returnRoots"))
2054+
if true || !has_arg_attr(fop, i, StringAttribute("enzymejl_returnRoots"))
20552055
non_rooting_use = true
20562056
break
20572057
end

0 commit comments

Comments
 (0)