Skip to content

Commit ebe2aa0

Browse files
Merge pull request #2429 from KhronosGroup/fix-2411
HLSL: Fix lowering of arrayed clip/cull distance in mesh shaders.
2 parents d2478b2 + de515fe commit ebe2aa0

11 files changed

+176
-14
lines changed

reference/opt/shaders-hlsl/mesh/mesh-shader-basic-lines.spv14.vk.nocompat.mesh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT
5151

5252
groupshared float shared_float[16];
5353

54-
void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
54+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
5555
{
5656
SetMeshOutputCounts(24u, 22u);
5757
float3 _173 = float3(gl_GlobalInvocationID);

reference/opt/shaders-hlsl/mesh/mesh-shader-basic-triangle.spv14.vk.nocompat.mesh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT
5151

5252
groupshared float shared_float[16];
5353

54-
void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
54+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
5555
{
5656
SetMeshOutputCounts(24u, 22u);
5757
float3 _29 = float3(gl_GlobalInvocationID);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
static uint gl_LocalInvocationIndex;
2+
struct SPIRV_Cross_Input
3+
{
4+
uint gl_LocalInvocationIndex : SV_GroupIndex;
5+
};
6+
7+
struct gl_MeshPerVertexEXT
8+
{
9+
float4 gl_ClipDistance : SV_ClipDistance;
10+
};
11+
12+
struct gl_MeshPerPrimitiveEXT
13+
{
14+
};
15+
16+
void write_clip_distance(inout float v[4])
17+
{
18+
v[0] += 1.0f;
19+
v[1] += 2.0f;
20+
v[2] += 3.0f;
21+
v[3] += 4.0f;
22+
}
23+
24+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[3])
25+
{
26+
SetMeshOutputCounts(3u, 1u);
27+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0f;
28+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1] = 4.0f;
29+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2] = 4.0f;
30+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] = 4.0f;
31+
float _62[4] = { gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2], gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] };
32+
float param[4] = _62;
33+
write_clip_distance(param);
34+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance = float4(param[0], param[1], param[2], param[3]);
35+
}
36+
37+
[outputtopology("triangle")]
38+
[numthreads(1, 1, 1)]
39+
void main(SPIRV_Cross_Input stage_input, out vertices gl_MeshPerVertexEXT gl_MeshVerticesEXT[3])
40+
{
41+
gl_LocalInvocationIndex = stage_input.gl_LocalInvocationIndex;
42+
mesh_main(gl_MeshVerticesEXT);
43+
}

reference/shaders-hlsl-no-opt/mesh/mesh-shader-basic-triangle.spv14.vk.nocompat.nofxc.flip-vert-y.mesh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ float spvFlipVertY(float v)
5252
return -v;
5353
}
5454

55-
void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
55+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
5656
{
5757
SetMeshOutputCounts(24u, 22u);
5858
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = spvFlipVertY(float4(float3(gl_GlobalInvocationID), 1.0f));

reference/shaders-hlsl/mesh/mesh-shader-basic-lines.spv14.vk.nocompat.mesh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT
5151

5252
groupshared float shared_float[16];
5353

54-
void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22])
54+
void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22])
5555
{
5656
gl_PrimitiveLineIndicesEXT[gl_LocalInvocationIndex] = uint2(0u, 1u) + gl_LocalInvocationIndex.xx;
5757
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = int(gl_GlobalInvocationID.x);
@@ -61,7 +61,7 @@ void main3(inout uint2 gl_PrimitiveLineIndicesEXT[22], inout gl_MeshPerPrimitive
6161
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveShadingRateEXT = int(gl_GlobalInvocationID.x) + 3;
6262
}
6363

64-
void main2(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
64+
void main2(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
6565
{
6666
SetMeshOutputCounts(24u, 22u);
6767
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0f);
@@ -81,7 +81,7 @@ void main2(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPri
8181
}
8282
}
8383

84-
void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
84+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint2 gl_PrimitiveLineIndicesEXT[22])
8585
{
8686
main2(gl_MeshVerticesEXT, gl_MeshPrimitivesEXT, _payload, gl_PrimitiveLineIndicesEXT);
8787
}

reference/shaders-hlsl/mesh/mesh-shader-basic-triangle.spv14.vk.nocompat.mesh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gl_MeshPerPrimitiveEXT
5151

5252
groupshared float shared_float[16];
5353

54-
void mesh_main(inout gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], inout gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
54+
void mesh_main(out gl_MeshPerVertexEXT gl_MeshVerticesEXT[24], out gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[22], TaskPayload _payload, inout uint3 gl_PrimitiveTriangleIndicesEXT[22])
5555
{
5656
SetMeshOutputCounts(24u, 22u);
5757
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0f);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#version 450
2+
#extension GL_EXT_mesh_shader : require
3+
layout(triangles, max_vertices = 3, max_primitives = 1) out;
4+
5+
out gl_MeshPerVertexEXT
6+
{
7+
float gl_ClipDistance[4];
8+
} gl_MeshVerticesEXT[];
9+
10+
void write_clip_distance(inout float v[4])
11+
{
12+
v[0] += 1.0;
13+
v[1] += 2.0;
14+
v[2] += 3.0;
15+
v[3] += 4.0;
16+
}
17+
18+
void main()
19+
{
20+
SetMeshOutputsEXT(3, 1);
21+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
22+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[1] = 4.0;
23+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[2] = 4.0;
24+
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[3] = 4.0;
25+
write_clip_distance(gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance);
26+
}

spirv_common.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,8 @@ struct AccessChainMeta
16041604
bool flattened_struct = false;
16051605
bool relaxed_precision = false;
16061606
bool access_meshlet_position_y = false;
1607+
bool chain_is_builtin = false;
1608+
spv::BuiltIn builtin = {};
16071609
};
16081610

