Skip to content

Commit d9b78c5

Browse files
authored
Merge pull request #103 from JuliaGPU/tb/compile
Make the compilation example first-class functionality.
2 parents 89da86f + b92664a commit d9b78c5

File tree

7 files changed

+124
-106
lines changed

7 files changed

+124
-106
lines changed

examples/compilation/README.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/compilation/usage.jl

Lines changed: 0 additions & 69 deletions
This file was deleted.

examples/vadd.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using CUDAdrv, CUDArt
2+
using Base.Test
3+
4+
using Compat
5+
6+
dev = CuDevice(0)
7+
ctx = CuContext(dev)
8+
9+
CUDArt.@compile dev kernel_vadd """
10+
__global__ void kernel_vadd(const float *a, const float *b, float *c)
11+
{
12+
int i = blockIdx.x *blockDim.x + threadIdx.x;
13+
c[i] = a[i] + b[i];
14+
}
15+
"""
16+
17+
dims = (3,4)
18+
a = round.(rand(Float32, dims) * 100)
19+
b = round.(rand(Float32, dims) * 100)
20+
21+
d_a = CuArray(a)
22+
d_b = CuArray(b)
23+
d_c = similar(d_a)
24+
25+
len = prod(dims)
26+
cudacall(kernel_vadd, len, 1, Tuple{Ptr{Cfloat},Ptr{Cfloat},Ptr{Cfloat}}, d_a, d_b, d_c)
27+
c = Array(d_c)
28+
@test a+b c
29+
30+
destroy!(ctx)

src/CUDArt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include("device.jl")
4646
include("stream.jl")
4747
include("pointer.jl")
4848
include("arrays.jl")
49+
include("compile.jl")
4950
include("execute.jl")
5051

5152
include("precompile.jl")
Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
# EXCLUDE FROM TESTING
2-
3-
using CUDArt
4-
using Compat
1+
export CompileError
52

63
# Generate a temporary file with specific suffix
74
# NOTE: mkstemps is glibc 2.19+, so emulate its behavior
@@ -28,9 +25,8 @@ macro compile(dev, kernel, code)
2825
:($(esc(kernel)) = _compile($(esc(dev)), $kernel_name, $code, $containing_file)))
2926
end
3027

31-
type CompileError <: Base.WrappedException
28+
immutable CompileError <: Exception
3229
message::String
33-
error
3430
end
3531

3632
const builddir = joinpath(@__DIR__, ".cache")
@@ -43,7 +39,7 @@ function _compile(dev, kernel, code, containing_file)
4339
mkpath(builddir)
4440
end
4541

46-
# Check if we need to compile
42+
# check if we need to compile
4743
codehash = hex(hash(code))
4844
output = "$builddir/$(kernel)_$(codehash)-$(arch).ptx"
4945
if isfile(output)
@@ -52,51 +48,44 @@ function _compile(dev, kernel, code, containing_file)
5248
need_compile = true
5349
end
5450

55-
# Compile the source, if necessary
51+
# compile the source, if necessary
5652
if need_compile
57-
# Write the source into a compilable file
53+
# write the source to a compilable file
5854
(source, io) = mkstemps(".cu")
5955
write(io, """
6056
extern "C"
6157
{
6258
$code
6359
}
6460
""")
65-
close(io)
61+
Base.close(io)
6662

6763
compile_flags = vcat(CUDArt.toolchain_flags, ["--gpu-architecture", arch])
68-
try
69-
# TODO: capture STDERR
70-
run(pipeline(`$(CUDArt.toolchain_nvcc) $(compile_flags) -ptx -o $output $source`, stderr=DevNull))
71-
catch ex
72-
isa(ex, ErrorException) || rethrow(ex)
73-
rethrow(CompileError("compilation of kernel $kernel failed (typo in C++ source?)", ex))
74-
finally
75-
rm(source)
64+
err = Pipe()
65+
cmd = `$(CUDArt.toolchain_nvcc) $(compile_flags) -ptx -o $output $source`
66+
result = success(pipeline(cmd; stdout=DevNull, stderr=err))
67+
Base.close(err.in)
68+
rm(source)
69+
70+
errors = readstring(err)
71+
if !result
72+
throw(CompileError("compilation of kernel $kernel failed\n$errors"))
73+
elseif !isempty(errors)
74+
warn("during compilation of kernel $kernel:\n$errors")
7675
end
7776

7877
if !isfile(output)
7978
error("compilation of kernel $kernel failed (no output generated)")
8079
end
8180
end
8281

83-
# Pass the module to the CUDA driver
84-
mod = try
85-
CuModuleFile(output)
86-
catch ex
87-
rethrow(CompileError("loading of kernel $kernel failed (invalid CUDA code?)", ex))
88-
end
89-
90-
# Load the function pointer
91-
func = try
92-
CuFunction(mod, kernel)
93-
catch ex
94-
rethrow(CompileError("could not find kernel $kernel in the compiled binary (wrong function name?)", ex))
95-
end
96-
97-
return func
82+
mod = CUDAdrv.CuModuleFile(output)
83+
return CUDAdrv.CuFunction(mod, kernel)
9884
end
9985

10086
function clean_cache()
101-
rm(builddir; recursive=true)
87+
if ispath(builddir)
88+
@assert isdir(builddir)
89+
rm(builddir; recursive=true)
90+
end
10291
end

test/compile.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using CUDArt
2+
import CUDAdrv
3+
using Base.Test
4+
5+
dev = CUDAdrv.CuDevice(0)
6+
ctx = CUDAdrv.CuContext(dev)
7+
8+
CUDArt.clean_cache() # for deterministic testing purposes
9+
10+
11+
## basic compilation & execution
12+
13+
let
14+
CUDArt.@compile dev kernel """
15+
__global__ void kernel()
16+
{
17+
}
18+
"""
19+
20+
CUDAdrv.cudacall(kernel, 1, 1, ())
21+
end
22+
23+
@test_throws CompileError let
24+
CUDArt.@compile dev kernel """
25+
__global__ void kernel()
26+
{
27+
invalid code
28+
}
29+
"""
30+
end
31+
32+
@test_throws CUDAdrv.CuError let
33+
CUDArt.@compile dev wrongname """
34+
__global__ void kernel()
35+
{
36+
}
37+
"""
38+
end
39+
40+
41+
## argument passing
42+
43+
dims = (16, 16)
44+
len = prod(dims)
45+
46+
CUDArt.@compile dev kernel_copy """
47+
__global__ void kernel_copy(const float *input, float *output)
48+
{
49+
int i = blockIdx.x * blockDim.x + threadIdx.x;
50+
51+
output[i] = input[i];
52+
}
53+
"""
54+
55+
let
56+
input = round.(rand(Cfloat, dims) * 100)
57+
58+
input_dev = CUDAdrv.CuArray(input)
59+
output_dev = CUDAdrv.CuArray{Cfloat}(dims)
60+
61+
CUDAdrv.cudacall(kernel_copy, 1, len,
62+
Tuple{Ptr{Cfloat}, Ptr{Cfloat}},
63+
input_dev, output_dev)
64+
output = Array(output_dev)
65+
@test input output
66+
end
67+
68+
69+
CUDAdrv.destroy!(ctx)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include("gc.jl")
22
include("test.jl")
3+
include("compile.jl")
34
include("examples.jl")

0 commit comments

Comments
 (0)