@@ -486,12 +486,13 @@ include("llvm/transforms.jl")
486486include (" llvm/passes.jl" )
487487include (" typeutils/make_zero.jl" )
488488
489- function nested_codegen! (mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ), world :: UInt )
490- funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, world)
491- nested_codegen! (mode, mod, funcspec, world )
489+ function nested_codegen! (ctx :: EnzymeContext , mode:: API.CDerivativeMode , mod:: LLVM.Module , @nospecialize (f), @nospecialize (tt:: Type ))
490+ funcspec = my_methodinstance (mode == API. DEM_ForwardMode ? Forward : Reverse, typeof (f), tt, ctx . world)
491+ nested_codegen! (ctx, mode, mod, funcspec)
492492end
493493
494494function prepare_llvm (interp, mod:: LLVM.Module , job, meta)
495+ # TODO : remove enzymejl_world
495496 for f in functions (mod)
496497 attributes = function_attributes (f)
497498 push! (attributes, StringAttribute (" enzymejl_world" , string (job. world)))
@@ -1234,11 +1235,12 @@ const DumpPreNestedOpt = Ref(false)
12341235const DumpPostNestedOpt = Ref (false )
12351236
12361237function nested_codegen! (
1238+ ctx:: EnzymeContext ,
12371239 mode:: API.CDerivativeMode ,
12381240 mod:: LLVM.Module ,
12391241 funcspec:: Core.MethodInstance ,
1240- world:: UInt ,
12411242)
1243+ world = ctx. world
12421244 # TODO : Put a cache here index on `mod` and f->tt
12431245
12441246
@@ -1254,6 +1256,7 @@ function nested_codegen!(
12541256 GPUCompiler. prepare_job! (job)
12551257 otherMod, meta = GPUCompiler. emit_llvm (job)
12561258
1259+ # TODO : interp should be cached since it contains internal caches
12571260 interp = GPUCompiler. get_interpreter (job)
12581261 prepare_llvm (interp, otherMod, job, meta)
12591262
@@ -2398,6 +2401,7 @@ const DumpPostEnzyme = Ref(false)
23982401const DumpPostWrap = Ref (false )
23992402
24002403function enzyme! (
2404+ enzyme_context:: EnzymeContext ,
24012405 job:: CompilerJob ,
24022406 interp,
24032407 mod:: LLVM.Module ,
@@ -2513,7 +2517,6 @@ function enzyme!(
25132517 convert (API. CDIFFE_TYPE, rt)
25142518 end
25152519
2516- enzyme_context = EnzymeContext (job. world)
25172520 GC. @preserve enzyme_context begin
25182521 LLVM. @dispose logic = Logic (enzyme_context) begin
25192522
@@ -2583,6 +2586,7 @@ function enzyme!(
25832586
25842587 if wrap
25852588 augmented_primalf = create_abi_wrapper (
2589+ enzyme_context,
25862590 augmented_primalf,
25872591 TT,
25882592 rt,
@@ -2592,7 +2596,6 @@ function enzyme!(
25922596 width,
25932597 returnPrimal,
25942598 shadow_init,
2595- world,
25962599 interp,
25972600 runtimeActivity,
25982601 )
@@ -2625,6 +2628,7 @@ function enzyme!(
26252628 ) #= atomicAdd=#
26262629 if wrap
26272630 adjointf = create_abi_wrapper (
2631+ enzyme_context,
26282632 adjointf,
26292633 TT,
26302634 rt,
@@ -2634,7 +2638,6 @@ function enzyme!(
26342638 width,
26352639 false ,
26362640 shadow_init,
2637- world,
26382641 interp,
26392642 runtimeActivity
26402643 ) #= returnPrimal=#
@@ -2666,6 +2669,7 @@ function enzyme!(
26662669 augmented_primalf = nothing
26672670 if wrap
26682671 adjointf = create_abi_wrapper (
2672+ enzyme_context,
26692673 adjointf,
26702674 TT,
26712675 rt,
@@ -2675,7 +2679,6 @@ function enzyme!(
26752679 width,
26762680 returnPrimal,
26772681 shadow_init,
2678- world,
26792682 interp,
26802683 runtimeActivity
26812684 )
@@ -2711,6 +2714,7 @@ function enzyme!(
27112714 if wrap
27122715 pf = adjointf
27132716 adjointf = create_abi_wrapper (
2717+ enzyme_context,
27142718 adjointf,
27152719 TT,
27162720 rt,
@@ -2720,7 +2724,6 @@ function enzyme!(
27202724 width,
27212725 returnPrimal,
27222726 shadow_init,
2723- world,
27242727 interp,
27252728 runtimeActivity
27262729 )
@@ -2792,6 +2795,7 @@ function set_subprogram!(f::LLVM.Function, sp)
27922795end
27932796
27942797function create_abi_wrapper (
2798+ ctx:: EnzymeContext ,
27952799 enzymefn:: LLVM.Function ,
27962800 @nospecialize (TT:: Type ),
27972801 @nospecialize (rettype:: Type ),
@@ -2801,10 +2805,10 @@ function create_abi_wrapper(
28012805 width:: Int ,
28022806 returnPrimal:: Bool ,
28032807 shadow_init:: Bool ,
2804- world:: UInt ,
28052808 interp,
28062809 runtime_activity:: Bool
28072810)
2811+ world = ctx. world
28082812 is_adjoint = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModeCombined
28092813 is_split = Mode == API. DEM_ReverseModeGradient || Mode == API. DEM_ReverseModePrimal
28102814 needs_tape = Mode == API. DEM_ReverseModeGradient
@@ -3087,6 +3091,7 @@ function create_abi_wrapper(
30873091 realparms = LLVM. Value[]
30883092 i = 1
30893093
3094+ # TODO (vchuravy): remove
30903095 for attr in collect (function_attributes (enzymefn))
30913096 if kind (attr) == " enzymejl_world"
30923097 push! (function_attributes (llvm_f), attr)
@@ -3231,7 +3236,7 @@ function create_abi_wrapper(
32313236 elseif T <: BatchDuplicatedFunc
32323237 Func = get_func (T)
32333238 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world)
3234- llvmf = nested_codegen! (Mode, mod, funcspec, world )
3239+ llvmf = nested_codegen! (ctx, Mode, mod, funcspec)
32353240 push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 ))
32363241 Func_RT = return_type (interp, funcspec)
32373242 @assert Func_RT == NTuple{width,T′}
@@ -5102,15 +5107,17 @@ end
51025107 end
51035108 end
51045109
5110+ ctx = EnzymeContext (job. world)
51055111 if params. run_enzyme
51065112 # Generate the adjoint
51075113 memcpy_alloca_to_loadstore (mod)
51085114 force_recompute! (mod)
51095115 API. EnzymeDetectReadonlyOrThrow (mod)
51105116
51115117 adjointf, augmented_primalf, TapeType = enzyme! (
5118+ ctx,
51125119 job,
5113- interp,
5120+ interp,
51145121 mod,
51155122 primalf,
51165123 TT,
@@ -5209,7 +5216,7 @@ end
52095216 fname = String (name) * pf
52105217 if haskey (functions (mod), fname)
52115218 funcspec = my_methodinstance (Mode == API. DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, job. world)
5212- llvmf = nested_codegen! (mode, mod, funcspec, job . world )
5219+ llvmf = nested_codegen! (ctx, mode, mod, funcspec)
52135220 push! (function_attributes (llvmf), StringAttribute (" implements" , fname))
52145221 end
52155222 end
0 commit comments