Skip to content

Commit 1bfee3a

Browse files
committed
pass context to enzyme!
1 parent 79dc110 commit 1bfee3a

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

src/compiler.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -484,12 +484,13 @@ include("llvm/transforms.jl")
484484
include("llvm/passes.jl")
485485
include("typeutils/make_zero.jl")
486486

487-
function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
488-
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
489-
nested_codegen!(mode, mod, funcspec, world)
487+
function nested_codegen!(ctx::EnzymeContext, mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type))
488+
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, ctx.world)
489+
nested_codegen!(ctx, mode, mod, funcspec)
490490
end
491491

492492
function prepare_llvm(interp, mod::LLVM.Module, job, meta)
493+
# TODO: remove enzymejl_world
493494
for f in functions(mod)
494495
attributes = function_attributes(f)
495496
push!(attributes, StringAttribute("enzymejl_world", string(job.world)))
@@ -1228,11 +1229,12 @@ const DumpPreNestedOpt = Ref(false)
12281229
const DumpPostNestedOpt = Ref(false)
12291230

12301231
function nested_codegen!(
1232+
ctx::EnzymeContext,
12311233
mode::API.CDerivativeMode,
12321234
mod::LLVM.Module,
12331235
funcspec::Core.MethodInstance,
1234-
world::UInt,
12351236
)
1237+
world = ctx.world
12361238
# TODO: Put a cache here index on `mod` and f->tt
12371239

12381240

