From c5f5df704d69aadf617d2698bf0f344c978ac17d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 5 Feb 2026 10:20:03 -0700 Subject: [PATCH 1/4] stencils: Support mixed boundaries, add Clamp and LinearExtrapolate --- docs/src/stencils.md | 83 ++++++- src/stencil.jl | 516 ++++++++++++++++++++++++++++++++++++------ test/array/stencil.jl | 171 +++++++++++++- 3 files changed, 702 insertions(+), 68 deletions(-) diff --git a/docs/src/stencils.md b/docs/src/stencils.md index cbaa92b67..c42e6f7b1 100644 --- a/docs/src/stencils.md +++ b/docs/src/stencils.md @@ -8,7 +8,7 @@ The fundamental structure of a `@stencil` block involves iterating over an impli ```julia using Dagger -import Dagger: @stencil, Wrap, Pad, Reflect +import Dagger: @stencil, Wrap, Pad, Reflect, Clamp, LinearExtrapolate # Initialize a DArray A = zeros(Blocks(2, 2), Int, 4, 4) @@ -31,13 +31,16 @@ The true power of stencils comes from accessing neighboring elements. The `@neig `@neighbors(array[idx], distance, boundary_condition)` - `array[idx]`: The array and current index from which to find neighbors. -- `distance`: An integer specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D). +- `distance`: An integer or `Tuple` of integers specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D). - `boundary_condition`: Defines how to handle accesses beyond the array boundaries. Available conditions are: - - `Wrap()`: Wraps around to the other side of the array. + - `Wrap()`: Wraps around to the other side of the array (periodic boundaries). - `Pad(value)`: Pads with a specified `value`. - `Reflect(symmetric)`: Reflects values back into the array at boundaries. The `symmetric` boolean controls whether the edge element is included in the reflection: - `Reflect(true)` (symmetric): Edge element IS repeated. For array `[a,b,c,d]`, extends as `[...,c,b,a,a,b,c,d,d,c,b,...]`. - `Reflect(false)` (mirror): Edge element NOT repeated. For array `[a,b,c,d]`, extends as `[...,d,c,b,a,b,c,d,c,b,a,...]`. + - `Clamp()`: Clamps to the boundary value (repeats edge elements). For array `[a,b,c,d]`, extends as `[...,a,a,a,a,b,c,d,d,d,d,...]`. + - `LinearExtrapolate()`: Linearly extrapolates using the slope at the boundary. Only works with `Real` element types. For array `[2,4,6,8]`, the slope at the low boundary is `4-2=2`, so index 0 would be `2-2=0`. + - **Mixed BCs (Tuple)**: You can specify different boundary conditions per dimension using a tuple. For example, `(Wrap(), Pad(0))` uses `Wrap` for dimension 1 and `Pad(0)` for dimension 2. ### Example: Averaging Neighbors with `Wrap` @@ -149,6 +152,80 @@ end @assert collect(B) == [5, 6, 9, 10] ``` +### Example: Edge Detection with `Clamp` + +The `Clamp` boundary condition repeats edge values, which is useful when you want boundary elements to have a neutral effect: + +```julia +import Dagger: Clamp + +# Array [1, 2, 3, 4] extends as [..., 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, ...] +A = DArray([1, 2, 3, 4], Blocks(2)) +B = zeros(Blocks(2), Int, 4) + +Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, Clamp())) + end +end + +# B[1]: indices 0,1,2 -> 0 clamps to 1, so [1,1,2] = 4 +# B[2]: indices 1,2,3 -> all in bounds, [1,2,3] = 6 +# B[3]: indices 2,3,4 -> all in bounds, [2,3,4] = 9 +# B[4]: indices 3,4,5 -> 5 clamps to 4, so [3,4,4] = 11 +@assert collect(B) == [4, 6, 9, 11] +``` + +### Example: Smooth Extrapolation with `LinearExtrapolate` + +The `LinearExtrapolate` boundary condition extrapolates linearly based on the slope at the boundary. This is useful for maintaining trends at edges: + +```julia +import Dagger: LinearExtrapolate + +# Array [2.0, 4.0, 6.0, 8.0] has slope 2.0 at both boundaries +# At low boundary: index 0 -> 2.0 + 2.0*(-1) = 0.0 +# At high boundary: index 5 -> 8.0 + 2.0*(1) = 10.0 +A = DArray([2.0, 4.0, 6.0, 8.0], Blocks(2)) +B = zeros(Blocks(2), Float64, 4) + +Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, LinearExtrapolate())) + end +end + +# B[1]: indices 0,1,2 -> [0.0, 2.0, 4.0] = 6.0 +# B[2]: indices 1,2,3 -> [2.0, 4.0, 6.0] = 12.0 +# B[3]: indices 2,3,4 -> [4.0, 6.0, 8.0] = 18.0 +# B[4]: indices 3,4,5 -> [6.0, 8.0, 10.0] = 24.0 +@assert collect(B) ≈ [6.0, 12.0, 18.0, 24.0] +``` + +### Example: Mixed Boundary Conditions + +You can specify different boundary conditions for each dimension using a tuple. This is useful when different boundaries have different physical meanings: + +```julia +import Dagger: Wrap, Pad + +# 2D array with Wrap in dimension 1 (rows) and Pad(0) in dimension 2 (columns) +A = DArray(reshape(1:16, 4, 4), Blocks(2, 2)) +B = zeros(Blocks(2, 2), Int, 4, 4) + +Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, (Wrap(), Pad(0)))) + end +end + +# For each element: +# - Row neighbors wrap around (periodic in rows) +# - Column neighbors are padded with 0 (zero-flux at column boundaries) +``` + +This is particularly useful for physical simulations where, for example, you might have periodic boundaries in one direction and fixed boundaries in another. + ## Sequential Semantics Expressions within a `@stencil` block are executed sequentially in terms of their effect on the data. This means that the result of one statement is visible to the subsequent statements, as if they were applied "all at once" across all indices before the next statement begins. diff --git a/src/stencil.jl b/src/stencil.jl index 6ebc96f9b..40154213c 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -21,6 +21,10 @@ end get_neigh_dist(neigh_dist::Integer, i::Int) = neigh_dist get_neigh_dist(neigh_dist::Tuple, i::Int) = neigh_dist[i] +# Get boundary condition for dimension i (supports single boundary condition or tuple of boundary conditions) +get_boundary(boundary, i::Int) = boundary +get_boundary(boundary::Tuple, i::Int) = boundary[i] + # Load a halo region from a neighboring chunk # region_code: N-tuple where each element is -1 (low), 0 (full extent), or +1 (high) # For dimensions with code 0, we take the full extent of the array @@ -45,85 +49,93 @@ function load_neighbor_region(arr, region_code::NTuple{N,Int}, neigh_dist) where # FIXME: Don't collect return move(task_processor(), collect(@view arr[start_idx:stop_idx])) end -function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary) - validate_neigh_dist(neigh_dist) - N = ndims(chunks) - # FIXME: Depends on neigh_dist and chunk size - chunk_dist = 1 +is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size))) - # Get the center - accesses = Any[chunks[idx]] +############################################################################# +# Boundary Condition Interface +############################################################################# +# +# To implement a custom boundary condition, define a struct and implement: +# +# Required: +# - boundary_has_transition(::MyBoundary) -> Bool +# - load_boundary_region(::MyBoundary, arr, region_code, neigh_dist, boundary_dims) +# +# Required if boundary_has_transition returns true: +# - boundary_transition(::MyBoundary, idx, size) -> CartesianIndex +# +# Required for mixed boundary condition support (when used in a tuple with other boundary conditions): +# - boundary_source_index(::MyBoundary, arr, rc, nd, idx_d, d) -> Int +# - apply_boundary_value(::MyBoundary, value, arr, rc, nd, idx_d, src_idx, d) [optional, default returns value unchanged] +# +############################################################################# + +# Default implementations for boundary_source_index and apply_boundary_value +# These are used when a boundary condition is part of a mixed boundary condition tuple - # Iterate over all 3^N - 1 halo regions (excluding center) - # Each region is identified by a code tuple where each element is -1, 0, or +1 - for i in 0:(3^N - 1) - region_code = ntuple(N) do d - ((i ÷ 3^(d-1)) % 3) - 1 # Maps 0,1,2 -> -1,0,+1 - end - all(==(0), region_code) && continue # Skip center +""" + boundary_source_index(boundary, arr, rc, nd, idx_d, d) -> Int - # Compute the chunk offset for this region - # For each dimension: -1 means go to previous chunk, +1 means go to next chunk, 0 means same chunk - chunk_offset = CartesianIndex(ntuple(N) do d - region_code[d] * chunk_dist - end) - new_idx = idx + chunk_offset +Compute the source index for dimension `d` when the boundary condition is used in a mixed boundary condition tuple. +- `boundary`: The boundary condition +- `arr`: The array being accessed +- `rc`: Region code for this dimension (-1, 0, or +1) +- `nd`: Neighborhood distance for this dimension +- `idx_d`: The index in the result array for this dimension +- `d`: The dimension number - if is_past_boundary(size(chunks), new_idx) - # Compute which dimensions are actually past boundary - boundary_dims = ntuple(N) do d - new_idx[d] < 1 || new_idx[d] > size(chunks)[d] - end - if boundary_has_transition(boundary) - new_idx = boundary_transition(boundary, new_idx, size(chunks)) - else - new_idx = idx - end - chunk = chunks[new_idx] - push!(accesses, Dagger.@spawn load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims)) - else - chunk = chunks[new_idx] - push!(accesses, Dagger.@spawn load_neighbor_region(chunk, region_code, neigh_dist)) - end - end +Default implementation clamps to valid array range. +""" +boundary_source_index(::Any, arr, rc, nd, idx_d, d) = clamp(idx_d, firstindex(arr, d), lastindex(arr, d)) - @assert length(accesses) == 3^N "Accesses mismatch: expected $(3^N), got $(length(accesses))" - return accesses -end -function build_halo(neigh_dist, boundary, center, all_halos...) - N = ndims(center) - expected_halos = 3^N - 1 - @assert length(all_halos) == expected_halos "Halo mismatch: N=$N expected $expected_halos halos, got $(length(all_halos))" - return HaloArray(center, (all_halos...,), ntuple(i->get_neigh_dist(neigh_dist, i), N)) -end -function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N} - start_idx = idx - CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr))) - stop_idx = idx + CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr))) - return @view arr[start_idx:stop_idx] -end -function inner_stencil!(f, output, read_vars) - processor = task_processor() - inner_stencil_proc!(processor, f, output, read_vars) -end -# Non-KA (for CPUs) -function inner_stencil_proc!(::ThreadProc, f, output, read_vars) - for idx in CartesianIndices(output) - f(idx, output, read_vars) - end - return -end +""" + apply_boundary_value(boundary, value, arr, rc, nd, idx_d, src_idx, d) + +Apply the boundary condition's value transformation for dimension `d` when used in a mixed boundary condition tuple. +- `boundary`: The boundary condition +- `value`: The current value from the source array +- `arr`: The array being accessed +- `rc`: Region code for this dimension (-1, 0, or +1) +- `nd`: Neighborhood distance for this dimension +- `idx_d`: The index in the result array for this dimension +- `src_idx`: The full source index tuple +- `d`: The dimension number + +Default implementation returns the value unchanged. +""" +apply_boundary_value(::Any, value, arr, rc, nd, idx_d, src_idx, d) = value -is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size))) +############################################################################# +# Wrap Boundary Condition +############################################################################# """ Wrap boundary condition. Non-local accesses wrap around to the other side of the array. """ struct Wrap end + boundary_has_transition(::Wrap) = true + boundary_transition(::Wrap, idx, size) = CartesianIndex(ntuple(i -> mod1(idx[i], size[i]), length(size))) -load_boundary_region(::Wrap, arr, region_code, neigh_dist, boundary_dims) = load_neighbor_region(arr, region_code, neigh_dist) + +load_boundary_region(::Wrap, arr, region_code, neigh_dist, boundary_dims) = + load_neighbor_region(arr, region_code, neigh_dist) + +function boundary_source_index(::Wrap, arr, rc, nd, idx_d, d) + if rc == -1 + return lastindex(arr, d) - nd + idx_d + elseif rc == +1 + return firstindex(arr, d) + idx_d - 1 + else + return idx_d + end +end + +############################################################################# +# Pad Boundary Condition +############################################################################# """ Pad boundary condition. Non-local accesses are padded with a specified value. @@ -131,7 +143,9 @@ Pad boundary condition. Non-local accesses are padded with a specified value. struct Pad{T} padval::T end + boundary_has_transition(::Pad) = false + function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N # Compute the size of this halo region # For dimensions with code 0, use full array size @@ -143,6 +157,197 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d return move(task_processor(), fill(pad.padval, region_size)) end +# Use edge as source index (value will be overridden by apply_boundary_value) +boundary_source_index(::Pad, arr, rc, nd, idx_d, d) = + rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d) + +# Override with pad value +apply_boundary_value(p::Pad, value, arr, rc, nd, idx_d, src_idx, d) = p.padval + +############################################################################# +# Clamp Boundary Condition +############################################################################# + +""" +Clamp boundary condition. Non-local accesses are clamped to the boundary value. +For example, an array [1,2,3,4] with neighborhood distance 2 would be extended as [1,1,1,2,3,4,4,4]. +""" +struct Clamp end + +boundary_has_transition(::Clamp) = true + +# Clamp to valid chunk indices - we stay at the boundary chunk +boundary_transition(::Clamp, idx, size) = + CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) + +function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N + # Compute the size of this halo region + region_size = ntuple(N) do i + region_code[i] == 0 ? size(arr, i) : get_neigh_dist(neigh_dist, i) + end + + result = similar(arr, region_size) + + for idx in CartesianIndices(result) + # Compute source index for each dimension + src_idx = CartesianIndex(ntuple(N) do i + nd = get_neigh_dist(neigh_dist, i) + if boundary_dims[i] && region_code[i] == -1 + # Low boundary - clamp to first element + firstindex(arr, i) + elseif boundary_dims[i] && region_code[i] == +1 + # High boundary - clamp to last element + lastindex(arr, i) + elseif region_code[i] == -1 + # Not at boundary but loading from low side of neighbor + lastindex(arr, i) - nd + idx[i] + elseif region_code[i] == +1 + # Not at boundary but loading from high side of neighbor + firstindex(arr, i) + idx[i] - 1 + else + # Full extent + idx[i] + end + end) + result[idx] = arr[src_idx] + end + + return move(task_processor(), result) +end + +function boundary_source_index(::Clamp, arr, rc, nd, idx_d, d) + if rc == -1 + return firstindex(arr, d) + elseif rc == +1 + return lastindex(arr, d) + else + return idx_d + end +end + +############################################################################# +# LinearExtrapolate Boundary Condition +############################################################################# + +""" +LinearExtrapolate boundary condition. Non-local accesses are extrapolated linearly +using the slope at the boundary. Only supports arrays with `Real` element types. + +For multi-dimensional arrays, extrapolation is applied along the first out-of-bounds +dimension only (other out-of-bounds dimensions are clamped). +""" +struct LinearExtrapolate end + +boundary_has_transition(::LinearExtrapolate) = true + +# Clamp to valid chunk indices - we stay at the boundary chunk +boundary_transition(::LinearExtrapolate, idx, size) = + CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) + +function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {T<:Real,N} + # Compute the size of this halo region + region_size = ntuple(N) do i + region_code[i] == 0 ? size(arr, i) : get_neigh_dist(neigh_dist, i) + end + + result = similar(arr, region_size) + + for idx in CartesianIndices(result) + # Find the first boundary dimension that needs extrapolation + extrap_dim = 0 + for d in 1:N + if boundary_dims[d] && region_code[d] != 0 + extrap_dim = d + break + end + end + + if extrap_dim == 0 + # No boundary dimensions - normal neighbor access + src_idx = CartesianIndex(ntuple(N) do i + nd = get_neigh_dist(neigh_dist, i) + if region_code[i] == -1 + lastindex(arr, i) - nd + idx[i] + elseif region_code[i] == +1 + firstindex(arr, i) + idx[i] - 1 + else + idx[i] + end + end) + result[idx] = arr[src_idx] + else + # Extrapolate along extrap_dim, clamp other boundary dimensions + nd = get_neigh_dist(neigh_dist, extrap_dim) + + # Compute base index (for other dimensions, clamp if at boundary) + base_idx = ntuple(N) do i + ndi = get_neigh_dist(neigh_dist, i) + if i == extrap_dim + # Will be set for slope computation + region_code[i] == -1 ? firstindex(arr, i) : lastindex(arr, i) + elseif boundary_dims[i] && region_code[i] == -1 + firstindex(arr, i) + elseif boundary_dims[i] && region_code[i] == +1 + lastindex(arr, i) + elseif region_code[i] == -1 + lastindex(arr, i) - ndi + idx[i] + elseif region_code[i] == +1 + firstindex(arr, i) + idx[i] - 1 + else + idx[i] + end + end + + # Compute slope at boundary + if region_code[extrap_dim] == -1 + # Low boundary: slope = arr[2] - arr[1] + idx1 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) : base_idx[i], N) + idx2 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) + 1 : base_idx[i], N) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = -(nd - idx[extrap_dim] + 1) + result[idx] = arr[CartesianIndex(idx1)] + slope * dist + else + # High boundary: slope = arr[end] - arr[end-1] + idx1 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) - 1 : base_idx[i], N) + idx2 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) : base_idx[i], N) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = idx[extrap_dim] + result[idx] = arr[CartesianIndex(idx2)] + slope * dist + end + end + end + + return move(task_processor(), result) +end + +# Use edge as source index (value will be computed by apply_boundary_value) +boundary_source_index(::LinearExtrapolate, arr, rc, nd, idx_d, d) = + rc == -1 ? firstindex(arr, d) : (rc == +1 ? lastindex(arr, d) : idx_d) + +function apply_boundary_value(::LinearExtrapolate, value, arr::AbstractArray{T}, rc, nd, idx_d, src_idx, d) where T<:Real + if rc == -1 + # Low boundary: extrapolate using slope from arr[1] to arr[2] + idx1 = ntuple(i -> i == d ? firstindex(arr, i) : src_idx[i], length(src_idx)) + idx2 = ntuple(i -> i == d ? firstindex(arr, i) + 1 : src_idx[i], length(src_idx)) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = -(nd - idx_d + 1) + return arr[CartesianIndex(idx1)] + slope * dist + elseif rc == +1 + # High boundary: extrapolate using slope from arr[end-1] to arr[end] + idx1 = ntuple(i -> i == d ? lastindex(arr, i) - 1 : src_idx[i], length(src_idx)) + idx2 = ntuple(i -> i == d ? lastindex(arr, i) : src_idx[i], length(src_idx)) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = idx_d + return arr[CartesianIndex(idx2)] + slope * dist + else + return value + end +end + +############################################################################# +# Reflect Boundary Condition +############################################################################# + """ Reflect boundary condition. Non-local accesses are reflected back into the array. If `symm` is true, the reflected values include the nearest center elements. @@ -150,10 +355,13 @@ If `symm` is false, the reflected values do not include the nearest center eleme """ struct Reflect{Symmetric} end Reflect(symm::Bool) = Reflect{symm}() + boundary_has_transition(::Reflect) = true + # Clamp to valid chunk indices - we stay at the boundary chunk boundary_transition(::Reflect, idx, size) = CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) + function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {N, Symm} # Only flip region_code for dimensions that are BOTH: # 1. Non-zero in region_code (we're accessing a neighbor in that dimension) @@ -213,6 +421,186 @@ function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int}, return region end +function boundary_source_index(::Reflect{Symm}, arr, rc, nd, idx_d, d) where Symm + skip = Symm ? 0 : 1 + if rc == -1 + # Reflecting from low boundary - source from start of array + return firstindex(arr, d) + skip + (nd - idx_d) + elseif rc == +1 + # Reflecting from high boundary - source from end of array + return lastindex(arr, d) - skip - (idx_d - 1) + else + return idx_d + end +end + +############################################################################# +# Mixed Boundary Conditions (Tuple of Boundary Conditions) +############################################################################# + +# Mixed boundary condition support: check if any dimension has a transition boundary condition +boundary_has_transition(boundary::Tuple) = any(boundary_has_transition, boundary) + +# Mixed boundary condition support: apply per-dimension transitions +function boundary_transition(boundary::Tuple, idx, size) + CartesianIndex(ntuple(length(size)) do i + dim_boundary = get_boundary(boundary, i) + if boundary_has_transition(dim_boundary) + # Apply the boundary condition's transition for this dimension only + single_idx = CartesianIndex(idx[i]) + single_size = (size[i],) + boundary_transition(dim_boundary, single_idx, single_size)[1] + else + # No transition - clamp to valid range (stay at current chunk) + clamp(idx[i], 1, size[i]) + end + end) +end + +# Internal helper: compute source index for a single dimension based on its boundary condition +function compute_source_index_for_dim(dim_boundary, arr, region_code, neigh_dist, boundary_dims, idx, d) + N = length(region_code) + nd = get_neigh_dist(neigh_dist, d) + + if !boundary_dims[d] + # Not at boundary - normal neighbor region access + if region_code[d] == -1 + return lastindex(arr, d) - nd + idx[d] + elseif region_code[d] == +1 + return firstindex(arr, d) + idx[d] - 1 + else + return idx[d] + end + end + + # At boundary - apply boundary condition-specific logic + return boundary_source_index(dim_boundary, arr, region_code[d], nd, idx[d], d) +end + +# Internal helper: compute the final value, handling special boundary conditions like Pad and LinearExtrapolate +function compute_boundary_value(boundary, arr, region_code, neigh_dist, boundary_dims, idx, src_idx) + N = length(region_code) + base_value = arr[CartesianIndex(src_idx)] + + # Check if any boundary dimension has a special boundary condition that overrides the value + for d in 1:N + if boundary_dims[d] && region_code[d] != 0 + dim_boundary = get_boundary(boundary, d) + base_value = apply_boundary_value(dim_boundary, base_value, arr, region_code[d], get_neigh_dist(neigh_dist, d), idx[d], src_idx, d) + end + end + + return base_value +end + +""" +Mixed boundary conditions. When a Tuple of boundary conditions is provided, each dimension uses its own boundary condition. +""" +function load_boundary_region(boundary::Tuple, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N + # Compute the size of this halo region + region_size = ntuple(N) do i + region_code[i] == 0 ? size(arr, i) : get_neigh_dist(neigh_dist, i) + end + + result = similar(arr, region_size) + + for idx in CartesianIndices(result) + # For each element, compute its value based on per-dimension boundary conditions + # Start by finding the source index in the array + src_idx = ntuple(N) do d + dim_boundary = get_boundary(boundary, d) + compute_source_index_for_dim(dim_boundary, arr, region_code, neigh_dist, boundary_dims, idx, d) + end + + # Compute the value using per-dimension logic + value = compute_boundary_value(boundary, arr, region_code, neigh_dist, boundary_dims, idx, src_idx) + result[idx] = value + end + + return move(task_processor(), result) +end + +############################################################################# +# Chunk Selection and Halo Building +############################################################################# + +function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary) + validate_neigh_dist(neigh_dist) + + N = ndims(chunks) + # FIXME: Depends on neigh_dist and chunk size + chunk_dist = 1 + + # Get the center + accesses = Any[chunks[idx]] + + # Iterate over all 3^N - 1 halo regions (excluding center) + # Each region is identified by a code tuple where each element is -1, 0, or +1 + for i in 0:(3^N - 1) + region_code = ntuple(N) do d + ((i ÷ 3^(d-1)) % 3) - 1 # Maps 0,1,2 -> -1,0,+1 + end + all(==(0), region_code) && continue # Skip center + + # Compute the chunk offset for this region + # For each dimension: -1 means go to previous chunk, +1 means go to next chunk, 0 means same chunk + chunk_offset = CartesianIndex(ntuple(N) do d + region_code[d] * chunk_dist + end) + new_idx = idx + chunk_offset + + if is_past_boundary(size(chunks), new_idx) + # Compute which dimensions are actually past boundary + boundary_dims = ntuple(N) do d + new_idx[d] < 1 || new_idx[d] > size(chunks)[d] + end + if boundary_has_transition(boundary) + new_idx = boundary_transition(boundary, new_idx, size(chunks)) + else + new_idx = idx + end + chunk = chunks[new_idx] + push!(accesses, Dagger.@spawn load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims)) + else + chunk = chunks[new_idx] + push!(accesses, Dagger.@spawn load_neighbor_region(chunk, region_code, neigh_dist)) + end + end + + @assert length(accesses) == 3^N "Accesses mismatch: expected $(3^N), got $(length(accesses))" + return accesses +end + +function build_halo(neigh_dist, boundary, center, all_halos...) + N = ndims(center) + expected_halos = 3^N - 1 + @assert length(all_halos) == expected_halos "Halo mismatch: N=$N expected $expected_halos halos, got $(length(all_halos))" + return HaloArray(center, (all_halos...,), ntuple(i->get_neigh_dist(neigh_dist, i), N)) +end + +function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N} + start_idx = idx - CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr))) + stop_idx = idx + CartesianIndex(ntuple(i->arr.halo_width[i], ndims(arr))) + return @view arr[start_idx:stop_idx] +end + +function inner_stencil!(f, output, read_vars) + processor = task_processor() + inner_stencil_proc!(processor, f, output, read_vars) +end + +# Non-KA (for CPUs) +function inner_stencil_proc!(::ThreadProc, f, output, read_vars) + for idx in CartesianIndices(output) + f(idx, output, read_vars) + end + return +end + +############################################################################# +# @stencil Macro +############################################################################# + """ @stencil begin body end diff --git a/test/array/stencil.jl b/test/array/stencil.jl index cfd2896f3..a707e6157 100644 --- a/test/array/stencil.jl +++ b/test/array/stencil.jl @@ -1,4 +1,4 @@ -import Dagger: @stencil, Wrap, Pad, Reflect +import Dagger: @stencil, Wrap, Pad, Reflect, Clamp, LinearExtrapolate function test_stencil() @testset "Simple assignment" begin @@ -67,6 +67,175 @@ function test_stencil() @test collect(B) == expected_B_pad end + @testset "Clamp boundary" begin + # Test clamping to boundary values + # For A = [1, 2, 3, 4] with Clamp(): + # idx=0 → 1, idx=-1 → 1 (clamp to first element) + # idx=5 → 4, idx=6 → 4 (clamp to last element) + A = DArray([1, 2, 3, 4], Blocks(2)) + B = zeros(Blocks(2), Int, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, Clamp())) + end + end + # B[1]: neighbors at indices 0, 1, 2 -> clamped 0 becomes 1, so [1, 1, 2] = 4 + # B[2]: neighbors at indices 1, 2, 3 -> [1, 2, 3] = 6 + # B[3]: neighbors at indices 2, 3, 4 -> [2, 3, 4] = 9 + # B[4]: neighbors at indices 3, 4, 5 -> clamped 5 becomes 4, so [3, 4, 4] = 11 + expected_B_clamp = [4, 6, 9, 11] + @test collect(B) == expected_B_clamp + end + + @testset "Clamp boundary 2D" begin + # Test 2D clamping with a gradient pattern + A = DArray(reshape(1:16, 4, 4), Blocks(2, 2)) + B = zeros(Blocks(2, 2), Int, 4, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, Clamp())) + end + end + A_collected = collect(A) + expected_B_clamp = zeros(Int, 4, 4) + for i in 1:4, j in 1:4 + sum_val = 0 + for di in -1:1, dj in -1:1 + ni, nj = i + di, j + dj + # Apply clamp logic + ni = clamp(ni, 1, 4) + nj = clamp(nj, 1, 4) + sum_val += A_collected[ni, nj] + end + expected_B_clamp[i, j] = sum_val + end + @test collect(B) == expected_B_clamp + end + + @testset "LinearExtrapolate boundary" begin + # Test linear extrapolation using slope at boundary + # For A = [2.0, 4.0, 6.0, 8.0] with LinearExtrapolate(): + # slope at low boundary = 4.0 - 2.0 = 2.0 + # slope at high boundary = 8.0 - 6.0 = 2.0 + # idx=0 → 2.0 + 2.0*(-1) = 0.0 + # idx=5 → 8.0 + 2.0*(1) = 10.0 + A = DArray([2.0, 4.0, 6.0, 8.0], Blocks(2)) + B = zeros(Blocks(2), Float64, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, LinearExtrapolate())) + end + end + # B[1]: neighbors at indices 0, 1, 2 -> extrapolated 0 becomes 0.0, so [0.0, 2.0, 4.0] = 6.0 + # B[2]: neighbors at indices 1, 2, 3 -> [2.0, 4.0, 6.0] = 12.0 + # B[3]: neighbors at indices 2, 3, 4 -> [4.0, 6.0, 8.0] = 18.0 + # B[4]: neighbors at indices 3, 4, 5 -> extrapolated 5 becomes 10.0, so [6.0, 8.0, 10.0] = 24.0 + expected_B_extrap = [6.0, 12.0, 18.0, 24.0] + @test collect(B) ≈ expected_B_extrap + end + + @testset "LinearExtrapolate boundary 2D" begin + # Test 2D linear extrapolation with a gradient pattern + A = DArray(Float64.(reshape(1:16, 4, 4)), Blocks(2, 2)) + B = zeros(Blocks(2, 2), Float64, 4, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, LinearExtrapolate())) + end + end + A_collected = collect(A) + expected_B_extrap = zeros(Float64, 4, 4) + for i in 1:4, j in 1:4 + sum_val = 0.0 + for di in -1:1, dj in -1:1 + ni, nj = i + di, j + dj + val = 0.0 + # Apply linear extrapolation logic for each dimension + if ni < 1 + # Low boundary in dim 1: extrapolate using slope from A[1,:] to A[2,:] + base_nj = clamp(nj, 1, 4) + slope = A_collected[2, base_nj] - A_collected[1, base_nj] + val = A_collected[1, base_nj] + slope * (ni - 1) + elseif ni > 4 + # High boundary in dim 1: extrapolate using slope from A[3,:] to A[4,:] + base_nj = clamp(nj, 1, 4) + slope = A_collected[4, base_nj] - A_collected[3, base_nj] + val = A_collected[4, base_nj] + slope * (ni - 4) + elseif nj < 1 + # Low boundary in dim 2: extrapolate using slope from A[:,1] to A[:,2] + slope = A_collected[ni, 2] - A_collected[ni, 1] + val = A_collected[ni, 1] + slope * (nj - 1) + elseif nj > 4 + # High boundary in dim 2: extrapolate using slope from A[:,3] to A[:,4] + slope = A_collected[ni, 4] - A_collected[ni, 3] + val = A_collected[ni, 4] + slope * (nj - 4) + else + val = A_collected[ni, nj] + end + sum_val += val + end + expected_B_extrap[i, j] = sum_val + end + @test collect(B) ≈ expected_B_extrap + end + + @testset "Mixed boundary conditions" begin + # Test different BCs per dimension using a Tuple + # Use Wrap in dimension 1 and Pad(0) in dimension 2 + A = DArray(reshape(1:16, 4, 4), Blocks(2, 2)) + B = zeros(Blocks(2, 2), Int, 4, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, (Wrap(), Pad(0)))) + end + end + A_collected = collect(A) + expected_B_mixed = zeros(Int, 4, 4) + for i in 1:4, j in 1:4 + sum_val = 0 + for di in -1:1, dj in -1:1 + # Dim 1: Wrap + ni = mod1(i + di, 4) + # Dim 2: Pad(0) + nj = j + dj + if nj < 1 || nj > 4 + # Padded with 0 + sum_val += 0 + else + sum_val += A_collected[ni, nj] + end + end + expected_B_mixed[i, j] = sum_val + end + @test collect(B) == expected_B_mixed + end + + @testset "Mixed boundary conditions (Clamp, Reflect)" begin + # Test Clamp in dimension 1 and Reflect(true) in dimension 2 + A = DArray(reshape(1:16, 4, 4), Blocks(2, 2)) + B = zeros(Blocks(2, 2), Int, 4, 4) + Dagger.spawn_datadeps() do + @stencil begin + B[idx] = sum(@neighbors(A[idx], 1, (Clamp(), Reflect(true)))) + end + end + A_collected = collect(A) + expected_B_mixed = zeros(Int, 4, 4) + for i in 1:4, j in 1:4 + sum_val = 0 + for di in -1:1, dj in -1:1 + # Dim 1: Clamp + ni = clamp(i + di, 1, 4) + # Dim 2: Reflect(true) - symmetric + nj = j + dj + nj = nj < 1 ? 1 - nj : (nj > 4 ? 2*4 + 1 - nj : nj) + sum_val += A_collected[ni, nj] + end + expected_B_mixed[i, j] = sum_val + end + @test collect(B) == expected_B_mixed + end + @testset "Reflect boundary (symmetric)" begin # Test symmetric reflection (edge element IS included/repeated) # For A = [1, 2, 3, 4] with Reflect(true): From 87bcb175b80702b3c90b521541afb215e1503cce Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 6 Feb 2026 10:25:34 -0500 Subject: [PATCH 2/4] fixup! stencils: Support mixed boundaries, add Clamp and LinearExtrapolate --- src/Dagger.jl | 4 +- src/stencil.jl | 182 ++++++++++++++++++++++++++----------------------- 2 files changed, 100 insertions(+), 86 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index ff6161aa4..edb93a7c9 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -46,6 +46,8 @@ end import MacroTools: @capture, prewalk +import KernelAbstractions, Adapt + include("lib/util.jl") include("utils/dagdebug.jl") @@ -122,8 +124,6 @@ include("array/mul.jl") include("array/cholesky.jl") include("array/lu.jl") -import KernelAbstractions, Adapt - # GPU include("gpu.jl") diff --git a/src/stencil.jl b/src/stencil.jl index 40154213c..7015b475b 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -180,6 +180,34 @@ boundary_has_transition(::Clamp) = true boundary_transition(::Clamp, idx, size) = CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) +KernelAbstractions.@kernel function load_boundary_region_kernel(::Clamp, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N + raw_idx = KernelAbstractions.@index(Global) + + # Convert linear index to Cartesian index + idx = CartesianIndices(result)[raw_idx] + + # Compute source index for each dimension + src_idx = CartesianIndex(ntuple(N) do i + nd = get_neigh_dist(neigh_dist, i) + if boundary_dims[i] && region_code[i] == -1 + # Low boundary - clamp to first element + firstindex(arr, i) + elseif boundary_dims[i] && region_code[i] == +1 + # High boundary - clamp to last element + lastindex(arr, i) + elseif region_code[i] == -1 + # Not at boundary but loading from low side of neighbor + lastindex(arr, i) - nd + idx[i] + elseif region_code[i] == +1 + # Not at boundary but loading from high side of neighbor + firstindex(arr, i) + idx[i] - 1 + else + # Full extent + idx[i] + end + end) + result[idx] = arr[src_idx] +end function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N # Compute the size of this halo region region_size = ntuple(N) do i @@ -188,29 +216,7 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di result = similar(arr, region_size) - for idx in CartesianIndices(result) - # Compute source index for each dimension - src_idx = CartesianIndex(ntuple(N) do i - nd = get_neigh_dist(neigh_dist, i) - if boundary_dims[i] && region_code[i] == -1 - # Low boundary - clamp to first element - firstindex(arr, i) - elseif boundary_dims[i] && region_code[i] == +1 - # High boundary - clamp to last element - lastindex(arr, i) - elseif region_code[i] == -1 - # Not at boundary but loading from low side of neighbor - lastindex(arr, i) - nd + idx[i] - elseif region_code[i] == +1 - # Not at boundary but loading from high side of neighbor - firstindex(arr, i) + idx[i] - 1 - else - # Full extent - idx[i] - end - end) - result[idx] = arr[src_idx] - end + Kernel(load_boundary_region_kernel)(Clamp(), result, arr, region_code, neigh_dist, boundary_dims; ndrange=size(result)) return move(task_processor(), result) end @@ -244,6 +250,66 @@ boundary_has_transition(::LinearExtrapolate) = true boundary_transition(::LinearExtrapolate, idx, size) = CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) +KernelAbstractions.@kernel function load_boundary_region_kernel(::LinearExtrapolate, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}, ::Val{extrap_dim}, ::Val{nd}) where {N,extrap_dim,nd} + raw_idx = KernelAbstractions.@index(Global) + + # Convert linear index to Cartesian index + idx = CartesianIndices(result)[raw_idx] + + if extrap_dim == 0 + # No boundary dimensions - normal neighbor access + src_idx = CartesianIndex(ntuple(Val(N)) do i + ndi = get_neigh_dist(neigh_dist, i)::Int + if region_code[i] == -1 + lastindex(arr, i) - ndi + idx[i] + elseif region_code[i] == +1 + firstindex(arr, i) + idx[i] - 1 + else + idx[i] + end + end) + result[idx] = arr[src_idx] + else + # Extrapolate along extrap_dim, clamp other boundary dimensions + #nd = get_neigh_dist(neigh_dist, extrap_dim)::Int + + # Compute base index (for other dimensions, clamp if at boundary) + base_idx = ntuple(Val(N)) do i + ndi = get_neigh_dist(neigh_dist, i) + if i == extrap_dim + # Will be set for slope computation + region_code[i] == -1 ? firstindex(arr, i) : lastindex(arr, i) + elseif boundary_dims[i] && region_code[i] == -1 + firstindex(arr, i) + elseif boundary_dims[i] && region_code[i] == +1 + lastindex(arr, i) + elseif region_code[i] == -1 + lastindex(arr, i) - ndi + idx[i] + elseif region_code[i] == +1 + firstindex(arr, i) + idx[i] - 1 + else + idx[i] + end + end + + # Compute slope at boundary + if region_code[extrap_dim] == -1 + # Low boundary: slope = arr[2] - arr[1] + idx1 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) : base_idx[i], Val(N)) + idx2 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) + 1 : base_idx[i], Val(N)) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = -(nd - idx[extrap_dim] + 1) + result[idx] = arr[CartesianIndex(idx1)] + slope * dist + else + # High boundary: slope = arr[end] - arr[end-1] + idx1 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) - 1 : base_idx[i], Val(N)) + idx2 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) : base_idx[i], Val(N)) + slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] + dist = idx[extrap_dim] + result[idx] = arr[CartesianIndex(idx2)] + slope * dist + end + end +end function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {T<:Real,N} # Compute the size of this halo region region_size = ntuple(N) do i @@ -252,70 +318,18 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region result = similar(arr, region_size) - for idx in CartesianIndices(result) - # Find the first boundary dimension that needs extrapolation - extrap_dim = 0 - for d in 1:N - if boundary_dims[d] && region_code[d] != 0 - extrap_dim = d - break - end + # Find the first boundary dimension that needs extrapolation + extrap_dim = 0 + for d in 1:N + if boundary_dims[d] && region_code[d] != 0 + extrap_dim = d + break end + end - if extrap_dim == 0 - # No boundary dimensions - normal neighbor access - src_idx = CartesianIndex(ntuple(N) do i - nd = get_neigh_dist(neigh_dist, i) - if region_code[i] == -1 - lastindex(arr, i) - nd + idx[i] - elseif region_code[i] == +1 - firstindex(arr, i) + idx[i] - 1 - else - idx[i] - end - end) - result[idx] = arr[src_idx] - else - # Extrapolate along extrap_dim, clamp other boundary dimensions - nd = get_neigh_dist(neigh_dist, extrap_dim) - - # Compute base index (for other dimensions, clamp if at boundary) - base_idx = ntuple(N) do i - ndi = get_neigh_dist(neigh_dist, i) - if i == extrap_dim - # Will be set for slope computation - region_code[i] == -1 ? firstindex(arr, i) : lastindex(arr, i) - elseif boundary_dims[i] && region_code[i] == -1 - firstindex(arr, i) - elseif boundary_dims[i] && region_code[i] == +1 - lastindex(arr, i) - elseif region_code[i] == -1 - lastindex(arr, i) - ndi + idx[i] - elseif region_code[i] == +1 - firstindex(arr, i) + idx[i] - 1 - else - idx[i] - end - end + nd = get_neigh_dist(neigh_dist, extrap_dim) - # Compute slope at boundary - if region_code[extrap_dim] == -1 - # Low boundary: slope = arr[2] - arr[1] - idx1 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) : base_idx[i], N) - idx2 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) + 1 : base_idx[i], N) - slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] - dist = -(nd - idx[extrap_dim] + 1) - result[idx] = arr[CartesianIndex(idx1)] + slope * dist - else - # High boundary: slope = arr[end] - arr[end-1] - idx1 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) - 1 : base_idx[i], N) - idx2 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) : base_idx[i], N) - slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)] - dist = idx[extrap_dim] - result[idx] = arr[CartesianIndex(idx2)] + slope * dist - end - end - end + Kernel(load_boundary_region_kernel)(LinearExtrapolate(), result, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=size(result)) return move(task_processor(), result) end From 44ca6177abe9b057da3af3afa3af1866e46f3a7c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 6 Feb 2026 11:54:44 -0500 Subject: [PATCH 3/4] fixup! fixup! stencils: Support mixed boundaries, add Clamp and LinearExtrapolate --- src/Dagger.jl | 4 +++- src/stencil.jl | 16 +++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index edb93a7c9..740d941a2 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -46,7 +46,9 @@ end import MacroTools: @capture, prewalk -import KernelAbstractions, Adapt +import KernelAbstractions +import KernelAbstractions: @kernel, @index +import Adapt include("lib/util.jl") include("utils/dagdebug.jl") diff --git a/src/stencil.jl b/src/stencil.jl index 7015b475b..767676096 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -180,8 +180,8 @@ boundary_has_transition(::Clamp) = true boundary_transition(::Clamp, idx, size) = CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) -KernelAbstractions.@kernel function load_boundary_region_kernel(::Clamp, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N - raw_idx = KernelAbstractions.@index(Global) +@kernel function load_boundary_region_kernel(::Clamp, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N + raw_idx = @index(Global, Linear) # Convert linear index to Cartesian index idx = CartesianIndices(result)[raw_idx] @@ -216,7 +216,7 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di result = similar(arr, region_size) - Kernel(load_boundary_region_kernel)(Clamp(), result, arr, region_code, neigh_dist, boundary_dims; ndrange=size(result)) + Kernel(load_boundary_region_kernel)(Clamp(), result, arr, region_code, neigh_dist, boundary_dims; ndrange=length(result)) return move(task_processor(), result) end @@ -250,8 +250,8 @@ boundary_has_transition(::LinearExtrapolate) = true boundary_transition(::LinearExtrapolate, idx, size) = CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size))) -KernelAbstractions.@kernel function load_boundary_region_kernel(::LinearExtrapolate, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}, ::Val{extrap_dim}, ::Val{nd}) where {N,extrap_dim,nd} - raw_idx = KernelAbstractions.@index(Global) +@kernel function load_boundary_region_kernel(::LinearExtrapolate, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}, ::Val{extrap_dim}, ::Val{nd}) where {N,extrap_dim,nd} + raw_idx = @index(Global, Linear) # Convert linear index to Cartesian index idx = CartesianIndices(result)[raw_idx] @@ -270,9 +270,6 @@ KernelAbstractions.@kernel function load_boundary_region_kernel(::LinearExtrapol end) result[idx] = arr[src_idx] else - # Extrapolate along extrap_dim, clamp other boundary dimensions - #nd = get_neigh_dist(neigh_dist, extrap_dim)::Int - # Compute base index (for other dimensions, clamp if at boundary) base_idx = ntuple(Val(N)) do i ndi = get_neigh_dist(neigh_dist, i) @@ -327,9 +324,10 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region end end + # Extrapolate along extrap_dim, clamp other boundary dimensions nd = get_neigh_dist(neigh_dist, extrap_dim) - Kernel(load_boundary_region_kernel)(LinearExtrapolate(), result, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=size(result)) + Kernel(load_boundary_region_kernel)(LinearExtrapolate(), result, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=length(result)) return move(task_processor(), result) end From 49dbb606f41ef430910ba5ffb20bf6430230da33 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 6 Feb 2026 13:31:48 -0500 Subject: [PATCH 4/4] fixup! fixup! fixup! stencils: Support mixed boundaries, add Clamp and LinearExtrapolate --- src/stencil.jl | 55 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/src/stencil.jl b/src/stencil.jl index 767676096..555add31d 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -505,6 +505,48 @@ function compute_boundary_value(boundary, arr, region_code, neigh_dist, boundary return base_value end +# GPU-compatible helper: recursively apply boundary value transformations for mixed boundaries. +# Uses Val{d} to ensure boundary[d] is resolved at compile time (avoiding type instability +# from indexing a heterogeneous Tuple with a runtime variable). +@inline function _fold_boundary_value(boundary::Tuple, base_value, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}, idx, src_idx, ::Val{d}) where {N, d} + if boundary_dims[d] && region_code[d] != 0 + dim_boundary = boundary[d] + base_value = apply_boundary_value(dim_boundary, base_value, arr, region_code[d], get_neigh_dist(neigh_dist, d), idx[d], src_idx, d) + end + if d < N + return _fold_boundary_value(boundary, base_value, arr, region_code, neigh_dist, boundary_dims, idx, src_idx, Val(d + 1)) + end + return base_value +end + +@kernel function load_boundary_region_kernel(boundary::B, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {B<:Tuple, N} + raw_idx = @index(Global, Linear) + + # Convert linear index to Cartesian index + idx = CartesianIndices(result)[raw_idx] + + # Compute source index for each dimension + src_idx = ntuple(Val(N)) do d + dim_boundary = boundary[d] + nd = get_neigh_dist(neigh_dist, d) + if !boundary_dims[d] + if region_code[d] == -1 + lastindex(arr, d) - nd + idx[d] + elseif region_code[d] == +1 + firstindex(arr, d) + idx[d] - 1 + else + idx[d] + end + else + boundary_source_index(dim_boundary, arr, region_code[d], nd, idx[d], d) + end + end + + # Get base value and apply boundary transformations dimension by dimension + base_value = arr[CartesianIndex(src_idx)] + result[idx] = _fold_boundary_value(boundary, base_value, arr, region_code, neigh_dist, boundary_dims, idx, src_idx, Val(1)) +end + """ Mixed boundary conditions. When a Tuple of boundary conditions is provided, each dimension uses its own boundary condition. """ @@ -516,18 +558,7 @@ function load_boundary_region(boundary::Tuple, arr, region_code::NTuple{N,Int}, result = similar(arr, region_size) - for idx in CartesianIndices(result) - # For each element, compute its value based on per-dimension boundary conditions - # Start by finding the source index in the array - src_idx = ntuple(N) do d - dim_boundary = get_boundary(boundary, d) - compute_source_index_for_dim(dim_boundary, arr, region_code, neigh_dist, boundary_dims, idx, d) - end - - # Compute the value using per-dimension logic - value = compute_boundary_value(boundary, arr, region_code, neigh_dist, boundary_dims, idx, src_idx) - result[idx] = value - end + Kernel(load_boundary_region_kernel)(boundary, result, arr, region_code, neigh_dist, boundary_dims; ndrange=length(result)) return move(task_processor(), result) end