Skip to content

Commit e203246

Browse files
committed
Fix for issue #259
1 parent 0c9c8b7 commit e203246

File tree

1 file changed

+124
-1
lines changed

1 file changed

+124
-1
lines changed

src/compiler/compilation.jl

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,126 @@ function GPUCompiler.finish_module!(job::oneAPICompilerJob, mod::LLVM.Module,
3737
return entry
3838
end
3939

40+
# finish_ir! runs later in the pipeline, after optimizations that create nested insertvalue
41+
function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
42+
entry::LLVM.Function)
43+
entry = invoke(GPUCompiler.finish_ir!,
44+
Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(mod), typeof(entry)},
45+
job, mod, entry)
46+
47+
# FIX: Flatten nested insertvalue instructions to work around SPIR-V bug
48+
# See: https://github.com/JuliaGPU/oneAPI.jl/issues/XXX
49+
# Intel's SPIR-V runtime has a bug where OpCompositeInsert with nested
50+
# indices (e.g., "1 0") corrupts adjacent struct fields.
51+
flatten_nested_insertvalue!(mod)
52+
53+
return entry
54+
end
55+
56+
# Flatten nested insertvalue instructions
57+
# This works around a bug in Intel's SPIR-V runtime where OpCompositeInsert
58+
# with nested array indices corrupts adjacent struct fields.
59+
function flatten_nested_insertvalue!(mod::LLVM.Module)
60+
changed = false
61+
count = 0
62+
63+
for f in functions(mod)
64+
isempty(blocks(f)) && continue
65+
66+
for bb in blocks(f)
67+
# Collect instructions to process (can't modify while iterating)
68+
to_process = LLVM.Instruction[]
69+
70+
for inst in instructions(bb)
71+
# Check if this is an insertvalue with nested indices
72+
if LLVM.API.LLVMGetInstructionOpcode(inst) == LLVM.API.LLVMInsertValue
73+
num_indices = LLVM.API.LLVMGetNumIndices(inst)
74+
if num_indices > 1
75+
push!(to_process, inst)
76+
end
77+
end
78+
end
79+
80+
# Flatten each nested insertvalue
81+
for inst in to_process
82+
try
83+
flatten_insert!(inst)
84+
changed = true
85+
count += 1
86+
catch e
87+
@warn "Failed to flatten nested insertvalue" exception=(e, catch_backtrace())
88+
end
89+
end
90+
end
91+
end
92+
93+
return changed
94+
end
95+
96+
function flatten_insert!(inst::LLVM.Instruction)
97+
# Transform: insertvalue %base, %val, i, j, k...
98+
# Into: extractvalue %base, i
99+
# insertvalue %extracted, %val, j, k...
100+
# insertvalue %base, %modified, i
101+
102+
composite = LLVM.operands(inst)[1]
103+
value = LLVM.operands(inst)[2]
104+
105+
num_indices = LLVM.API.LLVMGetNumIndices(inst)
106+
idx_ptr = LLVM.API.LLVMGetIndices(inst)
107+
indices = unsafe_wrap(Array, idx_ptr, num_indices)
108+
109+
builder = LLVM.IRBuilder()
110+
LLVM.position!(builder, inst)
111+
112+
# Strategy: Recursively extract and insert for each nesting level
113+
# For insertvalue %base, %val, i, j, k
114+
# Do: %tmp1 = extractvalue %base, i
115+
# %tmp2 = extractvalue %tmp1, j
116+
# %tmp3 = insertvalue %tmp2, %val, k
117+
# %tmp4 = insertvalue %tmp1, %tmp3, j
118+
# %result = insertvalue %base, %tmp4, i
119+
120+
# But that's complex. Simpler approach for 2-3 levels:
121+
# Just do one level of flattening at a time
122+
first_idx = indices[1]
123+
rest_indices = indices[2:end]
124+
125+
# Extract the first level
126+
extracted = LLVM.extract_value!(builder, composite, first_idx)
127+
128+
# Now insert into the extracted value using remaining indices
129+
# The LLVM IR builder will handle this correctly
130+
inserted = extracted
131+
if length(rest_indices) == 1
132+
# Simple case: just one more level
133+
inserted = LLVM.insert_value!(builder, extracted, value, rest_indices[1])
134+
else
135+
# Multiple levels: need to extract down, insert, then insert back up
136+
# For now, recursively extract to the deepest level
137+
temps = [extracted]
138+
for i in 1:(length(rest_indices)-1)
139+
temp = LLVM.extract_value!(builder, temps[end], rest_indices[i])
140+
push!(temps, temp)
141+
end
142+
143+
# Insert the value at the deepest level
144+
inserted = LLVM.insert_value!(builder, temps[end], value, rest_indices[end])
145+
146+
# Insert back up the chain
147+
for i in (length(rest_indices)-1):-1:1
148+
inserted = LLVM.insert_value!(builder, temps[i], inserted, rest_indices[i])
149+
end
150+
end
151+
152+
# Insert the modified structure back into the original
153+
result = LLVM.insert_value!(builder, composite, inserted, first_idx)
154+
155+
LLVM.replace_uses!(inst, result)
156+
LLVM.API.LLVMInstructionEraseFromParent(inst)
157+
LLVM.dispose(builder)
158+
end
159+
40160

41161
## compiler implementation (cache, configure, compile, and link)
42162

@@ -68,7 +188,10 @@ end
68188
supports_fp64 = oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64
69189

70190
# TODO: emit printf format strings in constant memory
71-
extensions = String["SPV_EXT_relaxed_printf_string_address_space"]
191+
extensions = String[
192+
"SPV_EXT_relaxed_printf_string_address_space",
193+
"SPV_EXT_shader_atomic_float_add"
194+
]
72195

73196
# create GPUCompiler objects
74197
target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, kwargs...)

0 commit comments

Comments
 (0)