@@ -1248,6 +1250,7 @@ function nested_codegen!(
12481250
GPUCompiler.prepare_job!(job)
12491251
otherMod, meta = GPUCompiler.emit_llvm(job)
12501252

1253+
# TODO: interp should be cached since it contains internal caches
12511254
interp = GPUCompiler.get_interpreter(job)
12521255
prepare_llvm(interp, otherMod, job, meta)
12531256

@@ -2389,6 +2392,7 @@ const DumpPostEnzyme = Ref(false)
23892392
const DumpPostWrap = Ref(false)
23902393

23912394
function enzyme!(
2395+
enzyme_context::EnzymeContext,
23922396
job::CompilerJob,
23932397
interp,
23942398
mod::LLVM.Module,
@@ -2504,7 +2508,6 @@ function enzyme!(
25042508
convert(API.CDIFFE_TYPE, rt)
25052509
end
25062510

2507-
enzyme_context = EnzymeContext(job.world)
25082511
GC.@preserve enzyme_context begin
25092512
LLVM.@dispose logic = Logic(enzyme_context) begin
25102513

@@ -2574,6 +2577,7 @@ function enzyme!(
25742577

25752578
if wrap
25762579
augmented_primalf = create_abi_wrapper(
2580+
enzyme_context,
25772581
augmented_primalf,
25782582
TT,
25792583
rt,
@@ -2583,7 +2587,6 @@ function enzyme!(
25832587
width,
25842588
returnPrimal,
25852589
shadow_init,
2586-
world,
25872590
interp,
25882591
runtimeActivity,
25892592
)
@@ -2616,6 +2619,7 @@ function enzyme!(
26162619
) #=atomicAdd=#
26172620
if wrap
26182621
adjointf = create_abi_wrapper(
2622+
enzyme_context,
26192623
adjointf,
26202624
TT,
26212625
rt,
@@ -2625,7 +2629,6 @@ function enzyme!(
26252629
width,
26262630
false,
26272631
shadow_init,
2628-
world,
26292632
interp,
26302633
runtimeActivity
26312634
) #=returnPrimal=#
@@ -2657,6 +2660,7 @@ function enzyme!(
26572660
augmented_primalf = nothing
26582661
if wrap
26592662
adjointf = create_abi_wrapper(
2663+
enzyme_context,
26602664
adjointf,
26612665
TT,
26622666
rt,
@@ -2666,7 +2670,6 @@ function enzyme!(
26662670
width,
26672671
returnPrimal,
26682672
shadow_init,
2669-
world,
26702673
interp,
26712674
runtimeActivity
26722675
)
@@ -2702,6 +2705,7 @@ function enzyme!(
27022705
if wrap
27032706
pf = adjointf
27042707
adjointf = create_abi_wrapper(
2708+
enzyme_context,
27052709
adjointf,
27062710
TT,
27072711
rt,
@@ -2711,7 +2715,6 @@ function enzyme!(
27112715
width,
27122716
returnPrimal,
27132717
shadow_init,
2714-
world,
27152718
interp,
27162719
runtimeActivity
27172720
)
@@ -2786,6 +2789,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27862789
end
27872790

27882791
function create_abi_wrapper(
2792+
ctx::EnzymeContext,
27892793
enzymefn::LLVM.Function,
27902794
@nospecialize(TT::Type),
27912795
@nospecialize(rettype::Type),
@@ -2795,10 +2799,10 @@ function create_abi_wrapper(
27952799
width::Int,
27962800
returnPrimal::Bool,
27972801
shadow_init::Bool,
2798-
world::UInt,
27992802
interp,
28002803
runtime_activity::Bool
28012804
)
2805+
world = ctx.world
28022806
is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined
28032807
is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal
28042808
needs_tape = Mode == API.DEM_ReverseModeGradient
@@ -3081,6 +3085,7 @@ function create_abi_wrapper(
30813085
realparms = LLVM.Value[]
30823086
i = 1
30833087

3088+
# TODO(vchuravy): remove
30843089
for attr in collect(function_attributes(enzymefn))
30853090
if kind(attr) == "enzymejl_world"
30863091
push!(function_attributes(llvm_f), attr)
@@ -3225,7 +3230,7 @@ function create_abi_wrapper(
32253230
elseif T <: BatchDuplicatedFunc
32263231
Func = get_func(T)
32273232
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3228-
llvmf = nested_codegen!(Mode, mod, funcspec, world)
3233+
llvmf = nested_codegen!(ctx, Mode, mod, funcspec)
32293234
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
32303235
Func_RT = return_type(interp, funcspec)
32313236
@assert Func_RT == NTuple{width,T′}
@@ -5089,15 +5094,17 @@ end
50895094
end
50905095
end
50915096

5097+
ctx = EnzymeContext(job.world)
50925098
if params.run_enzyme
50935099
# Generate the adjoint
50945100
memcpy_alloca_to_loadstore(mod)
50955101
force_recompute!(mod)
50965102
API.EnzymeDetectReadonlyOrThrow(mod)
50975103

50985104
adjointf, augmented_primalf, TapeType = enzyme!(
5105+
ctx,
50995106
job,
5100-
interp,
5107+
interp,
51015108
mod,
51025109
primalf,
51035110
TT,
@@ -5195,7 +5202,7 @@ end
51955202
fname = String(name) * pf
51965203
if haskey(functions(mod), fname)
51975204
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job.world)
5198-
llvmf = nested_codegen!(mode, mod, funcspec, job.world)
5205+
llvmf = nested_codegen!(ctx, mode, mod, funcspec)
51995206
push!(function_attributes(llvmf), StringAttribute("implements", fname))
52005207
end
52015208
end

src/rules/customrules.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,8 @@ end
481481

482482
curent_bb = position(B)
483483
fn = LLVM.parent(curent_bb)
484-
world = enzyme_extract_world(fn)
485-
@assert world == enzyme_context(gutils).world
486484

487-
llvmf = nested_codegen!(mode, mod, fmi, world)
485+
llvmf = nested_codegen!(enzyme_context(gutils), mode, mod, fmi)
488486

489487
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
490488

@@ -922,8 +920,8 @@ function enzyme_custom_common_rev(
922920

923921
curent_bb = position(B)
924922
fn = LLVM.parent(curent_bb)
925-
world = enzyme_extract_world(fn)
926-
@assert world == enzyme_context(gutils).world
923+
ctx = enzyme_context(gutils)
924+
world = ctx.world
927925

928926
mode = get_mode(gutils)
929927

@@ -984,7 +982,7 @@ function enzyme_custom_common_rev(
984982
applicablefn = true
985983

986984
if forward
987-
llvmf = nested_codegen!(mode, mod, ami, world)
985+
llvmf = nested_codegen!(ctx, mode, mod, ami)
988986
@assert llvmf !== nothing
989987
rev_RT = nothing
990988
else
@@ -1026,7 +1024,7 @@ function enzyme_custom_common_rev(
10261024

10271025
rmi = rmi::Core.MethodInstance
10281026
rev_RT = rev_RT::Type
1029-
llvmf = nested_codegen!(mode, mod, rmi, world)
1027+
llvmf = nested_codegen!(ctx, mode, mod, rmi)
10301028
end
10311029

10321030
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

src/rules/parallelrules.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,8 @@ end
550550

551551
tt = Tuple{thunkTy,dfuncT,Bool}
552552
mode = get_mode(gutils)
553-
world = enzyme_extract_world(LLVM.parent(position(B)))
554-
@assert world == enzyme_context(gutils).world
555-
entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world)
553+
ctx = enzyme_context(gutils)
554+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_fwd, tt)
556555
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
557556

558557
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -595,9 +594,8 @@ end
595594
Bool,
596595
}
597596
mode = get_mode(gutils)
598-
world = enzyme_extract_world(LLVM.parent(position(B)))
599-
@assert world == enzyme_context(gutils).world
600-
entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world)
597+
ctx = enzyme_context(gutils)
598+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_augfwd, tt)
601599
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
602600

603601
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))
@@ -629,8 +627,6 @@ end
629627

630628
@register_rev function threadsfor_rev(B, orig, gutils, tape)
631629
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
632-
world = enzyme_extract_world(LLVM.parent(position(B)))
633-
@assert world == enzyme_context(gutils).world
634630
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
635631
return
636632
end
@@ -653,7 +649,8 @@ end
653649
Bool,
654650
}
655651
mode = get_mode(gutils)
656-
entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt, world)
652+
ctx = enzyme_context(gutils)
653+
entry = nested_codegen!(ctx, mode, mod, runtime_pfor_rev, tt)
657654
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
658655

659656
pval = const_ptrtoint(functions(mod)[sname], convert(LLVMType, Ptr{Cvoid}))

src/rules/typeunstablerules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ end
10361036
if legal
10371037
@assert legal
10381038
world = enzyme_extract_world(LLVM.parent(position(B)))
1039-
@assert world == enzyme_context(gutils).world
1039+
@assert world == enzyme_context(gutils).world
10401040
torun = !guaranteed_nonactive(TT, world)
10411041
else
10421042
torun = true

0 commit comments

Comments
 (0)