From 84aa45b55741f030ec394d1f4ecb3cdd42163f0c Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Mon, 16 Mar 2026 23:57:48 -0600 Subject: [PATCH 1/2] Enable GPU-aware halo exchanges via OpenACC directives This commit enables execution of halo exchanges on GPUs via OpenACC directives, if the MPAS atmosphere core has been built with an appropriate GPU-aware MPI distribution. Module mpas_halo is modified in the following ways to enable GPU-aware halo exchanges: - In the call to mpas_halo_exch_group_complete, OpenACC directives copy to device all the relevant fields and metadata that are required for the packing and unpacking loops later. - OpenACC directives are introduced around the packing and unpacking loops to perform the field to/from send/recv buffer operations on the device. The attach clauses introduced to the parallel constructs ensures that the device pointers are attached to the device targets at the start of the parallel region and detached at the end of the region. - The actual MPI_Isend and MPI_Irecv operations use GPU-aware MPI, by wrapping these calls within !$acc host_data constructs. Note: This commit introduces temporary host-device data movements in the atm_core_init routine around the two calls to exchange_halo_group. This is required just for this commit as all halo-exchanges occur on the device and fields not yet present on the device must be copied over to it before the halo exchanges and back to host after it. These copies will be removed in subsequent commits. --- .../dynamics/mpas_atm_time_integration.F | 34 --------- src/core_atmosphere/mpas_atm_core.F | 19 +++++ src/framework/mpas_halo.F | 76 +++++++++++++++++++ 3 files changed, 95 insertions(+), 34 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 19c9ed8978..77109da9bc 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -2115,9 +2115,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'scalars', scalars_1, 1) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p') - !$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_rk_integration_setup') @@ -2209,9 +2207,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'exner', exner) - !$acc update self(exner) call exchange_halo_group(domain, 'dynamics:exner') - !$acc update device(exner) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -2293,9 +2289,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! tend_u MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(tend, 'u', tend_u) - !$acc update self(tend_u) call exchange_halo_group(domain, 'dynamics:tend_u') - !$acc update device(tend_u) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('small_step_prep') @@ -2374,9 +2368,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rho_pp', rho_pp) - !$acc update self(rho_pp) call exchange_halo_group(domain, 'dynamics:rho_pp') - !$acc update device(rho_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_advance_acoustic_step') @@ -2401,9 +2393,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! This is the only communications needed during the acoustic steps because we solve for u on all edges of owned cells MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !$acc update self(rtheta_pp) call exchange_halo_group(domain, 'dynamics:rtheta_pp') - !$acc update device(rtheta_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! complete update of horizontal momentum by including 3d divergence damping at the end of the acoustic step @@ -2429,9 +2419,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(diag, 'rw_p', rw_p) call mpas_pool_get_array(diag, 'rho_pp', rho_pp) call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp') - !$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_recover_large_step_variables') @@ -2502,14 +2490,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'u', u, 2) - !$acc update self(u) ! u if (config_apply_lbcs) then call exchange_halo_group(domain, 'dynamics:u_123') else call exchange_halo_group(domain, 'dynamics:u_3') end if - !$acc update device(u) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! scalar advection: RK3 scheme of Skamarock and Gassmann (2011). @@ -2524,9 +2510,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2583,22 +2567,18 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'w', w, 2) call mpas_pool_get_array(diag, 'pv_edge', pv_edge) call mpas_pool_get_array(diag, 'rho_edge', rho_edge) - !$acc update self(w,pv_edge,rho_edge) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2], scalars[1,2] ! call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars') - !$acc update device(scalars_2) else ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2] ! call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge') end if - !$acc update device(w,pv_edge,rho_edge) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! set the zero-gradient condition on w for regional_MPAS @@ -2615,9 +2595,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! w halo values needs resetting after regional boundary update call mpas_pool_get_array(state, 'w', w, 2) - !$acc update self(w) call exchange_halo_group(domain, 'dynamics:w') - !$acc update device(w) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if ! end of regional_MPAS addition @@ -2633,9 +2611,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'theta_m', theta_m, 2) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !$acc update self(theta_m,pressure_p,rtheta_p) call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p') - !$acc update device(theta_m,pressure_p,rtheta_p) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! @@ -2704,9 +2680,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! need to fill halo for horizontal filter call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2735,9 +2709,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (rk_step < 3) then MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if @@ -2867,9 +2839,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -5226,7 +5196,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update self(scalars_old) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER @@ -5236,7 +5205,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update device(scalars_old) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') ! @@ -5633,7 +5601,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge ! MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update self(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER @@ -5643,7 +5610,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update device(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$acc parallel diff --git a/src/core_atmosphere/mpas_atm_core.F b/src/core_atmosphere/mpas_atm_core.F index 7a79527910..5da37f931b 100644 --- a/src/core_atmosphere/mpas_atm_core.F +++ b/src/core_atmosphere/mpas_atm_core.F @@ -60,6 +60,7 @@ function atm_core_init(domain, startTimeStamp) result(ierr) character (len=StrKIND), pointer :: initial_time1, initial_time2 type (MPAS_Time_Type) :: startTime + real (kind=RKIND), dimension(:,:), pointer :: u, ru, rw, pv_edge real (kind=RKIND), pointer :: nominalMinDc real (kind=RKIND), pointer :: config_len_disp real (kind=RKIND), pointer :: Time @@ -238,7 +239,14 @@ function atm_core_init(domain, startTimeStamp) result(ierr) startTime = mpas_get_clock_time(clock, MPAS_START_TIME, ierr) call mpas_get_time(startTime, dateTimeString=startTimeStamp) + ! This copy is temporarily needed to ensure that the field u is available on the device + ! prior to the GPU-direct halo exchange in the call to exchange_halo_group. It will be + ! removed in subsequent commits that introduce a namelist option to select whether halo exchanges + ! are performed on the device or not. + call mpas_pool_get_array(domain % blocklist % allFields, 'u', u, 1) + !$acc enter data copyin(u) call exchange_halo_group(domain, 'initialization:u') + !$acc exit data copyout(u) ! @@ -276,7 +284,18 @@ function atm_core_init(domain, startTimeStamp) result(ierr) block => block % next end do + ! This copy is temporarily needed to ensure that ru, rw, pv_edge are available on the device + ! prior to the GPU-direct halo exchange in the call to exchange_halo_group. It will be + ! removed in subsequent commits that introduce a namelist option to select whether halo exchanges + ! are performed on the device or not. + call mpas_pool_get_array(domain % blocklist % allFields, 'ru', ru, 1) + !$acc enter data copyin(ru) + call mpas_pool_get_array(domain % blocklist % allFields, 'rw', rw, 1) + !$acc enter data copyin(rw) + call mpas_pool_get_array(domain % blocklist % allFields, 'pv_edge', pv_edge) + !$acc enter data copyin(pv_edge) call exchange_halo_group(domain, 'initialization:pv_edge,ru,rw') + !$acc exit data copyout(pv_edge,ru,rw) call mpas_atm_diag_setup(domain % streamManager, domain % blocklist % configs, & domain % blocklist % structs, domain % clock, domain % dminfo) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 4ab8817c23..582afa7c34 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -280,6 +280,29 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) + ! Logic to return early if there are no neighbors to send to + if ( newGroup% nGroupSendNeighbors <=0 ) then + return + end if + + + ! Always copy in the main data member first + !$acc enter data copyin(newGroup) + ! Then the data in the members of the type + !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) + !$acc enter data copyin(newGroup % fields(:)) + do i = 1, newGroup % nFields + !$acc enter data copyin(newGroup % fields(i)) + !$acc enter data copyin(newGroup % fields(i) % nSendLists(:,:)) + !$acc enter data copyin(newGroup % fields(i) % packOffsets(:)) + !$acc enter data copyin(newGroup % fields(i) % sendListSrc(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % sendListDst(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % nRecvLists(:,:)) + !$acc enter data copyin(newGroup % fields(i) % unpackOffsets(:)) + !$acc enter data copyin(newGroup % fields(i) % recvListSrc(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % recvListDst(:,:,:)) + end do + end subroutine mpas_halo_exch_group_complete @@ -349,15 +372,26 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % fields(i) % compactHaloInfo) deallocate(cursor % fields(i) % compactSendLists) deallocate(cursor % fields(i) % compactRecvLists) + !$acc exit data delete(cursor % fields(i) % nSendLists(:,:)) deallocate(cursor % fields(i) % nSendLists) + !$acc exit data delete(cursor % fields(i) % sendListSrc(:,:,:)) deallocate(cursor % fields(i) % sendListSrc) + !$acc exit data delete(cursor % fields(i) % sendListDst(:,:,:)) deallocate(cursor % fields(i) % sendListDst) + !$acc exit data delete(cursor % fields(i) % packOffsets(:)) deallocate(cursor % fields(i) % packOffsets) + !$acc exit data delete(cursor % fields(i) % nRecvLists(:,:)) deallocate(cursor % fields(i) % nRecvLists) + !$acc exit data delete(cursor % fields(i) % recvListSrc(:,:,:)) deallocate(cursor % fields(i) % recvListSrc) + !$acc exit data delete(cursor % fields(i) % recvListDst(:,:,:)) deallocate(cursor % fields(i) % recvListDst) + !$acc exit data delete(cursor % fields(i) % unpackOffsets(:)) deallocate(cursor % fields(i) % unpackOffsets) + !$acc exit data delete(cursor % fields(i)) end do + ! Use finalize here in-case the copyins in ..._complete increment the reference counter + !$acc exit data finalize delete(cursor % fields(:)) deallocate(cursor % fields) deallocate(cursor % groupPackOffsets) deallocate(cursor % groupSendNeighbors) @@ -368,10 +402,14 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % groupToFieldRecvIdx) deallocate(cursor % groupRecvOffsets) deallocate(cursor % groupRecvCounts) + !$acc exit data delete(cursor % sendBuf(:)) deallocate(cursor % sendBuf) + !$acc exit data delete(cursor % recvBuf(:)) deallocate(cursor % recvBuf) deallocate(cursor % sendRequests) deallocate(cursor % recvRequests) + ! Finalize here as well, just in-case + !$acc exit data finalize delete(cursor) deallocate(cursor) end subroutine mpas_halo_exch_group_destroy @@ -577,6 +615,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) messageType=MPAS_LOG_CRIT) end if + ! Logic to return early if there are no neighbors to send to + if ( group% nGroupSendNeighbors <= 0 ) then + return + end if + ! ! Get the rank of this task and the MPI communicator to use from the first field in ! the group; all fields should be using the same communicator, so this should not @@ -598,9 +641,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 !TO DO: how do we determine appropriate type here? + !$acc host_data use_device(group % recvBuf) call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) + !$acc end host_data else group % recvRequests(i) = MPI_REQUEST_NULL end if @@ -642,8 +687,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! + !$acc parallel default(present) attach(group % fields(i) % r1arr) + !$acc loop gang collapse(2) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos + !$acc loop vector do j = 1, maxNSendList if (j <= nSendLists(iHalo,iEndp)) then group % sendBuf(packOffsets(iEndp) + sendListDst(j,iHalo,iEndp)) = & @@ -652,6 +700,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end parallel ! ! Packing code for 2-d real-valued fields @@ -663,9 +712,16 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! + ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of + ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' + ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) + + !$acc parallel default(present) attach(group % fields(i) % r2arr) + !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList + !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then group % sendBuf(packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1) = & @@ -675,6 +731,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end parallel ! ! Packing code for 3-d real-valued fields @@ -686,10 +743,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! + !$acc parallel default(present) attach(group % fields(i) % r3arr) + !$acc loop gang collapse(4) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList do i2 = 1, dim2 + !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then group % sendBuf(packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & @@ -701,6 +761,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end parallel end select end if @@ -714,9 +775,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 !TO DO: how do we determine appropriate type here? + !$acc host_data use_device(group % sendBuf) call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) + !$acc end host_data else group % sendRequests(i) = MPI_REQUEST_NULL end if @@ -771,7 +834,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc parallel default(present) attach(group % fields(i) % r1arr) + !$acc loop gang do iHalo = 1, nHalos + !$acc loop vector do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then group % fields(i) % r1arr(recvListDst(j,iHalo,iEndp)) = & @@ -779,6 +845,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end if end do end do + !$acc end parallel ! ! Unpacking code for 2-d real-valued fields @@ -787,8 +854,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc parallel default(present) attach(group % fields(i) % r2arr) + !$acc loop gang do iHalo = 1, nHalos + !$acc loop worker do j = 1, maxNRecvList + !$acc loop vector do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then group % fields(i) % r2arr(i1, recvListDst(j,iHalo,iEndp)) = & @@ -797,6 +868,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end parallel ! ! Unpacking code for 3-d real-valued fields @@ -805,8 +877,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc parallel default(present) attach(group % fields(i) % r3arr) + !$acc loop gang collapse(2) do iHalo = 1, nHalos do j = 1, maxNRecvList + !$acc loop vector collapse(2) do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then @@ -818,6 +893,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end parallel end select end if From b0ce63013ab05bd81c39a6488e0b610720084c64 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Tue, 17 Mar 2026 09:28:06 -0600 Subject: [PATCH 2/2] New namelist option to switch on or off GPU-aware MPI halo exchanges Introducing a new namelist option under development, config_gpu_aware_mpi, which will control whether the MPAS runs on GPUs will use GPU-aware MPI or perform a device<->host update of variables around the call to a purely CPU-based halo exchange. Note: This feature is not available to use when config_halo_exch_method is set to 'mpas_dmpar' --- src/core_atmosphere/Registry.xml | 5 ++ .../dynamics/mpas_atm_time_integration.F | 89 +++++++++++++------ src/core_atmosphere/mpas_atm_core.F | 19 ---- src/framework/mpas_dmpar.F | 9 +- src/framework/mpas_halo.F | 25 ++++-- src/framework/mpas_halo_interface.inc | 3 +- 6 files changed, 95 insertions(+), 55 deletions(-) diff --git a/src/core_atmosphere/Registry.xml b/src/core_atmosphere/Registry.xml index c389db81dd..abf6fc58fe 100644 --- a/src/core_atmosphere/Registry.xml +++ b/src/core_atmosphere/Registry.xml @@ -443,6 +443,11 @@ units="-" description="Method to use for exchanging halos" possible_values="`mpas_dmpar', `mpas_halo'"/> + + #ifdef MPAS_USE_MUSICA diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 77109da9bc..68329ef3b3 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -1918,6 +1918,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) logical, pointer :: config_scalar_advection logical, pointer :: config_positive_definite logical, pointer :: config_monotonic + logical, pointer :: config_gpu_aware_mpi character (len=StrKIND), pointer :: config_microp_scheme character (len=StrKIND), pointer :: config_convection_scheme @@ -1962,6 +1963,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_config(block % configs, 'config_scalar_advection', config_scalar_advection) call mpas_pool_get_config(block % configs, 'config_positive_definite', config_positive_definite) call mpas_pool_get_config(block % configs, 'config_monotonic', config_monotonic) + call mpas_pool_get_config(block % configs, 'config_gpu_aware_mpi', config_gpu_aware_mpi) call mpas_pool_get_config(block % configs, 'config_IAU_option', config_IAU_option) ! config variables for dynamics-transport splitting, WCS 18 November 2014 call mpas_pool_get_config(block % configs, 'config_split_dynamics_transport', config_split_dynamics_transport) @@ -2115,7 +2117,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'scalars', scalars_1, 1) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p') + !$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p', config_gpu_aware_mpi) + !$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_rk_integration_setup') @@ -2207,7 +2211,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'exner', exner) - call exchange_halo_group(domain, 'dynamics:exner') + !$acc update self(exner) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:exner', config_gpu_aware_mpi) + !$acc update device(exner) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -2289,7 +2295,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! tend_u MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(tend, 'u', tend_u) - call exchange_halo_group(domain, 'dynamics:tend_u') + !$acc update self(tend_u) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:tend_u', config_gpu_aware_mpi) + !$acc update device(tend_u) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('small_step_prep') @@ -2368,7 +2376,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rho_pp', rho_pp) - call exchange_halo_group(domain, 'dynamics:rho_pp') + !$acc update self(rho_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rho_pp', config_gpu_aware_mpi) + !$acc update device(rho_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_advance_acoustic_step') @@ -2393,7 +2403,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! This is the only communications needed during the acoustic steps because we solve for u on all edges of owned cells MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rtheta_pp') + !$acc update self(rtheta_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rtheta_pp', config_gpu_aware_mpi) + !$acc update device(rtheta_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! complete update of horizontal momentum by including 3d divergence damping at the end of the acoustic step @@ -2419,7 +2431,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(diag, 'rw_p', rw_p) call mpas_pool_get_array(diag, 'rho_pp', rho_pp) call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp') + !$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp', config_gpu_aware_mpi) + !$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_recover_large_step_variables') @@ -2490,12 +2504,14 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'u', u, 2) + !$acc update self(u) if (.not. config_gpu_aware_mpi) ! u if (config_apply_lbcs) then - call exchange_halo_group(domain, 'dynamics:u_123') + call exchange_halo_group(domain, 'dynamics:u_123', config_gpu_aware_mpi) else - call exchange_halo_group(domain, 'dynamics:u_3') + call exchange_halo_group(domain, 'dynamics:u_3', config_gpu_aware_mpi) end if + !$acc update device(u) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! scalar advection: RK3 scheme of Skamarock and Gassmann (2011). @@ -2504,13 +2520,15 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then call advance_scalars('scalars', domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2567,18 +2585,22 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'w', w, 2) call mpas_pool_get_array(diag, 'pv_edge', pv_edge) call mpas_pool_get_array(diag, 'rho_edge', rho_edge) + !$acc update self(w,pv_edge,rho_edge) if (.not. config_gpu_aware_mpi) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2], scalars[1,2] ! call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars') + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) else ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2] ! - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge') + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge', config_gpu_aware_mpi) end if + !$acc update device(w,pv_edge,rho_edge) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! set the zero-gradient condition on w for regional_MPAS @@ -2595,7 +2617,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! w halo values needs resetting after regional boundary update call mpas_pool_get_array(state, 'w', w, 2) - call exchange_halo_group(domain, 'dynamics:w') + !$acc update self(w) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:w', config_gpu_aware_mpi) + !$acc update device(w) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if ! end of regional_MPAS addition @@ -2611,7 +2635,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'theta_m', theta_m, 2) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p') + !$acc update self(theta_m,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p', config_gpu_aware_mpi) + !$acc update device(theta_m,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! @@ -2673,14 +2699,16 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call advance_scalars('scalars', domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! need to fill halo for horizontal filter call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2709,7 +2737,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (rk_step < 3) then MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if @@ -2839,7 +2869,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2893,7 +2925,7 @@ end subroutine atm_srk3 ! !----------------------------------------------------------------------- subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) implicit none @@ -2906,6 +2938,7 @@ subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_mono logical, intent(in) :: config_positive_definite integer, intent(in) :: config_time_integration_order logical, intent(in) :: config_split_dynamics_transport + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group ! Local variables @@ -3037,7 +3070,7 @@ subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_mono edgeThreadStart(thread), edgeThreadEnd(thread), & cellSolveThreadStart(thread), cellSolveThreadEnd(thread), & scalar_old_arr, scalar_new_arr, s_max_arr, s_min_arr, wdtn_arr, & - flux_array, flux_upwind_tmp_arr, flux_tmp_arr, & + flux_array, flux_upwind_tmp_arr, flux_tmp_arr, config_gpu_aware_mpi, & exchange_halo_group, & advance_density=config_split_dynamics_transport, rho_zz_int=rho_zz_int) end if @@ -4937,7 +4970,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, cellStart, cellEnd, edgeStart, edgeEnd, & cellSolveStart, cellSolveEnd, & scalar_old, scalar_new, s_max, s_min, wdtn, flux_arr, & - flux_upwind_tmp, flux_tmp, exchange_halo_group, advance_density, rho_zz_int) + flux_upwind_tmp, flux_tmp, config_gpu_aware_mpi, exchange_halo_group, advance_density, rho_zz_int) implicit none @@ -4958,6 +4991,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, real (kind=RKIND), dimension(:,:), intent(inout) :: wdtn real (kind=RKIND), dimension(:,:), intent(inout) :: flux_arr real (kind=RKIND), dimension(:,:), intent(inout) :: flux_upwind_tmp, flux_tmp + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group logical, intent(in), optional :: advance_density real (kind=RKIND), dimension(:,:), intent(inout), optional :: rho_zz_int @@ -5036,7 +5070,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, edgesOnCell, edgesOnCell_sign, nEdgesOnCell, fnm, fnp, rdnw, nAdvCellsForEdge, & advCellsForEdge, adv_coefs, adv_coefs_3rd, scalar_old, scalar_new, s_max, s_min, & wdtn, scale_arr, flux_arr, flux_upwind_tmp, flux_tmp, & - bdyMaskCell, bdyMaskEdge, & + bdyMaskCell, bdyMaskEdge, config_gpu_aware_mpi, & exchange_halo_group, advance_density, rho_zz_int) call mpas_deallocate_scratch_field(scale) @@ -5084,7 +5118,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge edgesOnCell, edgesOnCell_sign, nEdgesOnCell, fnm, fnp, rdnw, nAdvCellsForEdge, & advCellsForEdge, adv_coefs, adv_coefs_3rd, scalar_old, scalar_new, s_max, s_min, & wdtn, scale_arr, flux_arr, flux_upwind_tmp, flux_tmp, & - bdyMaskCell, bdyMaskEdge, & + bdyMaskCell, bdyMaskEdge, config_gpu_aware_mpi, & exchange_halo_group, advance_density, rho_zz_int) use mpas_atm_dimensions, only : nVertLevels @@ -5100,6 +5134,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge real (kind=RKIND), intent(in) :: dt integer, intent(in) :: cellStart, cellEnd, edgeStart, edgeEnd integer, intent(in) :: cellSolveStart, cellSolveEnd + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group logical, intent(in), optional :: advance_density real (kind=RKIND), dimension(:,:), intent(inout), optional :: rho_zz_int @@ -5196,15 +5231,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') + !$acc update self(scalars_old) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old') + call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old', config_gpu_aware_mpi) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') + !$acc update device(scalars_old) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') ! @@ -5601,15 +5638,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge ! MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') + !$acc update self(scale_arr) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:scale') + call exchange_halo_group(block % domain, 'dynamics:scale', config_gpu_aware_mpi) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') + !$acc update device(scale_arr) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$acc parallel diff --git a/src/core_atmosphere/mpas_atm_core.F b/src/core_atmosphere/mpas_atm_core.F index 5da37f931b..7a79527910 100644 --- a/src/core_atmosphere/mpas_atm_core.F +++ b/src/core_atmosphere/mpas_atm_core.F @@ -60,7 +60,6 @@ function atm_core_init(domain, startTimeStamp) result(ierr) character (len=StrKIND), pointer :: initial_time1, initial_time2 type (MPAS_Time_Type) :: startTime - real (kind=RKIND), dimension(:,:), pointer :: u, ru, rw, pv_edge real (kind=RKIND), pointer :: nominalMinDc real (kind=RKIND), pointer :: config_len_disp real (kind=RKIND), pointer :: Time @@ -239,14 +238,7 @@ function atm_core_init(domain, startTimeStamp) result(ierr) startTime = mpas_get_clock_time(clock, MPAS_START_TIME, ierr) call mpas_get_time(startTime, dateTimeString=startTimeStamp) - ! This copy is temporarily needed to ensure that the field u is available on the device - ! prior to the GPU-direct halo exchange in the call to exchange_halo_group. It will be - ! removed in subsequent commits that introduce a namelist option to select whether halo exchanges - ! are performed on the device or not. - call mpas_pool_get_array(domain % blocklist % allFields, 'u', u, 1) - !$acc enter data copyin(u) call exchange_halo_group(domain, 'initialization:u') - !$acc exit data copyout(u) ! @@ -284,18 +276,7 @@ function atm_core_init(domain, startTimeStamp) result(ierr) block => block % next end do - ! This copy is temporarily needed to ensure that ru, rw, pv_edge are available on the device - ! prior to the GPU-direct halo exchange in the call to exchange_halo_group. It will be - ! removed in subsequent commits that introduce a namelist option to select whether halo exchanges - ! are performed on the device or not. - call mpas_pool_get_array(domain % blocklist % allFields, 'ru', ru, 1) - !$acc enter data copyin(ru) - call mpas_pool_get_array(domain % blocklist % allFields, 'rw', rw, 1) - !$acc enter data copyin(rw) - call mpas_pool_get_array(domain % blocklist % allFields, 'pv_edge', pv_edge) - !$acc enter data copyin(pv_edge) call exchange_halo_group(domain, 'initialization:pv_edge,ru,rw') - !$acc exit data copyout(pv_edge,ru,rw) call mpas_atm_diag_setup(domain % streamManager, domain % blocklist % configs, & domain % blocklist % structs, domain % clock, domain % dminfo) diff --git a/src/framework/mpas_dmpar.F b/src/framework/mpas_dmpar.F index a107412d97..e18cc4a310 100644 --- a/src/framework/mpas_dmpar.F +++ b/src/framework/mpas_dmpar.F @@ -7450,10 +7450,11 @@ end subroutine mpas_dmpar_exch_group_end_halo_exch!}}} !> exchange is complete. ! !----------------------------------------------------------------------- - subroutine mpas_dmpar_exch_group_full_halo_exch(domain, groupName, iErr)!{{{ + subroutine mpas_dmpar_exch_group_full_halo_exch(domain, groupName, withGPUAwareMPI, iErr)!{{{ type (domain_type), intent(inout) :: domain character (len=*), intent(in) :: groupName + logical, optional, intent(in) :: withGPUAwareMPI integer, optional, intent(out) :: iErr type (mpas_exchange_group), pointer :: exchGroupPtr @@ -7463,6 +7464,12 @@ subroutine mpas_dmpar_exch_group_full_halo_exch(domain, groupName, iErr)!{{{ iErr = MPAS_DMPAR_NOERR end if + if (present(withGPUAwareMPI)) then + if (withGPUAwareMPI) then + call mpas_log_write(' GPU-aware MPI not implemented in this module', MPAS_LOG_CRIT) + end if + end if + nLen = len_trim(groupName) DMPAR_DEBUG_WRITE(' -- Trying to perform a full exchange for group ' // trim(groupName)) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 582afa7c34..c596ab08e0 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -533,7 +533,7 @@ end subroutine mpas_halo_exch_group_add_field !> exchange group. ! !----------------------------------------------------------------------- - subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) + subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMPI, iErr) #ifdef MPAS_USE_MPI_F08 use mpi_f08, only : MPI_Datatype, MPI_Comm @@ -565,6 +565,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Arguments type (domain_type), intent(inout) :: domain character (len=*), intent(in) :: groupName + logical, optional, intent(in) :: withGPUAwareMPI integer, optional, intent(out) :: iErr ! Local variables @@ -580,6 +581,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) integer :: comm #endif integer :: mpi_ierr + logical:: useGPUAwareMPI type (mpas_halo_group), pointer :: group integer, dimension(:), pointer :: compactHaloInfo integer, dimension(:), pointer :: compactSendLists @@ -598,6 +600,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) iErr = 0 end if + useGPUAwareMPI = .false. + if (present(withGPUAwareMPI)) then + useGPUAwareMPI = withGPUAwareMPI + end if + ! ! Find this halo exhange group in the list of groups ! @@ -641,7 +648,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % recvBuf) + !$acc host_data use_device(group % recvBuf) if(useGPUAwareMPI) call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) @@ -687,7 +694,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! - !$acc parallel default(present) attach(group % fields(i) % r1arr) + !$acc parallel default(present) attach(group % fields(i) % r1arr) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -716,7 +723,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - !$acc parallel default(present) attach(group % fields(i) % r2arr) + !$acc parallel default(present) attach(group % fields(i) % r2arr) if(useGPUAwareMPI) !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -743,7 +750,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! - !$acc parallel default(present) attach(group % fields(i) % r3arr) + !$acc parallel default(present) attach(group % fields(i) % r3arr) if(useGPUAwareMPI) !$acc loop gang collapse(4) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -775,7 +782,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % sendBuf) + !$acc host_data use_device(group % sendBuf) if(useGPUAwareMPI) call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) @@ -834,7 +841,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc parallel default(present) attach(group % fields(i) % r1arr) + !$acc parallel default(present) attach(group % fields(i) % r1arr) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop vector @@ -854,7 +861,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc parallel default(present) attach(group % fields(i) % r2arr) + !$acc parallel default(present) attach(group % fields(i) % r2arr) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop worker @@ -877,7 +884,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc parallel default(present) attach(group % fields(i) % r3arr) + !$acc parallel default(present) attach(group % fields(i) % r3arr) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iHalo = 1, nHalos do j = 1, maxNRecvList diff --git a/src/framework/mpas_halo_interface.inc b/src/framework/mpas_halo_interface.inc index 8f0934fbb0..b1dd9a9c99 100644 --- a/src/framework/mpas_halo_interface.inc +++ b/src/framework/mpas_halo_interface.inc @@ -3,12 +3,13 @@ ! in a named group ! abstract interface - subroutine halo_exchange_routine(domain, halo_group, ierr) + subroutine halo_exchange_routine(domain, halo_group, withGPUAwareMPI, ierr) use mpas_derived_types, only : domain_type type (domain_type), intent(inout) :: domain character(len=*), intent(in) :: halo_group + logical, intent(in), optional :: withGPUAwareMPI integer, intent(out), optional :: ierr end subroutine halo_exchange_routine