diff --git a/backends/cuda-ref/ceed-cuda-vector.c b/backends/cuda-ref/ceed-cuda-vector.c index bb373dd037..a6dcb2462c 100644 --- a/backends/cuda-ref/ceed-cuda-vector.c +++ b/backends/cuda-ref/ceed-cuda-vector.c @@ -13,6 +13,30 @@ #include #include "ceed-cuda-ref.h" + +//------------------------------------------------------------------------------ +// Check if host/device sync is needed +//------------------------------------------------------------------------------ +static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, + CeedMemType mem_type, bool *need_sync) { + int ierr; + CeedVector_Cuda *impl; + ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); + + bool has_valid_array = false; + ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); + switch (mem_type) { + case CEED_MEM_HOST: + *need_sync = has_valid_array && !impl->h_array; + break; + case CEED_MEM_DEVICE: + *need_sync = has_valid_array && !impl->d_array; + break; + } + + return CEED_ERROR_SUCCESS; +} + //------------------------------------------------------------------------------ // Sync host to device //------------------------------------------------------------------------------ @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) { //------------------------------------------------------------------------------ // Sync arrays //------------------------------------------------------------------------------ -static inline int CeedVectorSync_Cuda(const CeedVector vec, - CeedMemType mem_type) { +static int CeedVectorSyncArray_Cuda(const CeedVector vec, + CeedMemType mem_type) { + int ierr; + // Check whether device/host sync is needed + bool need_sync = false; + ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); + CeedChkBackend(ierr); + if (!need_sync) + return CEED_ERROR_SUCCESS; + switch (mem_type) { case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec); case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec); @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, return CEED_ERROR_SUCCESS; } -//------------------------------------------------------------------------------ -// Check if is any array of given type -//------------------------------------------------------------------------------ -static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, - CeedMemType mem_type, bool *need_sync) { - int ierr; - CeedVector_Cuda *impl; - ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); - - bool has_valid_array = false; - ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); - switch (mem_type) { - case CEED_MEM_HOST: - *need_sync = has_valid_array && !impl->h_array; - break; - case CEED_MEM_DEVICE: - *need_sync = has_valid_array && !impl->d_array; - break; - } - - return CEED_ERROR_SUCCESS; -} - //------------------------------------------------------------------------------ // Set array from host //------------------------------------------------------------------------------ @@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); // Sync array to requested mem_type - bool need_sync = false; - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); - if (need_sync) { - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); - } + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); // Update pointer switch (mem_type) { @@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, CeedVector_Cuda *impl; ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); - bool need_sync = false, has_array_of_type = true; - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); - ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type); - CeedChkBackend(ierr); - if (need_sync) { - // Sync array to requested mem_type - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); - } + // Sync array to requested mem_type + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); // Update pointer switch (mem_type) { @@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) { ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())(CeedVectorSetValue_Cuda)); CeedChkBackend(ierr); + ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", + CeedVectorSyncArray_Cuda); CeedChkBackend(ierr); ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda); CeedChkBackend(ierr); ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", diff --git a/backends/hip-ref/ceed-hip-ref-vector.c b/backends/hip-ref/ceed-hip-ref-vector.c index b8371225cb..49c6494025 100644 --- a/backends/hip-ref/ceed-hip-ref-vector.c +++ b/backends/hip-ref/ceed-hip-ref-vector.c @@ -13,6 +13,30 @@ #include #include "ceed-hip-ref.h" + +//------------------------------------------------------------------------------ +// Check if host/device sync is needed +//------------------------------------------------------------------------------ +static inline int CeedVectorNeedSync_Hip(const CeedVector vec, + CeedMemType mem_type, bool *need_sync) { + int ierr; + CeedVector_Hip *impl; + ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); + + bool has_valid_array = false; + ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); + switch (mem_type) { + case CEED_MEM_HOST: + *need_sync = has_valid_array && !impl->h_array; + break; + case CEED_MEM_DEVICE: + *need_sync = has_valid_array && !impl->d_array; + break; + } + + return CEED_ERROR_SUCCESS; +} + //------------------------------------------------------------------------------ // Sync host to device //------------------------------------------------------------------------------ @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) { //------------------------------------------------------------------------------ // Sync arrays //------------------------------------------------------------------------------ -static inline int CeedVectorSync_Hip(const CeedVector vec, - CeedMemType mem_type) { +static int CeedVectorSyncArray_Hip(const CeedVector vec, + CeedMemType mem_type) { + int ierr; + // Check whether device/host sync is needed + bool need_sync = false; + ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); + CeedChkBackend(ierr); + if (!need_sync) + return CEED_ERROR_SUCCESS; + switch (mem_type) { case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec); case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec); @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, return CEED_ERROR_SUCCESS; } -//------------------------------------------------------------------------------ -// Sync array of given type -//------------------------------------------------------------------------------ -static inline int CeedVectorNeedSync_Hip(const CeedVector vec, - CeedMemType mem_type, bool *need_sync) { - int ierr; - CeedVector_Hip *impl; - ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); - - bool has_valid_array = false; - ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); - switch (mem_type) { - case CEED_MEM_HOST: - *need_sync = has_valid_array && !impl->h_array; - break; - case CEED_MEM_DEVICE: - *need_sync = has_valid_array && !impl->d_array; - break; - } - - return CEED_ERROR_SUCCESS; -} - //------------------------------------------------------------------------------ // Set array from host //------------------------------------------------------------------------------ @@ -363,11 +372,7 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); // Sync array to requested mem_type - bool need_sync = false; - ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr); - if (need_sync) { - ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr); - } + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); // Update pointer switch (mem_type) { @@ -398,13 +403,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedVector_Hip *impl; ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); - bool need_sync = false; - ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr); - CeedChkBackend(ierr); - if (need_sync) { - // Sync array to requested mem_type - ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr); - } + // Sync array to requested mem_type + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); // Update pointer switch (mem_type) { @@ -758,6 +758,8 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) { CeedVectorTakeArray_Hip); CeedChkBackend(ierr); ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr); + ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", + CeedVectorSyncArray_Hip); CeedChkBackend(ierr); ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip); CeedChkBackend(ierr); ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", diff --git a/interface/ceed.c b/interface/ceed.c index 08ef354072..34635cdf4a 100644 --- a/interface/ceed.c +++ b/interface/ceed.c @@ -843,6 +843,7 @@ int CeedInit(const char *resource, Ceed *ceed) { CEED_FTABLE_ENTRY(CeedVector, SetArray), CEED_FTABLE_ENTRY(CeedVector, TakeArray), CEED_FTABLE_ENTRY(CeedVector, SetValue), + CEED_FTABLE_ENTRY(CeedVector, SyncArray), CEED_FTABLE_ENTRY(CeedVector, GetArray), CEED_FTABLE_ENTRY(CeedVector, GetArrayRead), CEED_FTABLE_ENTRY(CeedVector, GetArrayWrite),