Skip to content

Commit 2687373

Browse files
committed
Make sure that we pick up new definition of typetree
1 parent c889e43 commit 2687373

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/compiler.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import Enzyme:
1414
eltype,
1515
API,
1616
TypeTree,
17-
typetree,
17+
typetree, typetree_total,
1818
TypeTreeTable,
1919
only!,
2020
shift!,
@@ -1078,7 +1078,7 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
10781078

10791079
byref = arg.cc
10801080

1081-
rest = copy(typetree(arg.typ, ctx, dl))
1081+
rest = copy(typetree_total(job, arg.typ, ctx, dl))
10821082

10831083
if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF
10841084
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
@@ -1103,7 +1103,7 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11031103
if sret !== nothing
11041104
idx = 0
11051105
if !in(0, parmsRemoved)
1106-
rest = typetree(sret, ctx, dl)
1106+
rest = typetree_total(job, sret, ctx, dl)
11071107
push!(
11081108
parameter_attributes(f, idx + 1),
11091109
StringAttribute("enzyme_type", string(rest)),
@@ -1125,12 +1125,12 @@ function set_module_types!(mod::LLVM.Module, primalf::Union{Nothing, LLVM.Functi
11251125
LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType()
11261126
@assert !retRemoved
11271127
rest = if llRT == Ptr{RT}
1128-
typeTree = copy(typetree(RT, ctx, dl))
1128+
typeTree = copy(typetree_total(job, RT, ctx, dl))
11291129
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
11301130
only!(typeTree, -1)
11311131
typeTree
11321132
else
1133-
typetree(RT, ctx, dl)
1133+
typetree_total(job, RT, ctx, dl)
11341134
end
11351135
push!(return_attributes(f), StringAttribute("enzyme_type", string(rest)))
11361136
end
@@ -2313,7 +2313,7 @@ function enzyme!(
23132313
else
23142314
error("illegal annotation type $T")
23152315
end
2316-
typeTree = typetree(source_typ, ctx, dl, seen)
2316+
typeTree = typetree_total(job, source_typ, ctx, dl, seen)
23172317
if isboxed
23182318
typeTree = copy(typeTree)
23192319
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
@@ -2355,7 +2355,7 @@ function enzyme!(
23552355
in(Any, actualRetType.parameters)
23562356
TypeTree()
23572357
else
2358-
typeTree = typetree(actualRetType, ctx, dl, seen)
2358+
typeTree = typetree_total(job, actualRetType, ctx, dl, seen)
23592359
if !isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)
23602360
typeTree = copy(typeTree)
23612361
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
@@ -3703,7 +3703,7 @@ function lower_convention(
37033703
metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
37043704
end
37053705

3706-
typeTree = copy(typetree(actualRetType, ctx, dl, seen))
3706+
typeTree = copy(typetree(job, actualRetType, ctx, dl, seen))
37073707
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
37083708
only!(typeTree, -1)
37093709
metadata(sretPtr)["enzyme_type"] = to_md(typeTree, ctx)

src/typetree.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ end
195195

196196
const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}}
197197

198+
"""
199+
typetree_total(job, T, ctx, dl, seen=TypeTreeTable())
200+
201+
A wrapper around `typetree` that ensures the call happens in the correct world for GPUCompiler.
202+
Useful when using typetree from a generated function since typetree is user-extendable.
203+
"""
204+
function typetree_total(@nospecialize(job::GPUCompiler.CompilerJob), @nospecialize(T), ctx, dl, seen=TypeTreeTable())
205+
return Core._call_in_world_total(job.world, typetree, T, ctx, dl)
206+
end
207+
198208
"""
199209
function typetree(T, ctx, dl, seen=TypeTreeTable())
200210

0 commit comments

Comments
 (0)