Skip to content

Commit 6a62df6

Browse files
Merge pull request #2448 from andrejnau/fix_hlsl_mesh_shader_to_msl
MSL: Fix support for HLSL mesh shader
2 parents 8a92822 + 34fe989 commit 6a62df6

File tree

4 files changed

+333
-7
lines changed

4 files changed

+333
-7
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
7+
using namespace metal;
8+
9+
template<typename T, size_t Num>
10+
struct spvUnsafeArray
11+
{
12+
T elements[Num ? Num : 1];
13+
14+
thread T& operator [] (size_t pos) thread
15+
{
16+
return elements[pos];
17+
}
18+
constexpr const thread T& operator [] (size_t pos) const thread
19+
{
20+
return elements[pos];
21+
}
22+
23+
device T& operator [] (size_t pos) device
24+
{
25+
return elements[pos];
26+
}
27+
constexpr const device T& operator [] (size_t pos) const device
28+
{
29+
return elements[pos];
30+
}
31+
32+
constexpr const constant T& operator [] (size_t pos) const constant
33+
{
34+
return elements[pos];
35+
}
36+
37+
threadgroup T& operator [] (size_t pos) threadgroup
38+
{
39+
return elements[pos];
40+
}
41+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
42+
{
43+
return elements[pos];
44+
}
45+
46+
object_data T& operator [] (size_t pos) object_data
47+
{
48+
return elements[pos];
49+
}
50+
constexpr const object_data T& operator [] (size_t pos) const object_data
51+
{
52+
return elements[pos];
53+
}
54+
};
55+
56+
void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
57+
{
58+
if (gl_LocalInvocationIndex == 0)
59+
{
60+
spvMeshSizes.x = vertexCount;
61+
spvMeshSizes.y = primitiveCount;
62+
}
63+
}
64+
65+
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1, 1, 1);
66+
67+
struct spvPerVertex
68+
{
69+
float4 gl_Position [[position]];
70+
float3 out_var_COLOR0 [[user(locn0)]];
71+
};
72+
73+
using spvMesh_t = mesh<spvPerVertex, void, 3, 1, topology::triangle>;
74+
75+
static inline __attribute__((always_inline))
76+
void _1(threadgroup spvUnsafeArray<uint3, 1>& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray<float4, 3>& gl_Position, threadgroup spvUnsafeArray<float3, 3>& out_var_COLOR0, thread uint& gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes)
77+
{
78+
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 3u, 1u);
79+
gl_PrimitiveTriangleIndicesEXT[0] = uint3(0u, 1u, 2u);
80+
gl_Position[0] = float4(-0.5, -0.5, 0.0, 1.0);
81+
out_var_COLOR0[0] = float3(1.0, 0.0, 0.0);
82+
gl_Position[1] = float4(0.5, -0.5, 0.0, 1.0);
83+
out_var_COLOR0[1] = float3(0.0, 1.0, 0.0);
84+
gl_Position[2] = float4(0.0, 0.5, 0.0, 1.0);
85+
out_var_COLOR0[2] = float3(0.0, 0.0, 1.0);
86+
}
87+
88+
[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], spvMesh_t spvMesh)
89+
{
90+
threadgroup uint2 spvMeshSizes;
91+
threadgroup spvUnsafeArray<uint3, 1> gl_PrimitiveTriangleIndicesEXT;
92+
threadgroup spvUnsafeArray<float4, 3> gl_Position;
93+
threadgroup spvUnsafeArray<float3, 3> out_var_COLOR0;
94+
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
95+
_1(gl_PrimitiveTriangleIndicesEXT, gl_Position, out_var_COLOR0, gl_LocalInvocationIndex, spvMeshSizes);
96+
threadgroup_barrier(mem_flags::mem_threadgroup);
97+
if (spvMeshSizes.y == 0)
98+
{
99+
return;
100+
}
101+
spvMesh.set_primitive_count(spvMeshSizes.y);
102+
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
103+
for (uint spvVI = gl_LocalInvocationIndex; spvVI < spvMeshSizes.x; spvVI += spvThreadCount)
104+
{
105+
spvPerVertex spvV = {};
106+
spvV.gl_Position = gl_Position[spvVI];
107+
spvV.out_var_COLOR0 = out_var_COLOR0[spvVI];
108+
spvMesh.set_vertex(spvVI, spvV);
109+
}
110+
const uint spvPI = gl_LocalInvocationIndex;
111+
if (gl_LocalInvocationIndex < spvMeshSizes.y)
112+
{
113+
spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);
114+
spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);
115+
spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);
116+
}
117+
}
118+
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
7+
using namespace metal;
8+
9+
template<typename T, size_t Num>
10+
struct spvUnsafeArray
11+
{
12+
T elements[Num ? Num : 1];
13+
14+
thread T& operator [] (size_t pos) thread
15+
{
16+
return elements[pos];
17+
}
18+
constexpr const thread T& operator [] (size_t pos) const thread
19+
{
20+
return elements[pos];
21+
}
22+
23+
device T& operator [] (size_t pos) device
24+
{
25+
return elements[pos];
26+
}
27+
constexpr const device T& operator [] (size_t pos) const device
28+
{
29+
return elements[pos];
30+
}
31+
32+
constexpr const constant T& operator [] (size_t pos) const constant
33+
{
34+
return elements[pos];
35+
}
36+
37+
threadgroup T& operator [] (size_t pos) threadgroup
38+
{
39+
return elements[pos];
40+
}
41+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
42+
{
43+
return elements[pos];
44+
}
45+
46+
object_data T& operator [] (size_t pos) object_data
47+
{
48+
return elements[pos];
49+
}
50+
constexpr const object_data T& operator [] (size_t pos) const object_data
51+
{
52+
return elements[pos];
53+
}
54+
};
55+
56+
void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
57+
{
58+
if (gl_LocalInvocationIndex == 0)
59+
{
60+
spvMeshSizes.x = vertexCount;
61+
spvMeshSizes.y = primitiveCount;
62+
}
63+
}
64+
65+
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1, 1, 1);
66+
67+
struct spvPerVertex
68+
{
69+
float4 gl_Position [[position]];
70+
float3 out_var_COLOR0 [[user(locn0)]];
71+
};
72+
73+
using spvMesh_t = mesh<spvPerVertex, void, 3, 1, topology::triangle>;
74+
75+
static inline __attribute__((always_inline))
76+
void _1(threadgroup spvUnsafeArray<uint3, 1>& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray<float4, 3>& gl_Position, threadgroup spvUnsafeArray<float3, 3>& out_var_COLOR0, thread uint& gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes)
77+
{
78+
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 3u, 1u);
79+
gl_PrimitiveTriangleIndicesEXT[0] = uint3(0u, 1u, 2u);
80+
gl_Position[0] = float4(-0.5, -0.5, 0.0, 1.0);
81+
out_var_COLOR0[0] = float3(1.0, 0.0, 0.0);
82+
gl_Position[1] = float4(0.5, -0.5, 0.0, 1.0);
83+
out_var_COLOR0[1] = float3(0.0, 1.0, 0.0);
84+
gl_Position[2] = float4(0.0, 0.5, 0.0, 1.0);
85+
out_var_COLOR0[2] = float3(0.0, 0.0, 1.0);
86+
}
87+
88+
[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], spvMesh_t spvMesh)
89+
{
90+
threadgroup uint2 spvMeshSizes;
91+
threadgroup spvUnsafeArray<uint3, 1> gl_PrimitiveTriangleIndicesEXT;
92+
threadgroup spvUnsafeArray<float4, 3> gl_Position;
93+
threadgroup spvUnsafeArray<float3, 3> out_var_COLOR0;
94+
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
95+
_1(gl_PrimitiveTriangleIndicesEXT, gl_Position, out_var_COLOR0, gl_LocalInvocationIndex, spvMeshSizes);
96+
threadgroup_barrier(mem_flags::mem_threadgroup);
97+
if (spvMeshSizes.y == 0)
98+
{
99+
return;
100+
}
101+
spvMesh.set_primitive_count(spvMeshSizes.y);
102+
const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
103+
for (uint spvVI = gl_LocalInvocationIndex; spvVI < spvMeshSizes.x; spvVI += spvThreadCount)
104+
{
105+
spvPerVertex spvV = {};
106+
spvV.gl_Position = gl_Position[spvVI];
107+
spvV.out_var_COLOR0 = out_var_COLOR0[spvVI];
108+
spvMesh.set_vertex(spvVI, spvV);
109+
}
110+
const uint spvPI = gl_LocalInvocationIndex;
111+
if (gl_LocalInvocationIndex < spvMeshSizes.y)
112+
{
113+
spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);
114+
spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);
115+
spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);
116+
}
117+
}
118+
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: Google spiregg; 0
4+
; Bound: 48
5+
; Schema: 0
6+
OpCapability MeshShadingEXT
7+
OpExtension "SPV_EXT_mesh_shader"
8+
OpMemoryModel Logical GLSL450
9+
OpEntryPoint MeshEXT %main "main" %2 %gl_Position %out_var_COLOR0
10+
OpExecutionMode %main LocalSize 1 1 1
11+
OpExecutionMode %main OutputTrianglesEXT
12+
OpExecutionMode %main OutputVertices 3
13+
OpExecutionMode %main OutputPrimitivesEXT 1
14+
OpSource HLSL 650
15+
OpName %out_var_COLOR0 "out.var.COLOR0"
16+
OpName %main "main"
17+
OpDecorate %2 BuiltIn PrimitiveTriangleIndicesEXT
18+
OpDecorate %gl_Position BuiltIn Position
19+
OpDecorate %out_var_COLOR0 Location 0
20+
%uint = OpTypeInt 32 0
21+
%uint_3 = OpConstant %uint 3
22+
%uint_1 = OpConstant %uint 1
23+
%uint_0 = OpConstant %uint 0
24+
%uint_2 = OpConstant %uint 2
25+
%v3uint = OpTypeVector %uint 3
26+
%11 = OpConstantComposite %v3uint %uint_0 %uint_1 %uint_2
27+
%int = OpTypeInt 32 1
28+
%int_0 = OpConstant %int 0
29+
%float = OpTypeFloat 32
30+
%float_n0_5 = OpConstant %float -0.5
31+
%float_0 = OpConstant %float 0
32+
%float_1 = OpConstant %float 1
33+
%v4float = OpTypeVector %float 4
34+
%19 = OpConstantComposite %v4float %float_n0_5 %float_n0_5 %float_0 %float_1
35+
%v3float = OpTypeVector %float 3
36+
%21 = OpConstantComposite %v3float %float_1 %float_0 %float_0
37+
%float_0_5 = OpConstant %float 0.5
38+
%23 = OpConstantComposite %v4float %float_0_5 %float_n0_5 %float_0 %float_1
39+
%int_1 = OpConstant %int 1
40+
%25 = OpConstantComposite %v3float %float_0 %float_1 %float_0
41+
%26 = OpConstantComposite %v4float %float_0 %float_0_5 %float_0 %float_1
42+
%int_2 = OpConstant %int 2
43+
%28 = OpConstantComposite %v3float %float_0 %float_0 %float_1
44+
%_arr_v3uint_uint_1 = OpTypeArray %v3uint %uint_1
45+
%_ptr_Output__arr_v3uint_uint_1 = OpTypePointer Output %_arr_v3uint_uint_1
46+
%_arr_v4float_uint_3 = OpTypeArray %v4float %uint_3
47+
%_ptr_Output__arr_v4float_uint_3 = OpTypePointer Output %_arr_v4float_uint_3
48+
%_arr_v3float_uint_3 = OpTypeArray %v3float %uint_3
49+
%_ptr_Output__arr_v3float_uint_3 = OpTypePointer Output %_arr_v3float_uint_3
50+
%void = OpTypeVoid
51+
%36 = OpTypeFunction %void
52+
%_ptr_Output_v3uint = OpTypePointer Output %v3uint
53+
%_ptr_Output_v4float = OpTypePointer Output %v4float
54+
%_ptr_Output_v3float = OpTypePointer Output %v3float
55+
%2 = OpVariable %_ptr_Output__arr_v3uint_uint_1 Output
56+
%gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_3 Output
57+
%out_var_COLOR0 = OpVariable %_ptr_Output__arr_v3float_uint_3 Output
58+
%main = OpFunction %void None %36
59+
%40 = OpLabel
60+
OpSetMeshOutputsEXT %uint_3 %uint_1
61+
%41 = OpAccessChain %_ptr_Output_v3uint %2 %int_0
62+
OpStore %41 %11
63+
%42 = OpAccessChain %_ptr_Output_v4float %gl_Position %int_0
64+
OpStore %42 %19
65+
%43 = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_0
66+
OpStore %43 %21
67+
%44 = OpAccessChain %_ptr_Output_v4float %gl_Position %int_1
68+
OpStore %44 %23
69+
%45 = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_1
70+
OpStore %45 %25
71+
%46 = OpAccessChain %_ptr_Output_v4float %gl_Position %int_2
72+
OpStore %46 %26
73+
%47 = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_2
74+
OpStore %47 %28
75+
OpReturn
76+
OpFunctionEnd

