Skip to content

Commit b623edf

Browse files
committed
pass context to enzyme!
1 parent 108da13 commit b623edf

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

src/compiler.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,13 @@ include("llvm/transforms.jl")
486486
include("llvm/passes.jl")
487487
include("typeutils/make_zero.jl")
488488

489-
function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
490-
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
491-
nested_codegen!(mode, mod, funcspec, world)
489+
function nested_codegen!(ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type))
490+
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, ctx.world)
491+
nested_codegen!(ctx, mode, mod, funcspec)
492492
end
493493

494494
function prepare_llvm(interp, mod::LLVM.Module, job, meta)
495+
# TODO: remove enzymejl_world
495496
for f in functions(mod)
496497
attributes = function_attributes(f)
497498
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
@@ -1234,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false)
12341235
const DumpPostNestedOpt = Ref(false)
12351236

12361237
function nested_codegen!(
1238+
ctx::EnzymeContext,
12371239
mode::API.CDerivativeMode,
12381240
mod::LLVM.Module,
12391241
funcspec::Core.MethodInstance,
1240-
world::UInt,
12411242
)
1243+
world = ctx.world
12421244
# TODO: Put a cache here index on `mod` and f->tt
12431245

12441246

@@ -1254,6 +1256,7 @@ function nested_codegen!(
12541256
GPUCompiler.prepare_job!(job)
12551257
otherMod, meta = GPUCompiler.emit_llvm(job)
12561258

1259+
# TODO: interp should be cached since it contains internal caches
12571260
interp = GPUCompiler.get_interpreter(job)
12581261
prepare_llvm(interp, otherMod, job, meta)
12591262

@@ -2398,6 +2401,7 @@ const DumpPostEnzyme = Ref(false)
23982401
const DumpPostWrap = Ref(false)
23992402

24002403
function enzyme!(
2404+
enzyme_context::EnzymeContext,
24012405
job::CompilerJob,
24022406
interp,
24032407
mod::LLVM.Module,
@@ -2513,7 +2517,6 @@ function enzyme!(
25132517
convert(API.CDIFFE_TYPE, rt)
25142518
end
25152519

2516-
enzyme_context = EnzymeContext(job.world)
25172520
GC.@preserve enzyme_context begin
25182521
LLVM.@dispose logic = Logic(enzyme_context) begin
25192522

@@ -2583,6 +2586,7 @@ function enzyme!(
25832586

25842587
if wrap
25852588
augmented_primalf = create_abi_wrapper(
2589+
enzyme_context,
25862590
augmented_primalf,
25872591
TT,
25882592
rt,
@@ -2592,7 +2596,6 @@ function enzyme!(
25922596
width,
25932597
returnPrimal,
25942598
shadow_init,
2595-
world,
25962599
interp,
25972600
runtimeActivity,
25982601
)
@@ -2625,6 +2628,7 @@ function enzyme!(
26252628
) #=atomicAdd=#
26262629
if wrap
26272630
adjointf = create_abi_wrapper(
2631+
enzyme_context,
26282632
adjointf,
26292633
TT,
26302634
rt,
@@ -2634,7 +2638,6 @@ function enzyme!(
26342638
width,
26352639
false,
26362640
shadow_init,
2637-
world,
26382641
interp,
26392642
runtimeActivity
26402643
) #=returnPrimal=#
@@ -2666,6 +2669,7 @@ function enzyme!(
26662669
augmented_primalf = nothing
26672670
if wrap
26682671
adjointf = create_abi_wrapper(
2672+
enzyme_context,
26692673
adjointf,
26702674
TT,
26712675
rt,
@@ -2675,7 +2679,6 @@ function enzyme!(
26752679
width,
26762680
returnPrimal,
26772681
shadow_init,
2678-
world,
26792682
interp,
26802683
runtimeActivity
26812684
)
@@ -2711,6 +2714,7 @@ function enzyme!(
27112714
if wrap
27122715
pf = adjointf
27132716
adjointf = create_abi_wrapper(
2717+
enzyme_context,
27142718
adjointf,
27152719
TT,
27162720
rt,
@@ -2720,7 +2724,6 @@ function enzyme!(
27202724
width,
27212725
returnPrimal,
27222726
shadow_init,
2723-
world,
27242727
interp,
27252728
runtimeActivity
27262729
)
@@ -2792,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27922795
end
27932796

27942797
function create_abi_wrapper(
2798+
ctx::EnzymeContext,
27952799
enzymefn::LLVM.Function,
27962800
@nospecialize(TT::Type),
27972801
@nospecialize(rettype::Type),
@@ -2801,10 +2805,10 @@ function create_abi_wrapper(
28012805
width::Int,
28022806
returnPrimal::Bool,
28032807
shadow_init::Bool,
2804-
world::UInt,
28052808
interp,
28062809
runtime_activity::Bool
28072810
)
2811+
world = ctx.world
28082812
is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined
28092813
is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal
28102814
needs_tape = Mode == API.DEM_ReverseModeGradient
@@ -3087,6 +3091,7 @@ function create_abi_wrapper(
30873091
realparms = LLVM.Value[]
30883092
i = 1
30893093

3094+
# TODO(vchuravy): remove
30903095
for attr in collect(function_attributes(enzymefn))
30913096
if kind(attr) == "enzymejl_world"
30923097
push!(function_attributes(llvm_f), attr)
@@ -3231,7 +3236,7 @@ function create_abi_wrapper(
32313236
elseif T <: BatchDuplicatedFunc
32323237
Func = get_func(T)
32333238
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3234-
llvmf = nested_codegen!(Mode, mod, funcspec, world)
3239+
llvmf = nested_codegen!(ctx, Mode, mod, funcspec)
32353240
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
32363241
Func_RT = return_type(interp, funcspec)
32373242
@assert Func_RT == NTuple{width,T′}
@@ -5102,15 +5107,17 @@ end
51025107
end
51035108
end
51045109

5110+
ctx = EnzymeContext(job.world)
51055111
if params.run_enzyme
51065112
# Generate the adjoint
51075113
memcpy_alloca_to_loadstore(mod)
51085114
force_recompute!(mod)
51095115
API.EnzymeDetectReadonlyOrThrow(mod)
51105116

51115117
adjointf, augmented_primalf, TapeType = enzyme!(
5118+
ctx,
51125119
job,
5113-
interp,
5120+
interp,
51145121
mod,
51155122
primalf,
51165123
TT,
@@ -5209,7 +5216,7 @@ end
52095216
fname = String(name) * pf
52105217
if haskey(functions(mod), fname)
52115218
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job.world)
5212-
llvmf = nested_codegen!(mode, mod, funcspec, job.world)
5219+
llvmf = nested_codegen!(ctx, mode, mod, funcspec)
52135220
push!(function_attributes(llvmf), StringAttribute("implements", fname))
52145221
end
52155222
end

src/rules/customrules.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,8 @@ end
592592

593593
curent_bb = position(B)
594594
fn = LLVM.parent(curent_bb)
595-
world = enzyme_extract_world(fn)
596-
@assert world == enzyme_context(gutils).world
597595

598-
llvmf = nested_codegen!(mode, mod, fmi, world)
596+
llvmf = nested_codegen!(enzyme_context(gutils), mode, mod, fmi)
599597

600598
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
601599

@@ -1053,8 +1051,8 @@ function enzyme_custom_common_rev(
10531051

10541052
curent_bb = position(B)
10551053
fn = LLVM.parent(curent_bb)
1056-
world = enzyme_extract_world(fn)
1057-
@assert world == enzyme_context(gutils).world
1054+
ctx = enzyme_context(gutils)
1055+
world = ctx.world
10581056

10591057
mode = get_mode(gutils)
10601058

@@ -1115,7 +1113,7 @@ function enzyme_custom_common_rev(
11151113
applicablefn = true
11161114

11171115
if forward
1118-
llvmf = nested_codegen!(mode, mod, ami, world)
1116+
llvmf = nested_codegen!(ctx, mode, mod, ami)
11191117
@assert llvmf !== nothing
11201118
rev_RT = nothing
11211119
else
@@ -1157,7 +1155,7 @@ function enzyme_custom_common_rev(
11571155

11581156
rmi = rmi::Core.MethodInstance
11591157
rev_RT = rev_RT::Type
1160-
llvmf = nested_codegen!(mode, mod, rmi, world)
1158+
llvmf = nested_codegen!(ctx, mode, mod, rmi)
11611159
end
11621160

11631161
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

src/rules/parallelrules.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,8 @@ end
550550

551551
tt = Tuple{thunkTy,dfuncT,Bool}
552552
mode = get_mode(gutils)
553-
world = enzyme_extract_world(LLVM.parent(position(B)))
554-
@assert world == enzyme_context(gutils).world
555-
entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world)
553+
ctx = enzyme_context(gutils)
554+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_fwd, tt)
556555
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
557556

558557
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -595,9 +594,8 @@ end
595594
Bool,
596595
}
597596
mode = get_mode(gutils)
598-
world = enzyme_extract_world(LLVM.parent(position(B)))
599-
@assert world == enzyme_context(gutils).world
600-
entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world)
597+
ctx = enzyme_context(gutils)
598+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_augfwd, tt)
601599
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
602600

603601
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -629,8 +627,6 @@ end
629627

630628
@register_rev function threadsfor_rev(B, orig, gutils, tape)
631629
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
632-
world = enzyme_extract_world(LLVM.parent(position(B)))
633-
@assert world == enzyme_context(gutils).world
634630
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
635631
return
636632
end
@@ -653,7 +649,8 @@ end
653649
Bool,
654650
}
655651
mode = get_mode(gutils)
656-
entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world)
652+
ctx = enzyme_context(gutils)
653+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_rev, tt)
657654
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
658655

659656
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))

src/rules/typeunstablerules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ end
10361036
if legal
10371037
@assert legal
10381038
world = enzyme_extract_world(LLVM.parent(position(B)))
1039-
@assert world == enzyme_context(gutils).world
1039+
@assert world == enzyme_context(gutils).world
10401040
torun = !guaranteed_nonactive(TT, world)
10411041
else
10421042
torun = true

0 commit comments

Comments
 (0)