@@ -3130,6 +3130,7 @@ function create_abi_wrapper(
31303130
31313131 returnRoots = false
31323132 root_ty = nothing
3133+ tracked = nothing
31333134 if uses_sret
31343135 returnRoots = deserves_rooting (jltype)
31353136 if returnRoots
@@ -3192,6 +3193,7 @@ function create_abi_wrapper(
31923193 end
31933194 push! (parameter_attributes (llvm_f, 1 ), attr)
31943195 push! (parameter_attributes (llvm_f, 1 ), EnumAttribute (" noalias" ))
3196+ push! (parameter_attributes (llvm_f, 2 ), StringAttribute (" enzymejl_returnRoots" , string (Int (tracked. count))))
31953197 push! (parameter_attributes (llvm_f, 2 ), EnumAttribute (" noalias" ))
31963198 elseif jltype != T_void
31973199 sret = alloca! (builder, jltype)
@@ -3802,58 +3804,59 @@ function move_sret_tofrom_roots!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType,
38023804 # aka bfs/etc
38033805 while length (todo) != 0
38043806 path, ty = popfirst! (todo)
3807+ if ! any_jltypes (ty)
3808+ continue
3809+ end
3810+
38053811 if isa (ty, LLVM. PointerType)
3806- if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3807- loc = inbounds_gep! (
3808- builder,
3809- root_ty,
3810- rootRet,
3811- to_llvm (Cuint[count]),
3812- )
3813- end
3814-
3815- if direction == SRetPointerToRootPointer
3816- outloc = inbounds_gep! (builder, jltype, sret, to_llvm (path))
3817- outloc = load! (builder, ty, outloc)
3818- store! (builder, outloc, loc)
3819- elseif direction == SRetValueToRootPointer
3820- outloc = Enzyme. API. e_extract_value! (builder, sret, path)
3821- store! (builder, outloc, loc)
3822- elseif direction == RootPointerToSRetValue
3823- loc = load! (builder, ty, loc)
3824- val = Enzyme. API. e_insert_value! (builder, val, loc, path)
3825- elseif direction == NullifySRetValue
3826- loc = unsafe_to_llvm (builder, nothing )
3827- val = Enzyme. API. e_insert_value! (builder, val, loc, path)
3828- elseif direction == RootPointerToSRetPointer
3829- outloc = inbounds_gep! (builder, jltype, sret, to_llvm (path))
3830- loc = load! (builder, ty, loc)
3831- push! (extracted, loc)
3832- store! (builder, loc, outloc)
3833- else
3834- @assert false " Unhandled direction"
3835- end
3836-
3837- count += 1
3812+
3813+ if direction == SRetPointerToRootPointer || direction == SRetValueToRootPointer || direction == RootPointerToSRetPointer || direction == RootPointerToSRetValue
3814+ loc = inbounds_gep! (
3815+ builder,
3816+ root_ty,
3817+ rootRet,
3818+ to_llvm (Cuint[count]),
3819+ )
3820+ end
3821+
3822+ if direction == SRetPointerToRootPointer
3823+ outloc = inbounds_gep! (builder, jltype, sret, to_llvm (path))
3824+ outloc = load! (builder, ty, outloc)
3825+ store! (builder, outloc, loc)
3826+ elseif direction == SRetValueToRootPointer
3827+ outloc = Enzyme. API. e_extract_value! (builder, sret, path)
3828+ store! (builder, outloc, loc)
3829+ elseif direction == RootPointerToSRetValue
3830+ loc = load! (builder, ty, loc)
3831+ val = Enzyme. API. e_insert_value! (builder, val, loc, path)
3832+ elseif direction == NullifySRetValue
3833+ loc = unsafe_to_llvm (builder, nothing )
3834+ val = Enzyme. API. e_insert_value! (builder, val, loc, path)
3835+ elseif direction == RootPointerToSRetPointer
3836+ outloc = inbounds_gep! (builder, jltype, sret, to_llvm (path))
3837+ loc = load! (builder, ty, loc)
3838+ push! (extracted, loc)
3839+ store! (builder, loc, outloc)
3840+ else
3841+ @assert false " Unhandled direction"
3842+ end
3843+
3844+ count += 1
38383845 continue
38393846 end
38403847 if isa (ty, LLVM. ArrayType)
3841- if any_jltypes (ty)
3842- for i = 1 : length (ty)
3843- npath = copy (path)
3844- push! (npath, i - 1 )
3845- push! (todo, (npath, eltype (ty)))
3846- end
3848+ for i = 1 : length (ty)
3849+ npath = copy (path)
3850+ push! (npath, i - 1 )
3851+ push! (todo, (npath, eltype (ty)))
38473852 end
38483853 continue
38493854 end
38503855 if isa (ty, LLVM. VectorType)
3851- if any_jltypes (ty)
3852- for i = 1 : size (ty)
3853- npath = copy (path)
3854- push! (npath, i - 1 )
3855- push! (todo, (npath, eltype (ty)))
3856- end
3856+ for i = 1 : size (ty)
3857+ npath = copy (path)
3858+ push! (npath, i - 1 )
3859+ push! (todo, (npath, eltype (ty)))
38573860 end
38583861 continue
38593862 end
@@ -6849,7 +6852,7 @@ function deferred_id_generator(world::UInt, source::Union{Method, LineNumberNode
68496852 id = Base. reinterpret (Int, pointer (addr))
68506853 deferred_codegen_jobs[id] = job
68516854
6852- code = Any[Core. Compiler. ReturnNode (reinterpret (Ptr{Cvoid} , id))]
6855+ code = Any[Core. Compiler. ReturnNode (reinterpret (UInt , id))]
68536856 ci = create_fresh_codeinfo (deferred_id_codegen, source, world, slotnames, code)
68546857
68556858 ci. edges = Any[mi]
@@ -6903,7 +6906,7 @@ end
69036906 @nospecialize (strongzero:: Val )
69046907)
69056908 id = deferred_id_codegen (fa, a, tt, mode, width, modifiedbetween, returnprimal, shadowinit, expectedtapetype, erriffuncwritten, runtimeactivity, strongzero)
6906- ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Ptr{Cvoid} ,), id)
6909+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (UInt ,), id)
69076910end
69086911
69096912include (" compiler/reflection.jl" )
0 commit comments