16091611
enum ExtendedDecorations

spirv_glsl.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10210,6 +10210,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
1021010210
bool pending_array_enclose = false;
1021110211
bool dimension_flatten = false;
1021210212
bool access_meshlet_position_y = false;
10213+
bool chain_is_builtin = false;
10214+
spv::BuiltIn chained_builtin = {};
1021310215

1021410216
if (auto *base_expr = maybe_get<SPIRExpression>(base))
1021510217
{
@@ -10367,6 +10369,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
1036710369
auto builtin = ir.meta[base].decoration.builtin_type;
1036810370
bool mesh_shader = get_execution_model() == ExecutionModelMeshEXT;
1036910371

10372+
chain_is_builtin = true;
10373+
chained_builtin = builtin;
10374+
1037010375
switch (builtin)
1037110376
{
1037210377
case BuiltInCullDistance:
@@ -10502,6 +10507,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
1050210507
{
1050310508
access_meshlet_position_y = true;
1050410509
}
10510+
10511+
chain_is_builtin = true;
10512+
chained_builtin = builtin;
1050510513
}
1050610514
else
1050710515
{
@@ -10721,6 +10729,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
1072110729
meta->storage_physical_type = physical_type;
1072210730
meta->relaxed_precision = relaxed_precision;
1072310731
meta->access_meshlet_position_y = access_meshlet_position_y;
10732+
meta->chain_is_builtin = chain_is_builtin;
10733+
meta->builtin = chained_builtin;
1072410734
}
1072510735

1072610736
return expr;
@@ -12336,6 +12346,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
1233612346
flattened_structs[ops[1]] = true;
1233712347
if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis)
1233812348
set_decoration(ops[1], DecorationRelaxedPrecision);
12349+
if (meta.chain_is_builtin)
12350+
set_decoration(ops[1], DecorationBuiltIn, meta.builtin);
1233912351

1234012352
// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
1234112353
// temporary which could be subject to invalidation.
@@ -15679,7 +15691,16 @@ string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg)
1567915691

