diff --git a/Project.toml b/Project.toml index 8ff48b8ab..5876dac72 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DistributedNext = "fab6aee4-877b-4bac-a744-3eca44acbb6f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -69,6 +70,7 @@ DataStructures = "0.18, 0.19" DistributedNext = "1.0.0" Distributions = "0.25" FillArrays = "1.13.0" +GPUArraysCore = "0.2.0" GraphViz = "0.2" Graphs = "1" JSON3 = "1" diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 9f9b8df4d..de4002a6b 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -303,13 +303,11 @@ end CuArray(H::Dagger.HaloArray) = convert(CuArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CuArray} = Dagger.HaloArray(C(H.center), - C.(H.edges), - C.(H.corners), + C.(H.halos), H.halo_width) Adapt.adapt_structure(to::CUDA.KernelAdaptor, H::Dagger.HaloArray) = Dagger.HaloArray(adapt(to, H.center), - adapt.(Ref(to), H.edges), - adapt.(Ref(to), H.corners), + adapt.(Ref(to), H.halos), H.halo_width) function Dagger.inner_stencil_proc!(::CuArrayDeviceProc, f, output, read_vars) Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output)) diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 08d54ee81..b0de56f38 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -271,13 +271,11 @@ end oneArray(H::Dagger.HaloArray) = convert(oneArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:oneArray} = Dagger.HaloArray(C(H.center), - C.(H.edges), - C.(H.corners), + C.(H.halos), H.halo_width) Adapt.adapt_structure(to::oneAPI.KernelAdaptor, H::Dagger.HaloArray) = Dagger.HaloArray(adapt(to, H.center), - adapt.(Ref(to), H.edges), - adapt.(Ref(to), H.corners), + adapt.(Ref(to), H.halos), H.halo_width) function Dagger.inner_stencil_proc!(::oneArrayDeviceProc, f, output, read_vars) Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output)) diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 38d5906dd..43a15a38f 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -287,8 +287,11 @@ end MtlArray(H::Dagger.HaloArray) = convert(MtlArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:MtlArray} = Dagger.HaloArray(C(H.center), - C.(H.edges), - C.(H.corners), + C.(H.halos), + H.halo_width) +Adapt.adapt_structure(to::Metal.Adaptor, H::Dagger.HaloArray) = + Dagger.HaloArray(adapt(to, H.center), + adapt.(Ref(to), H.halos), H.halo_width) function Dagger.inner_stencil_proc!(::MtlArrayDeviceProc, f, output, read_vars) Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output)) diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index f8eac930c..144a1ba77 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -254,13 +254,11 @@ end CLArray(H::Dagger.HaloArray) = convert(CLArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:CLArray} = Dagger.HaloArray(C(H.center), - C.(H.edges), - C.(H.corners), + C.(H.halos), H.halo_width) Adapt.adapt_structure(to::OpenCL.KernelAdaptor, H::Dagger.HaloArray) = Dagger.HaloArray(adapt(to, H.center), - adapt.(Ref(to), H.edges), - adapt.(Ref(to), H.corners), + adapt.(Ref(to), H.halos), H.halo_width) function Dagger.inner_stencil_proc!(::CLArrayDeviceProc, f, output, read_vars) Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output)) diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 773c2bb95..444ff919d 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -273,8 +273,11 @@ end ROCArray(H::Dagger.HaloArray) = convert(ROCArray, H) Base.convert(::Type{C}, H::Dagger.HaloArray) where {C<:ROCArray} = Dagger.HaloArray(C(H.center), - C.(H.edges), - C.(H.corners), + C.(H.halos), + H.halo_width) +Adapt.adapt_structure(to::AMDGPU.Runtime.Adaptor, H::Dagger.HaloArray) = + Dagger.HaloArray(adapt(to, H.center), + adapt.(Ref(to), H.halos), H.halo_width) function Dagger.inner_stencil_proc!(::ROCArrayDeviceProc, f, output, read_vars) Dagger.Kernel(_inner_stencil!)(f, output, read_vars; ndrange=size(output)) diff --git a/src/Dagger.jl b/src/Dagger.jl index 102a76149..ff6161aa4 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -32,6 +32,8 @@ import TimespanLogging: timespan_start, timespan_finish import Adapt +import GPUArraysCore + import Preferences: @load_preference, @set_preferences! if @load_preference("distributed-package") == "DistributedNext" diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 1b83d9b27..a39be4390 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -34,7 +34,7 @@ with_index_caching(f, size::Integer=1) = with(f, GETINDEX_CACHE_SIZE=>size) end # Return the value - return part[offset_idx...] + return GPUArraysCore.@allowscalar part[offset_idx...] end function partition_for(A::DArray, idx::NTuple{N,Int}) where N part_idx = zeros(Int, N) @@ -112,7 +112,10 @@ Base.getindex(A::DArray, idx::ArrayDomain) = part = A.chunks[part_idx...] space = memory_space(part) scope = UnionScope(map(ExactScope, collect(processors(space)))) - return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...)) + return fetch(Dagger.@spawn scope=scope setindex_allowscalar!(part, value, offset_idx...)) +end +function setindex_allowscalar!(part, value, offset_idx...) + GPUArraysCore.@allowscalar setindex!(part, value, offset_idx...) end Base.setindex!(A::DArray, value, idx::Integer...) = setindex!(A, value, idx) diff --git a/src/stencil.jl b/src/stencil.jl index b283c119e..aa09ac54a 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -18,83 +18,77 @@ function validate_neigh_dist(neigh_dist, size) end end -function load_neighbor_edge(arr, dim, dir, neigh_dist) +# Load a halo region from a neighboring chunk +# region_code: N-tuple where each element is -1 (low), 0 (full extent), or +1 (high) +# For dimensions with code 0, we take the full extent of the array +# For dimensions with code -1, we take the last neigh_dist elements (to go to neighbor's low side) +# For dimensions with code +1, we take the first neigh_dist elements (to go to neighbor's high side) +function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where N validate_neigh_dist(neigh_dist, size(arr)) - if dir == -1 - start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr))) - elseif dir == 1 - start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr))) - end + start_idx = CartesianIndex(ntuple(N) do i + if region_code[i] == -1 + lastindex(arr, i) - neigh_dist + 1 + else + firstindex(arr, i) + end + end) + stop_idx = CartesianIndex(ntuple(N) do i + if region_code[i] == +1 + firstindex(arr, i) + neigh_dist - 1 + else + lastindex(arr, i) + end + end) # FIXME: Don't collect return move(task_processor(), collect(@view arr[start_idx:stop_idx])) end -function load_neighbor_corner(arr, corner_side, neigh_dist) - validate_neigh_dist(neigh_dist, size(arr)) - start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr))) - return move(task_processor(), collect(@view arr[start_idx:stop_idx])) -end function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary) validate_neigh_dist(neigh_dist) + N = ndims(chunks) # FIXME: Depends on neigh_dist and chunk size chunk_dist = 1 + # Get the center accesses = Any[chunks[idx]] - # Get the edges - for dim in 1:ndims(chunks) - for dir in (-1, +1) - new_idx = idx + CartesianIndex(ntuple(i -> i == dim ? dir*chunk_dist : 0, ndims(chunks))) - if is_past_boundary(size(chunks), new_idx) - if boundary_has_transition(boundary) - new_idx = boundary_transition(boundary, new_idx, size(chunks)) - else - new_idx = idx - end - chunk = chunks[new_idx] - push!(accesses, Dagger.@spawn load_boundary_edge(boundary, chunk, dim, dir, neigh_dist)) - else - chunk = chunks[new_idx] - push!(accesses, Dagger.@spawn load_neighbor_edge(chunk, dim, dir, neigh_dist)) - end + # Iterate over all 3^N - 1 halo regions (excluding center) + # Each region is identified by a code tuple where each element is -1, 0, or +1 + for i in 0:(3^N - 1) + region_code = ntuple(N) do d + ((i ÷ 3^(d-1)) % 3) - 1 # Maps 0,1,2 -> -1,0,+1 end - end + all(==(0), region_code) && continue # Skip center - # Get the corners - for corner_num in 1:(2^ndims(chunks)) - corner_side = CartesianIndex(reverse(ntuple(ndims(chunks)) do i - ((corner_num-1) >> (((ndims(chunks) - i) + 1) - 1)) & 1 - end)) - corner_new_idx = CartesianIndex(ntuple(ndims(chunks)) do i - corner_shift = iszero(corner_side[i]) ? -1 : 1 - return idx[i] + corner_shift + # Compute the chunk offset for this region + # For each dimension: -1 means go to previous chunk, +1 means go to next chunk, 0 means same chunk + chunk_offset = CartesianIndex(ntuple(N) do d + region_code[d] * chunk_dist end) - if is_past_boundary(size(chunks), corner_new_idx) + new_idx = idx + chunk_offset + + if is_past_boundary(size(chunks), new_idx) if boundary_has_transition(boundary) - corner_new_idx = boundary_transition(boundary, corner_new_idx, size(chunks)) + new_idx = boundary_transition(boundary, new_idx, size(chunks)) else - corner_new_idx = idx + new_idx = idx end - chunk = chunks[corner_new_idx] - push!(accesses, Dagger.@spawn load_boundary_corner(boundary, chunk, corner_side, neigh_dist)) + chunk = chunks[new_idx] + push!(accesses, Dagger.@spawn load_boundary_region(boundary, chunk, region_code, neigh_dist)) else - chunk = chunks[corner_new_idx] - push!(accesses, Dagger.@spawn load_neighbor_corner(chunk, corner_side, neigh_dist)) + chunk = chunks[new_idx] + push!(accesses, Dagger.@spawn load_neighbor_region(chunk, region_code, neigh_dist)) end end - @assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))" + @assert length(accesses) == 3^N "Accesses mismatch: expected $(3^N), got $(length(accesses))" return accesses end -function build_halo(neigh_dist, boundary, center, all_neighbors...) +function build_halo(neigh_dist, boundary, center, all_halos...) N = ndims(center) - edges = all_neighbors[1:(2*N)] - corners = all_neighbors[((2^N)+1):end] - @assert length(edges) == 2*N && length(corners) == 2^N "Halo mismatch: edges=$(length(edges)) corners=$(length(corners))" - return HaloArray(center, (edges...,), (corners...,), ntuple(_->neigh_dist, N)) + expected_halos = 3^N - 1 + @assert length(all_halos) == expected_halos "Halo mismatch: N=$N expected $expected_halos halos, got $(length(all_halos))" + return HaloArray(center, (all_halos...,), ntuple(_->neigh_dist, N)) end function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N} @assert all(arr.halo_width .== arr.halo_width[1]) @@ -121,31 +115,21 @@ struct Wrap end boundary_has_transition(::Wrap) = true boundary_transition(::Wrap, idx, size) = CartesianIndex(ntuple(i -> mod1(idx[i], size[i]), length(size))) -load_boundary_edge(::Wrap, arr, dim, dir, neigh_dist) = load_neighbor_edge(arr, dim, dir, neigh_dist) -load_boundary_corner(::Wrap, arr, corner_side, neigh_dist) = load_neighbor_corner(arr, corner_side, neigh_dist) +load_boundary_region(::Wrap, arr, region_code, neigh_dist) = load_neighbor_region(arr, region_code, neigh_dist) struct Pad{T} padval::T end boundary_has_transition(::Pad) = false -function load_boundary_edge(pad::Pad, arr, dim, dir, neigh_dist) - if dir == -1 - start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> i == dim ? lastindex(arr, i) : lastindex(arr, i), ndims(arr))) - elseif dir == 1 - start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr))) +function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_dist) where N + # Compute the size of this halo region + # For dimensions with code 0, use full array size + # For dimensions with code -1 or +1, use neigh_dist + region_size = ntuple(N) do i + region_code[i] == 0 ? size(arr, i) : neigh_dist end - edge_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr)) - # FIXME: return Fill(pad.padval, edge_size) - return move(task_processor(), fill(pad.padval, edge_size)) -end -function load_boundary_corner(pad::Pad, arr, corner_side, neigh_dist) - start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr))) - stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr))) - corner_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr)) - # FIXME: return Fill(pad.padval, corner_size) - return move(task_processor(), fill(pad.padval, corner_size)) + # FIXME: return Fill(pad.padval, region_size) + return move(task_processor(), fill(pad.padval, region_size)) end """ diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index e34cf2b05..41014045c 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -1,140 +1,181 @@ -# Define the HaloArray type with minimized halo storage -struct HaloArray{T,N,E,C,A,EAT<:Tuple,CAT<:Tuple} <: AbstractArray{T,N} +# Define the HaloArray type with generalized halo storage for any dimensionality +# Each halo region is identified by a "region code" - an N-tuple where each element is: +# -1 = low halo (index < 1) +# 0 = center (1 <= index <= center_size) +# +1 = high halo (index > center_size) +# There are 3^N - 1 halo regions (excluding the all-zeros center code) + +struct HaloArray{T,N,A<:AbstractArray{T,N},H<:Tuple} <: AbstractArray{T,N} center::A - edges::EAT - corners::CAT + halos::H # Tuple of 3^N - 1 arrays in canonical order halo_width::NTuple{N,Int} end -# Helper function to create an empty HaloArray with minimized halo storage -function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) where {T,N} - center = Array{T,N}(undef, center_size...) - edges = ntuple(2N) do i - prev_dims = center_size[1:(cld(i,2)-1)] - next_dims = center_size[(cld(i,2)+1):end] - return Array{T,N}(undef, prev_dims..., halo_width[cld(i,2)], next_dims...) +# Number of halo regions for N dimensions +num_halo_regions(N::Int) = 3^N - 1 + +# Generate all region codes in canonical order (excluding center) +# Order: iterate through base-3 representation, skip the center (all zeros) +function all_region_codes(::Val{N}) where N + codes = NTuple{N,Int}[] + for i in 0:(3^N - 1) + code = ntuple(N) do d + ((i ÷ 3^(d-1)) % 3) - 1 # Maps 0,1,2 -> -1,0,+1 + end + if !all(==(0), code) # Skip center + push!(codes, code) + end + end + return Tuple(codes) +end + +# Convert region code to flat index (1 to 3^N - 1) +# The index is based on treating the code as a base-3 number (with offset) +@inline function region_index(code::NTuple{N,Int}) where N + # Map: -1 → 0, 0 → 1, +1 → 2, treat as base-3 number + raw = 0 + for i in 1:N + raw += (code[i] + 1) * 3^(i-1) end - corners = ntuple(2^N) do i - return Array{T,N}(undef, halo_width) + # Center (all zeros) maps to raw index where all digits are 1 + # center_raw = sum(1 * 3^(i-1) for i in 1:N) = (3^N - 1) / 2 + center_raw = (3^N - 1) ÷ 2 + # Adjust index to skip center + return raw < center_raw ? raw + 1 : raw +end + +# Compute the size of a halo region given its code +@inline function halo_region_size(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}, code::NTuple{N,Int}) where N + return ntuple(N) do i + code[i] == 0 ? center_size[i] : halo_width[i] end - return HaloArray(center, edges, corners, halo_width) end -HaloArray(center::AT, edges::EAT, corners::CAT, halo_width::NTuple{N, Int}) where {T,N,AT<:AbstractArray{T,N},CAT<:Tuple,EAT<:Tuple} = - HaloArray{T,N,length(edges),length(corners),AT,EAT,CAT}(center, edges, corners, halo_width) +# Helper function to create an empty HaloArray with generalized halo storage +function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) where {T,N} + center = Array{T,N}(undef, center_size...) + + # Create all 3^N - 1 halo regions + codes = all_region_codes(Val(N)) + halos = ntuple(length(codes)) do i + code = codes[i] + region_size = halo_region_size(center_size, halo_width, code) + Array{T,N}(undef, region_size...) + end + + return HaloArray{T,N,typeof(center),typeof(halos)}(center, halos, halo_width) +end Base.size(tile::HaloArray) = size(tile.center) .+ 2 .* tile.halo_width -function Base.axes(tile::HaloArray{T,N,H}) where {T,N,H} +function Base.axes(tile::HaloArray{T,N}) where {T,N} ntuple(N) do i first_ind = 1 - tile.halo_width[i] last_ind = size(tile.center, i) + tile.halo_width[i] return first_ind:last_ind end end -function Base.similar(tile::HaloArray{T,N,H}, ::Type{T}, dims::NTuple{N,Int}) where {T,N,H} +function Base.similar(tile::HaloArray{T,N}, ::Type{T}, dims::NTuple{N,Int}) where {T,N} center_size = dims halo_width = tile.halo_width - return HaloArray{T,N,H}(center_size, halo_width) + return HaloArray{T,N}(center_size, halo_width) end -function Base.copy(tile::HaloArray{T,N,H}) where {T,N,H} +function Base.copy(tile::HaloArray{T,N}) where {T,N} center = copy(tile.center) - halo = ntuple(i->copy(tile.edges[i]), H) + halos = ntuple(i -> copy(tile.halos[i]), length(tile.halos)) halo_width = tile.halo_width - return HaloArray{T,N,H}(center, halo, halo_width) + return HaloArray(center, halos, halo_width) +end + +# Compute the region code for a given index +@inline function compute_region_code(tile::HaloArray{T,N}, I::NTuple{N,Int}) where {T,N} + return ntuple(N) do i + if I[i] < 1 + -1 + elseif I[i] > size(tile.center, i) + +1 + else + 0 + end + end +end + +# Compute local index within a halo region +@inline function compute_local_index(tile::HaloArray{T,N}, I::NTuple{N,Int}, code::NTuple{N,Int}) where {T,N} + return ntuple(N) do i + if code[i] == -1 + I[i] + tile.halo_width[i] + elseif code[i] == +1 + I[i] - size(tile.center, i) + else + I[i] # Inside this dimension, keep index + end + end end # Define getindex for HaloArray @inline function Base.getindex(tile::HaloArray{T,N}, I::Vararg{Int,N}) where {T,N} Base.@boundscheck checkbounds(tile, I...) - if all(1 .<= I .<= size(tile.center)) - return tile.center[I...] - elseif !any(1 .<= I .<= size(tile.center)) - # Corner - # N.B. Corner indexes are in binary, e.g. 0b01, 0b10, 0b11 - corner_idx = sum(ntuple(i->(I[i] < 1 ? 0 : 1) * (2^(i-1)), N)) + 1 - corner_offset = CartesianIndex(I) + CartesianIndex(ntuple(i->(I[i] < 1 ? tile.halo_width[i] : -size(tile.center, i)), N)) - return tile.corners[corner_idx][corner_offset] + code = compute_region_code(tile, I) + + if all(==(0), code) + # Center + return @inbounds tile.center[I...] else - for d in 1:N - if I[d] < 1 - halo_idx = ntuple(i->i == d ? I[i] + tile.halo_width[i] : I[i], N) - return tile.edges[(2*(d-1))+1][halo_idx...] - elseif I[d] > size(tile.center, d) - halo_idx = ntuple(i->i == d ? I[i] - size(tile.center, d) : I[i], N) - return tile.edges[(2*(d-1))+2][halo_idx...] - end - end + # Halo region + idx = region_index(code) + local_idx = compute_local_index(tile, I, code) + return @inbounds tile.halos[idx][local_idx...] end - error("Index out of bounds") end # Define setindex! for HaloArray @inline function Base.setindex!(tile::HaloArray{T,N}, value, I::Vararg{Int,N}) where {T,N} Base.@boundscheck checkbounds(tile, I...) - if all(1 .<= I .<= size(tile.center)) + code = compute_region_code(tile, I) + + if all(==(0), code) # Center - return tile.center[I...] = value - elseif !any(1 .<= I .<= size(tile.center)) - # Corner - # N.B. Corner indexes are in binary, e.g. 0b01, 0b10, 0b11 - corner_idx = sum(ntuple(i->(I[i] < 1 ? 0 : 1) * (2^(i-1)), N)) + 1 - corner_offset = CartesianIndex(I) + CartesianIndex(ntuple(i->(I[i] < 1 ? tile.halo_width[i] : -size(tile.center, i)), N)) - return tile.corners[corner_idx][corner_offset] = value + return @inbounds tile.center[I...] = value else - # Edge - for d in 1:N - if I[d] < 1 - halo_idx = ntuple(i->i == d ? I[i] + tile.halo_width[i] : I[i], N) - return tile.edges[(2*(d-1))+1][halo_idx...] = value - elseif I[d] > size(tile.center, d) - halo_idx = ntuple(i->i == d ? I[i] - size(tile.center, d) : I[i], N) - return tile.edges[(2*(d-1))+2][halo_idx...] = value - end - end + # Halo region + idx = region_index(code) + local_idx = compute_local_index(tile, I, code) + return @inbounds tile.halos[idx][local_idx...] = value end - error("Index out of bounds") end Adapt.adapt_structure(to, H::Dagger.HaloArray) = HaloArray(Adapt.adapt(to, H.center), - Adapt.adapt.(Ref(to), H.edges), - Adapt.adapt.(Ref(to), H.corners), + Adapt.adapt.(Ref(to), H.halos), H.halo_width) function aliasing(A::HaloArray) - return CombinedAliasing([aliasing(A.center), map(aliasing, A.edges)..., map(aliasing, A.corners)...]) + return CombinedAliasing([aliasing(A.center), map(aliasing, A.halos)...]) end memory_space(A::HaloArray) = memory_space(A.center) + function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, A::HaloArray) center_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.center) - edge_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.edges[i]), length(A.edges)) - corner_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.corners[i]), length(A.corners)) + halo_chunks = ntuple(i -> rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.halos[i]), length(A.halos)) halo_width = A.halo_width to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, halo_chunks, halo_width center_new = move(from_proc, to_proc, center_chunk) - edges_new = ntuple(i->move(from_proc, to_proc, edge_chunks[i]), length(edge_chunks)) - corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) - return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) + halos_new = ntuple(i -> move(from_proc, to_proc, halo_chunks[i]), length(halo_chunks)) + return tochunk(HaloArray(center_new, halos_new, halo_width), to_proc) end end + function find_object_holding_ptr(object::HaloArray, ptr::UInt64) - for i in 1:length(object.edges) - edge = object.edges[i] - span = LocalMemorySpan(pointer(edge), length(edge)*sizeof(eltype(edge))) + for i in 1:length(object.halos) + halo = object.halos[i] + span = LocalMemorySpan(pointer(halo), length(halo)*sizeof(eltype(halo))) if span_start(span) <= ptr <= span_end(span) - return edge - end - end - for i in 1:length(object.corners) - corner = object.corners[i] - span = LocalMemorySpan(pointer(corner), length(corner)*sizeof(eltype(corner))) - if span_start(span) <= ptr <= span_end(span) - return corner + return halo end end center = object.center span = LocalMemorySpan(pointer(center), length(center)*sizeof(eltype(center))) @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in HaloArray" return center -end \ No newline at end of file +end diff --git a/test/array/stencil.jl b/test/array/stencil.jl index 69b7bb069..7160b82a8 100644 --- a/test/array/stencil.jl +++ b/test/array/stencil.jl @@ -119,6 +119,24 @@ function test_stencil() @test collect(B) == expected_B_pad_val end + # From issue #669 + for N in 3:4 + @testset "$(N)D array" begin + A = ones(Blocks(ntuple(_->1, N)...), Int, ntuple(_->3, N)...) + Dagger.allowscalar() do + A[:] = 1:length(A) + end + B = zeros(Blocks(ntuple(_->1, N)...), Float32, ntuple(_->3, N)...) + + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, Wrap())) / length(A) + end + end + @test all(==(Float64(sum(1:length(A)) / length(A))), collect(B)) + end + end + @testset "Invalid neighborhood distance" begin A = ones(Blocks(1, 1), Int, 2, 2) B = zeros(Blocks(1, 1), Int, 2, 2)