Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ext/AMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ import AMDGPU
import FiniteElementContainers

# public API
function FiniteElementContainers.rocm(x)
function FiniteElementContainers.to_backend(::AMDGPU.ROCBackend, x)
return Adapt.adapt_structure(AMDGPU.ROCArray, x)
end

# back-compat alias
FiniteElementContainers.rocm(x) =
FiniteElementContainers.to_backend(AMDGPU.ROCBackend(), x)

# private API
function FiniteElementContainers._coo_matrix_constructor(::AMDGPU.ROCBackend)
return AMDGPU.rocSPARSE.ROCSparseMatrixCOO
Expand Down
6 changes: 5 additions & 1 deletion ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ import CUDA
import FiniteElementContainers

# public API
function FiniteElementContainers.cuda(x)
function FiniteElementContainers.to_backend(::CUDA.CUDABackend, x)
return Adapt.adapt_structure(CUDA.CuArray, x)
end

# back-compat alias
FiniteElementContainers.cuda(x) =
FiniteElementContainers.to_backend(CUDA.CUDABackend(), x)

# private API
function FiniteElementContainers._coo_matrix_constructor(::CUDA.CUDABackend)
return CUDA.CUSPARSE.CuSparseMatrixCOO
Expand Down
9 changes: 9 additions & 0 deletions src/FiniteElementContainers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module FiniteElementContainers
export cpu
export cuda
export rocm
export to_backend

# Assemblers
export SparseMatrixAssembler
Expand Down Expand Up @@ -207,6 +208,14 @@ cpu(x) = adapt(Array, x)
function cuda end
function rocm end

# Move `x` onto the given KernelAbstractions backend. CPU is identity —
# CPU-built data already lives on the CPU backend. GPU backends (CUDABackend,
# ROCBackend) are provided by the CUDA / AMDGPU package extensions.
to_backend(::KA.CPU, x) = x
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only suggestion here is that in the future we add an additional method to_host(x) = to_backend(KA.CPU(), x)

to_backend(b::KA.Backend, x) = error(
"to_backend is not implemented for backend $(typeof(b)); load the " *
"corresponding GPU package (CUDA.jl or AMDGPU.jl) so its extension activates.")

# function communication_graph end
function create_partition end
function create_matrix_sparsity_pattern end
Expand Down
Loading