spirv_msl.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ string CompilerMSL::compile()
16131613
backend.nonuniform_qualifier = "";
16141614
backend.support_small_type_sampling_result = true;
16151615
backend.force_merged_mesh_block = false;
1616-
backend.force_gl_in_out_block = get_execution_model() == ExecutionModelMeshEXT;
1616+
backend.force_gl_in_out_block = false;
16171617
backend.supports_empty_struct = true;
16181618
backend.support_64bit_switch = true;
16191619
backend.boolean_in_struct_remapped_type = SPIRType::Short;
@@ -2325,7 +2325,14 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
23252325
v.storage = StorageClassWorkgroup;
23262326

23272327
// Ensure the existing variable has a valid name and the new variable has all the same meta info
2328-
set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
2328+
if (ir.meta[arg_id].decoration.builtin)
2329+
{
2330+
set_name(arg_id, builtin_to_glsl(bi_type, var.storage));
2331+
}
2332+
else
2333+
{
2334+
set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
2335+
}
23292336
ir.meta[next_id] = ir.meta[arg_id];
23302337
}
23312338
else if (is_builtin && has_decoration(p_type->self, DecorationBlock))
@@ -7923,8 +7930,16 @@ void CompilerMSL::emit_specialization_constants_and_structs()
79237930
{
79247931
SpecializationConstant wg_x, wg_y, wg_z;
79257932
ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
7926-
bool emitted = false;
7933+
if (workgroup_size_id == 0 && is_mesh_shader())
7934+
{
7935+
auto &execution = get_entry_point();
7936+
statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
7937+
" [[maybe_unused]] = ", "uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
7938+
execution.workgroup_size.z, ");");
7939+
statement("");
7940+
}
79277941

