Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,31 @@ function logs_annotate!(ctx::Context, A::DArray, name::Union{String,Symbol})
end
end

# TODO: Allow `f` to return proc
mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle)))
function mapchunks(f, d::DArray{T,N,F}) where {T,N,F}
chunks = map(d.chunks) do chunk
owner = get_parent(chunk.processor).pid
remotecall_fetch(mapchunk, owner, f, chunk)
_unwrap_type(::Type{Union{}}) = Any
_unwrap_type(T::Union) = mapreduce(_unwrap_type, Base.promote_type, Base.uniontypes(T))
_unwrap_type(T::Type{<:AbstractArray}) = eltype(T)
_unwrap_type(T::Type) = T

function _mapchunks_eltype(chunks, f, d::DArray{T,N}) where {T,N}
if isempty(chunks)
RT = Base._return_type(f, Tuple{AbstractArray{T,N}})
return RT === Any ? T : _unwrap_type(RT)
end
DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat)

# promote types across all chunks
types = [_unwrap_type(chunktype(c)) for c in chunks]
return any(==(Any), types) ? Any : mapreduce(identity, Base.promote_type, types)
end

_spawn_options(::Any) = Options(meta=false)
_spawn_options(c::Chunk) = Options(scope=ProcessScope(get_parent(c.processor).pid), meta=false)
_spawn_options(t::Thunk) = !isnothing(t.affinity) ?
Options(scope=ProcessScope(t.affinity.first), meta=false) :
Options(meta=false)

function mapchunks(f, d::DArray)
new_chunks = map(c -> Dagger.spawn(f, _spawn_options(c), c), d.chunks)
new_eltype = _mapchunks_eltype(new_chunks, f, d)

return DArray(new_eltype, d.domain, d.subdomains, new_chunks, d.partitioning, d.concat)
end
35 changes: 35 additions & 0 deletions test/array/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,38 @@ using MemPool
@test (aff[1]).pid in procs()
@test aff[2] == sizeof(Int)*10
end

@testset "mapchunks" begin
A = DArray(reshape(1:16, 4, 4), Blocks(2, 2))
B = mapchunks(x -> Float32.(x), A)
@test B isa DArray{Float32,2}
@test collect(B) == Float32.(collect(A))

D = mapchunks(x -> (x isa AbstractArray ? fill(true, size(x)) : fill(false, size(x))), A)
@test all(collect(D))

task_chunks = map(c -> Dagger.spawn(identity, c), chunks(A))
A_tasks = DArray(eltype(A), domain(A), domainchunks(A), task_chunks, A.partitioning, A.concat)
C = mapchunks(x -> x .+ 1, A_tasks)
@test collect(C) == collect(A) .+ 1

# heterogeneous chunk types
domain_h = ArrayDomain(1:4)
subdomains_h = partition(Blocks(2), domain_h)
chunks_h = reshape(Any[Dagger.tochunk([1, 2]),Dagger.tochunk([3.0, 4.0])], size(subdomains_h))
A_hetero = DArray(Any, domain_h, subdomains_h, chunks_h, Blocks(2), cat)
B_hetero = mapchunks(identity, A_hetero)
@test eltype(B_hetero) == Float64
@test collect(B_hetero) == Float64.([1, 2, 3, 4])

A0 = DArray{Int}(undef, Blocks(1), (0,))
B0 = mapchunks(x -> x .+ 1, A0)
@test size(B0) == (0,)
@test eltype(B0) == Int

# scalar return
S = mapchunks(sum, A)
@test eltype(S) == Int
expected = [sum(A[1:2, 1:2]) sum(A[1:2, 3:4]); sum(A[3:4, 1:2]) sum(A[3:4, 3:4])]
@test collect(S) == expected
end
2 changes: 1 addition & 1 deletion test/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LinearAlgebra, SparseArrays, Random, SharedArrays
import Dagger: DArray, chunks, domainchunks, treereduce_nd
import Dagger: DArray, chunks, domainchunks, treereduce_nd, mapchunks
import Distributed: myid, procs
import Statistics: mean, var, std
import OnlineStats
Loading