diff --git a/src/array/darray.jl b/src/array/darray.jl index 20722b8ed..8d8b5b663 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -608,12 +608,31 @@ function logs_annotate!(ctx::Context, A::DArray, name::Union{String,Symbol}) end end -# TODO: Allow `f` to return proc -mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle))) -function mapchunks(f, d::DArray{T,N,F}) where {T,N,F} - chunks = map(d.chunks) do chunk - owner = get_parent(chunk.processor).pid - remotecall_fetch(mapchunk, owner, f, chunk) +_unwrap_type(::Type{Union{}}) = Any +_unwrap_type(T::Union) = mapreduce(_unwrap_type, Base.promote_type, Base.uniontypes(T)) +_unwrap_type(T::Type{<:AbstractArray}) = eltype(T) +_unwrap_type(T::Type) = T + +function _mapchunks_eltype(chunks, f, d::DArray{T,N}) where {T,N} + if isempty(chunks) + RT = Base._return_type(f, Tuple{AbstractArray{T,N}}) + return RT === Any ? T : _unwrap_type(RT) end - DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat) + + # promote types across all chunks + types = [_unwrap_type(chunktype(c)) for c in chunks] + return any(==(Any), types) ? Any : mapreduce(identity, Base.promote_type, types) +end + +_spawn_options(::Any) = Options(meta=false) +_spawn_options(c::Chunk) = Options(scope=ProcessScope(get_parent(c.processor).pid), meta=false) +_spawn_options(t::Thunk) = !isnothing(t.affinity) ? + Options(scope=ProcessScope(t.affinity.first), meta=false) : + Options(meta=false) + +function mapchunks(f, d::DArray) + new_chunks = map(c -> Dagger.spawn(f, _spawn_options(c), c), d.chunks) + new_eltype = _mapchunks_eltype(new_chunks, f, d) + + return DArray(new_eltype, d.domain, d.subdomains, new_chunks, d.partitioning, d.concat) end diff --git a/test/array/core.jl b/test/array/core.jl index 0cdcb0456..c399c68c5 100644 --- a/test/array/core.jl +++ b/test/array/core.jl @@ -128,3 +128,38 @@ using MemPool @test (aff[1]).pid in procs() @test aff[2] == sizeof(Int)*10 end + +@testset "mapchunks" begin + A = DArray(reshape(1:16, 4, 4), Blocks(2, 2)) + B = mapchunks(x -> Float32.(x), A) + @test B isa DArray{Float32,2} + @test collect(B) == Float32.(collect(A)) + + D = mapchunks(x -> (x isa AbstractArray ? fill(true, size(x)) : fill(false, size(x))), A) + @test all(collect(D)) + + task_chunks = map(c -> Dagger.spawn(identity, c), chunks(A)) + A_tasks = DArray(eltype(A), domain(A), domainchunks(A), task_chunks, A.partitioning, A.concat) + C = mapchunks(x -> x .+ 1, A_tasks) + @test collect(C) == collect(A) .+ 1 + + # heterogeneous chunk types + domain_h = ArrayDomain(1:4) + subdomains_h = partition(Blocks(2), domain_h) + chunks_h = reshape(Any[Dagger.tochunk([1, 2]),Dagger.tochunk([3.0, 4.0])], size(subdomains_h)) + A_hetero = DArray(Any, domain_h, subdomains_h, chunks_h, Blocks(2), cat) + B_hetero = mapchunks(identity, A_hetero) + @test eltype(B_hetero) == Float64 + @test collect(B_hetero) == Float64.([1, 2, 3, 4]) + + A0 = DArray{Int}(undef, Blocks(1), (0,)) + B0 = mapchunks(x -> x .+ 1, A0) + @test size(B0) == (0,) + @test eltype(B0) == Int + + # scalar return + S = mapchunks(sum, A) + @test eltype(S) == Int + expected = [sum(A[1:2, 1:2]) sum(A[1:2, 3:4]); sum(A[3:4, 1:2]) sum(A[3:4, 3:4])] + @test collect(S) == expected +end diff --git a/test/imports.jl b/test/imports.jl index dad3d02b2..d7cd96747 100644 --- a/test/imports.jl +++ b/test/imports.jl @@ -1,5 +1,5 @@ using LinearAlgebra, SparseArrays, Random, SharedArrays -import Dagger: DArray, chunks, domainchunks, treereduce_nd +import Dagger: DArray, chunks, domainchunks, treereduce_nd, mapchunks import Distributed: myid, procs import Statistics: mean, var, std import OnlineStats