7942+
bool emitted = false;
79287943
unordered_set<uint32_t> declared_structs;
79297944
unordered_set<uint32_t> aligned_structs;
79307945

@@ -8038,7 +8053,7 @@ void CompilerMSL::emit_specialization_constants_and_structs()
80388053
statement("#endif");
80398054
statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
80408055
";");
8041-
8056+
80428057
// Record the usage of macro
80438058
constant_macro_ids.insert(constant_id);
80448059
}
@@ -15199,8 +15214,7 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
1519915214

1520015215
if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
1520115216
decl = join(cv_qualifier, type_to_glsl(type, arg.id));
15202-
else if (builtin && builtin_type != spv::BuiltInPrimitiveTriangleIndicesEXT &&
15203-
builtin_type != spv::BuiltInPrimitiveLineIndicesEXT && builtin_type != spv::BuiltInPrimitivePointIndicesEXT)
15217+
else if (builtin && !is_mesh_shader())
1520415218
{
1520515219
// Only use templated array for Clip/Cull distance when feasible.
1520615220
// In other scenarios, we need need to override array length for tess levels (if used as outputs),
@@ -19043,7 +19057,7 @@ void CompilerMSL::analyze_argument_buffers()
1904319057
set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding);
1904419058
member_index++;
1904519059
}
19046-
19060+
1904719061
if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type))
1904819062
{
1904919063
recursive_inputs.insert(type_id);

0 commit comments

Comments
 (0)