1568015692
if (type.pointer)
1568115693
{
15682-
if (arg.write_count && arg.read_count)
15694+
// If we're passing around block types to function, we really mean reference in a pointer sense,
15695+
// but DXC does not like inout for mesh blocks, so workaround that. out is technically not correct,
15696+
// but it works in practice due to legalization. It's ... not great, but you gotta do what you gotta do.
15697+
// GLSL will never hit this case since it's not valid.
15698+
if (type.storage == StorageClassOutput && get_execution_model() == ExecutionModelMeshEXT &&
15699+
has_decoration(type.self, DecorationBlock) && is_builtin_type(type) && arg.write_count)
15700+
{
15701+
direction = "out ";
15702+
}
15703+
else if (arg.write_count && arg.read_count)
1568315704
direction = "inout ";
1568415705
else if (arg.write_count)
1568515706
direction = "out ";

spirv_hlsl.cpp

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4775,13 +4775,13 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
47754775
{
47764776
auto ops = stream(instruction);
47774777

4778-
auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4778+
uint32_t result_type = ops[0];
4779+
uint32_t id = ops[1];
4780+
uint32_t ptr = ops[2];
4781+
4782+
auto *chain = maybe_get<SPIRAccessChain>(ptr);
47794783
if (chain)
47804784
{
4781-
uint32_t result_type = ops[0];
4782-
uint32_t id = ops[1];
4783-
uint32_t ptr = ops[2];
4784-
47854785
auto &type = get<SPIRType>(result_type);
47864786
bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
47874787

@@ -4819,7 +4819,35 @@ void CompilerHLSL::emit_load(const Instruction &instruction)
48194819
}
48204820
}
48214821
else
4822-
CompilerGLSL::emit_instruction(instruction);
4822+
{
4823+
// Very special case where we cannot rely on IO lowering.
4824+
// Mesh shader clip/cull arrays ... Cursed.
4825+
auto &res_type = get<SPIRType>(result_type);
4826+
if (get_execution_model() == ExecutionModelMeshEXT &&
4827+
has_decoration(ptr, DecorationBuiltIn) &&
4828+
(get_decoration(ptr, DecorationBuiltIn) == BuiltInClipDistance ||
4829+
get_decoration(ptr, DecorationBuiltIn) == BuiltInCullDistance) &&
4830+
is_array(res_type) && !is_array(get<SPIRType>(res_type.parent_type)))
4831+
{
4832+
track_expression_read(ptr);
4833+
string load_expr = "{ ";
4834+
uint32_t num_elements = to_array_size_literal(res_type);
4835+
for (uint32_t i = 0; i < num_elements; i++)
4836+
{
4837+
load_expr += join(to_expression(ptr), "[", i, "]");
4838+
if (i + 1 < num_elements)
4839+
load_expr += ", ";
4840+
}
4841+
load_expr += " }";
4842+
emit_op(result_type, id, load_expr, false);
4843+
register_read(id, ptr, false);
4844+
inherit_expression_dependencies(id, ptr);
4845+
}
4846+
else
4847+
{
4848+
CompilerGLSL::emit_instruction(instruction);
4849+
}
4850+
}
48234851
}
48244852

48254853
void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
@@ -6903,3 +6931,43 @@ bool CompilerHLSL::is_user_type_structured(uint32_t id) const
69036931
}
69046932
return false;
69056933
}
6934+
6935+
void CompilerHLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
6936+
{
6937+
// Loading a full array of ClipDistance needs special consideration in mesh shaders
6938+
// since we cannot lower them by wrapping the variables in global statics.
6939+
// Fortunately, clip/cull is a proper vector in HLSL so we can lower with simple rvalue casts.
6940+
if (get_execution_model() != ExecutionModelMeshEXT ||
6941+
!has_decoration(target_id, DecorationBuiltIn) ||
6942+
!is_array(expr_type))
6943+
{
6944+
CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
6945+
return;
6946+
}
6947+
6948+
auto builtin = BuiltIn(get_decoration(target_id, DecorationBuiltIn));
6949+
if (builtin != BuiltInClipDistance && builtin != BuiltInCullDistance)
6950+
{
6951+
CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
6952+
return;
6953+
}
6954+
6955+
// Array of array means one thread is storing clip distance for all vertices. Nonsensical?
6956+
if (is_array(get<SPIRType>(expr_type.parent_type)))
6957+
SPIRV_CROSS_THROW("Attempting to store all mesh vertices in one go. This is not supported.");
6958+
6959+
uint32_t num_clip = to_array_size_literal(expr_type);
6960+
if (num_clip > 4)
6961+
SPIRV_CROSS_THROW("Number of clip or cull distances exceeds 4, this will not work with mesh shaders.");
6962+
6963+
auto unrolled_expr = join("float", num_clip, "(");
6964+
for (uint32_t i = 0; i < num_clip; i++)
6965+
{
6966+
unrolled_expr += join(expr, "[", i, "]");
6967+
if (i + 1 < num_clip)
6968+
unrolled_expr += ", ";
6969+
}
6970+
6971+
unrolled_expr += ")";
6972+
expr = std::move(unrolled_expr);
6973+
}

0 commit comments

Comments
 (0)