Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions ext/IntelExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function Dagger.memory_space(x::oneArray)
return IntelVRAMMemorySpace(myid(), device_id)
end
_device_id(dev::ZeDevice) = findfirst(other_dev->other_dev === dev, collect(oneAPI.devices()))
_device_id(x::oneArray) = _device_id(oneAPI.device(x))
function Dagger.aliasing(x::oneArray{T}) where T
space = Dagger.memory_space(x)
S = typeof(space)
Expand All @@ -65,7 +66,7 @@ end
function with_context!(device_id::Integer)
driver!(DRIVERS[device_id])
device!(DEVICES[device_id])
context!(CONTEXTS[device_id])
context!(_get_context(DRIVERS[device_id]))
end
function with_context!(proc::oneArrayDeviceProc)
@assert Dagger.root_worker_id(proc) == myid()
Expand All @@ -92,6 +93,26 @@ function with_context(f, x)
end
end

function _ensure_device(x::oneArray, device_id::Int)
dev = DEVICES[device_id]
if oneAPI.device(x) === dev
return x
end
src_dev_id = _device_id(x)
@assert src_dev_id !== nothing "Unknown source device for oneArray"
# Ensure any pending work on the source device is complete before DtoD copy
Comment thread
fda-tome marked this conversation as resolved.
# N.B. oneAPI does not synchronize the source in `copyto!`, making this necessary
with_context(src_dev_id) do
oneAPI.synchronize()
end
return with_context(device_id) do
arr = similar(x)
copyto!(arr, x) # direct DtoD within shared context
oneAPI.synchronize()
return arr
end
end

function _sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace})
with_context(x) do
oneAPI.synchronize()
Expand Down Expand Up @@ -159,15 +180,7 @@ function Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, x::Chunk)
end
end
function Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, x::oneArray)
if oneAPI.device(x) == to_device(to_proc)
return x
end
with_context(to_proc) do
_x = similar(x)
copyto!(_x, x)
oneAPI.synchronize()
return _x
end
return _ensure_device(x, to_proc.device_id)
end

# Out-of-place DtoH
Expand Down Expand Up @@ -206,46 +219,43 @@ function Dagger.move(from_proc::oneArrayDeviceProc, to_proc::oneArrayDeviceProc,
with_context(oneAPI.synchronize, from_proc)
return arr
elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
# Same process but different GPUs, use DtoD copy
# Same process but different GPUs, rehome to target device
from_arr = unwrap(x)
with_context(oneAPI.synchronize, from_proc)
return with_context(to_proc) do
to_arr = similar(from_arr)
copyto!(to_arr, from_arr)
oneAPI.synchronize()
return to_arr
end
return _ensure_device(from_arr, to_proc.device_id)
else
# Different node, use DtoH, serialization, HtoD
return oneArray(remotecall_fetch(from_proc.owner, x) do x
host_copy = remotecall_fetch(from_proc.owner, x) do x
Array(unwrap(x))
end)
end
return with_context(to_proc) do
oneArray(host_copy)
end
end
end

function Dagger.move(from_proc::oneArrayDeviceProc, to_proc::oneArrayDeviceProc, x::oneArray)
return _ensure_device(x, to_proc.device_id)
end

# Adapt generic functions
Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, x::Function) = x
Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, x::Chunk{T}) where {T<:Function} =
Dagger.move(from_proc, to_proc, fetch(x))

#= FIXME: Adapt BLAS/LAPACK functions
import LinearAlgebra: BLAS, LAPACK
for lib in [BLAS, LAPACK]
for name in names(lib; all=true)
name == nameof(lib) && continue
startswith(string(name), '#') && continue
endswith(string(name), '!') || continue

for culib in [CUBLAS, CUSOLVER]
if name in names(culib; all=true)
fn = getproperty(lib, name)
cufn = getproperty(culib, name)
@eval Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, ::$(typeof(fn))) = $cufn
end
if name in names(oneAPI.oneMKL; all=true)
fn = getproperty(lib, name)
mklfn = getproperty(oneAPI.oneMKL, name)
@eval Dagger.move(from_proc::CPUProc, to_proc::oneArrayDeviceProc, ::$(typeof(fn))) = $mklfn
end
end
end
=#

# Task execution
function Dagger.execute!(proc::oneArrayDeviceProc, f, args...; kwargs...)
Expand Down Expand Up @@ -332,7 +342,13 @@ Dagger.scope_key_precedence(::Val{:intel_gpus}) = 1

const DEVICES = Dict{Int, ZeDevice}()
const DRIVERS = Dict{Int, ZeDriver}()
const CONTEXTS = Dict{Int, ZeContext}()
const CONTEXTS = IdDict{ZeDriver, ZeContext}()

function _get_context(driver::ZeDriver)
return get!(CONTEXTS, driver) do
ZeContext(driver)
end
end

function __init__()
if oneAPI.functional()
Expand All @@ -344,8 +360,7 @@ function __init__()
driver!(dev.driver)
DRIVERS[device_id] = dev.driver
device!(dev)
ctx = ZeContext(dev.driver)
CONTEXTS[device_id] = ctx
_get_context(dev.driver)
return proc
end
end
Expand Down
Loading