@@ -484,12 +484,13 @@ include("llvm/transforms.jl")
484484include (" llvm/passes.jl" )
485485include (" typeutils/make_zero.jl" )
486486
487- function nested_codegen! (mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ), world :: UInt )
488- funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, world)
489- nested_codegen! (mode, mod, funcspec, world )
487+ function nested_codegen! (ctx :: EnzymeContext , mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ))
488+ funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, ctx . world)
489+ nested_codegen! (ctx, mode, mod, funcspec)
490490end
491491
492492function prepare_llvm (interp, mod:: LLVM.Module , job, meta)
493+ # TODO : remove enzymejl_world
493494 for f in functions (mod)
494495 attributes = function_attributes (f)
495496 push! (attributes, StringAttribute (" enzymejl_world" , string (job. world)))
@@ -1228,11 +1229,12 @@ const DumpPreNestedOpt = Ref(false)
12281229const DumpPostNestedOpt = Ref (false )
12291230
12301231function nested_codegen! (
1232+ ctx:: EnzymeContext ,
12311233 mode:: API.CDerivativeMode ,
12321234 mod:: LLVM.Module ,
12331235 funcspec:: Core.MethodInstance ,
1234- world:: UInt ,
12351236)
1237+ world = ctx. world
12361238 # TODO : Put a cache here index on `mod` and f->tt
12371239
12381240
@@ -1248,6 +1250,7 @@ function nested_codegen!(
12481250 GPUCompiler. prepare_job! (job)
12491251 otherMod, meta = GPUCompiler. emit_llvm (job)
12501252
1253+ # TODO : interp should be cached since it contains internal caches
12511254 interp = GPUCompiler. get_interpreter (job)
12521255 prepare_llvm (interp, otherMod, job, meta)
12531256
@@ -2389,6 +2392,7 @@ const DumpPostEnzyme = Ref(false)
23892392const DumpPostWrap = Ref (false )
23902393
23912394function enzyme! (
2395+ enzyme_context:: EnzymeContext ,
23922396 job:: CompilerJob ,
23932397 interp,
23942398 mod:: LLVM.Module ,
@@ -2504,7 +2508,6 @@ function enzyme!(
25042508 convert (API. CDIFFE_TYPE, rt)
25052509 end
25062510
2507- enzyme_context = EnzymeContext (job. world)
25082511 GC. @preserve enzyme_context begin
25092512 LLVM. @dispose logic = Logic (enzyme_context) begin
25102513
@@ -2574,6 +2577,7 @@ function enzyme!(
25742577
25752578 if wrap
25762579 augmented_primalf = create_abi_wrapper (
2580+ enzyme_context,
25772581 augmented_primalf,
25782582 TT,
25792583 rt,
@@ -2583,7 +2587,6 @@ function enzyme!(
25832587 width,
25842588 returnPrimal,
25852589 shadow_init,
2586- world,
25872590 interp,
25882591 runtimeActivity,
25892592 )
@@ -2616,6 +2619,7 @@ function enzyme!(
26162619 ) #= atomicAdd=#
26172620 if wrap
26182621 adjointf = create_abi_wrapper (
2622+ enzyme_context,
26192623 adjointf,
26202624 TT,
26212625 rt,
@@ -2625,7 +2629,6 @@ function enzyme!(
26252629 width,
26262630 false ,
26272631 shadow_init,
2628- world,
26292632 interp,
26302633 runtimeActivity
26312634 ) #= returnPrimal=#
@@ -2657,6 +2660,7 @@ function enzyme!(
26572660 augmented_primalf = nothing
26582661 if wrap
26592662 adjointf = create_abi_wrapper (
2663+ enzyme_context,
26602664 adjointf,
26612665 TT,
26622666 rt,
@@ -2666,7 +2670,6 @@ function enzyme!(
26662670 width,
26672671 returnPrimal,
26682672 shadow_init,
2669- world,
26702673 interp,
26712674 runtimeActivity
26722675 )
@@ -2702,6 +2705,7 @@ function enzyme!(
27022705 if wrap
27032706 pf = adjointf
27042707 adjointf = create_abi_wrapper (
2708+ enzyme_context,
27052709 adjointf,
27062710 TT,
27072711 rt,
@@ -2711,7 +2715,6 @@ function enzyme!(
27112715 width,
27122716 returnPrimal,
27132717 shadow_init,
2714- world,
27152718 interp,
27162719 runtimeActivity
27172720 )
@@ -2786,6 +2789,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27862789end
27872790
27882791function create_abi_wrapper (
2792+ ctx:: EnzymeContext ,
27892793 enzymefn:: LLVM.Function ,
27902794 @nospecialize (TT:: Type ),
27912795 @nospecialize (rettype:: Type ),
@@ -2795,10 +2799,10 @@ function create_abi_wrapper(
27952799 width:: Int ,
27962800 returnPrimal:: Bool ,
27972801 shadow_init:: Bool ,
2798- world:: UInt ,
27992802 interp,
28002803 runtime_activity:: Bool
28012804)
2805+ world = ctx. world
28022806 is_adjoint = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModeCombined
28032807 is_split = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModePrimal
28042808 needs_tape = Mode == API. DEM_ReverseModeGradient
@@ -3081,6 +3085,7 @@ function create_abi_wrapper(
30813085 realparms = LLVM. Value[]
30823086 i = 1
30833087
3088+ # TODO (vchuravy): remove
30843089 for attr in collect (function_attributes (enzymefn))
30853090 if kind (attr) == " enzymejl_world"
30863091 push! (function_attributes (llvm_f), attr)
@@ -3225,7 +3230,7 @@ function create_abi_wrapper(
32253230 elseif T <: BatchDuplicatedFunc
32263231 Func = get_func (T)
32273232 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3228- llvmf = nested_codegen! (Mode, mod, funcspec, world )
3233+ llvmf = nested_codegen! (ctx, Mode, mod, funcspec)
32293234 push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 ))
32303235 Func_RT = return_type (interp, funcspec)
32313236 @assert Func_RT == NTuple{width,T′}
@@ -5089,15 +5094,17 @@ end
50895094 end
50905095 end
50915096
5097+ ctx = EnzymeContext (job. world)
50925098 if params. run_enzyme
50935099 # Generate the adjoint
50945100 memcpy_alloca_to_loadstore (mod)
50955101 force_recompute! (mod)
50965102 API. EnzymeDetectReadonlyOrThrow (mod)
50975103
50985104 adjointf, augmented_primalf, TapeType = enzyme! (
5105+ ctx,
50995106 job,
5100- interp,
5107+ interp,
51015108 mod,
51025109 primalf,
51035110 TT,
@@ -5195,7 +5202,7 @@ end
51955202 fname = String (name) * pf
51965203 if haskey (functions (mod), fname)
51975204 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job. world)
5198- llvmf = nested_codegen! (mode, mod, funcspec, job . world )
5205+ llvmf = nested_codegen! (ctx, mode, mod, funcspec)
51995206 push! (function_attributes (llvmf), StringAttribute (" implements" , fname))
52005207 end
52015208 end
0 commit comments