Skip to content

Commit 4331fac

Browse files
Merge pull request #2545 from KhronosGroup/fix-2538
Improve handling of subgroup ops with extended types.
2 parents b26ac3f + 599d35f commit 4331fac

File tree

11 files changed

+710
-0
lines changed

11 files changed

+710
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
3+
#include <metal_stdlib>
4+
#include <simd/simd.h>
5+
6+
using namespace metal;
7+
8+
template<typename T>
9+
inline T spvSubgroupShuffle(T value, ushort lane)
10+
{
11+
return simd_shuffle(value, lane);
12+
}
13+
14+
template<>
15+
inline bool spvSubgroupShuffle(bool value, ushort lane)
16+
{
17+
return !!simd_shuffle((ushort)value, lane);
18+
}
19+
20+
template<uint N>
21+
inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
22+
{
23+
return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);
24+
}
25+
26+
template<>
27+
inline ulong spvSubgroupShuffle(ulong value, ushort lane)
28+
{
29+
return as_type<ulong>(spvSubgroupShuffle(as_type<uint2>(value), lane));
30+
}
31+
32+
template<>
33+
inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane)
34+
{
35+
return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane));
36+
}
37+
38+
inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane)
39+
{
40+
return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane));
41+
}
42+
43+
inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane)
44+
{
45+
return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane));
46+
}
47+
48+
template<uint N>
49+
inline vec<long, N> spvSubgroupShuffle(vec<long, N> value, ushort lane)
50+
{
51+
return vec<long, N>(spvSubgroupShuffle(vec<ulong, N>(value), lane));
52+
}
53+
54+
struct SSBO
55+
{
56+
uchar u8_1[4];
57+
char2 i8_2[4];
58+
uchar3 u8_3[4];
59+
char4 i8_4[4];
60+
ushort u16_1[4];
61+
short2 i16_2[4];
62+
ushort3 u16_3[4];
63+
short4 i16_4[4];
64+
half f16_1[4];
65+
half2 f16_2[4];
66+
half3 f16_3[4];
67+
half4 f16_4[4];
68+
ulong u64_1[4];
69+
long2 i64_2[4];
70+
ulong3 u64_3[4];
71+
long4 i64_4[4];
72+
};
73+
74+
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
75+
76+
kernel void main0(device SSBO& _53 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
77+
{
78+
uint _232 = gl_SubgroupInvocationID ^ 1u;
79+
_53.u8_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u8_1[gl_LocalInvocationIndex], _232);
80+
_53.i8_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i8_2[gl_LocalInvocationIndex], _232);
81+
_53.u8_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u8_3[gl_LocalInvocationIndex], _232);
82+
_53.i8_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i8_4[gl_LocalInvocationIndex], _232);
83+
_53.u16_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u16_1[gl_LocalInvocationIndex], _232);
84+
_53.i16_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i16_2[gl_LocalInvocationIndex], _232);
85+
_53.u16_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u16_3[gl_LocalInvocationIndex], _232);
86+
_53.i16_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i16_4[gl_LocalInvocationIndex], _232);
87+
_53.u64_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u64_1[gl_LocalInvocationIndex], _232);
88+
_53.i64_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i64_2[gl_LocalInvocationIndex], _232);
89+
_53.u64_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u64_3[gl_LocalInvocationIndex], _232);
90+
_53.i64_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i64_4[gl_LocalInvocationIndex], _232);
91+
_53.f16_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_1[gl_LocalInvocationIndex], _232);
92+
_53.f16_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_2[gl_LocalInvocationIndex], _232);
93+
_53.f16_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_3[gl_LocalInvocationIndex], _232);
94+
_53.f16_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_4[gl_LocalInvocationIndex], _232);
95+
}
96+
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#version 450
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
3+
#extension GL_EXT_shader_8bit_storage : require
4+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
5+
#extension GL_EXT_shader_16bit_storage : require
6+
#if defined(GL_AMD_gpu_shader_half_float)
7+
#extension GL_AMD_gpu_shader_half_float : require
8+
#elif defined(GL_EXT_shader_explicit_arithmetic_types_float16)
9+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
10+
#else
11+
#error No extension available for FP16.
12+
#endif
13+
#if defined(GL_ARB_gpu_shader_int64)
14+
#extension GL_ARB_gpu_shader_int64 : require
15+
#else
16+
#error No extension available for 64-bit integers.
17+
#endif
18+
#extension GL_KHR_shader_subgroup_basic : require
19+
#extension GL_EXT_shader_subgroup_extended_types_int8 : require
20+
#extension GL_KHR_shader_subgroup_shuffle : require
21+
#extension GL_EXT_shader_subgroup_extended_types_int16 : require
22+
#extension GL_EXT_shader_subgroup_extended_types_int64 : require
23+
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
24+
layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in;
25+
26+
layout(set = 0, binding = 0, std430) buffer SSBO
27+
{
28+
uint8_t u8_1[4];
29+
i8vec2 i8_2[4];
30+
u8vec3 u8_3[4];
31+
i8vec4 i8_4[4];
32+
uint16_t u16_1[4];
33+
i16vec2 i16_2[4];
34+
u16vec3 u16_3[4];
35+
i16vec4 i16_4[4];
36+
float16_t f16_1[4];
37+
f16vec2 f16_2[4];
38+
f16vec3 f16_3[4];
39+
f16vec4 f16_4[4];
40+
double f64_1[4];
41+
double f64_2[4];
42+
double f64_3[4];
43+
double f64_4[4];
44+
uint64_t u64_1[4];
45+
i64vec2 i64_2[4];
46+
u64vec3 u64_3[4];
47+
i64vec4 i64_4[4];
48+
} _60;
49+
50+
void main()
51+
{
52+
mediump uint _277 = gl_SubgroupInvocationID ^ 1u;
53+
_60.u8_1[gl_LocalInvocationIndex] = subgroupShuffle(_60.u8_1[gl_LocalInvocationIndex], _277);
54+
_60.i8_2[gl_LocalInvocationIndex] = subgroupShuffle(_60.i8_2[gl_LocalInvocationIndex], _277);
55+
_60.u8_3[gl_LocalInvocationIndex] = subgroupShuffle(_60.u8_3[gl_LocalInvocationIndex], _277);
56+
_60.i8_4[gl_LocalInvocationIndex] = subgroupShuffle(_60.i8_4[gl_LocalInvocationIndex], _277);
57+
_60.u16_1[gl_LocalInvocationIndex] = subgroupShuffle(_60.u16_1[gl_LocalInvocationIndex], _277);
58+
_60.i16_2[gl_LocalInvocationIndex] = subgroupShuffle(_60.i16_2[gl_LocalInvocationIndex], _277);
59+
_60.u16_3[gl_LocalInvocationIndex] = subgroupShuffle(_60.u16_3[gl_LocalInvocationIndex], _277);
60+
_60.i16_4[gl_LocalInvocationIndex] = subgroupShuffle(_60.i16_4[gl_LocalInvocationIndex], _277);
61+
_60.u64_1[gl_LocalInvocationIndex] = subgroupShuffle(_60.u64_1[gl_LocalInvocationIndex], _277);
62+
_60.i64_2[gl_LocalInvocationIndex] = subgroupShuffle(_60.i64_2[gl_LocalInvocationIndex], _277);
63+
_60.u64_3[gl_LocalInvocationIndex] = subgroupShuffle(_60.u64_3[gl_LocalInvocationIndex], _277);
64+
_60.i64_4[gl_LocalInvocationIndex] = subgroupShuffle(_60.i64_4[gl_LocalInvocationIndex], _277);
65+
_60.f16_1[gl_LocalInvocationIndex] = subgroupShuffle(_60.f16_1[gl_LocalInvocationIndex], _277);
66+
_60.f16_2[gl_LocalInvocationIndex] = subgroupShuffle(_60.f16_2[gl_LocalInvocationIndex], _277);
67+
_60.f16_3[gl_LocalInvocationIndex] = subgroupShuffle(_60.f16_3[gl_LocalInvocationIndex], _277);
68+
_60.f16_4[gl_LocalInvocationIndex] = subgroupShuffle(_60.f16_4[gl_LocalInvocationIndex], _277);
69+
_60.f64_1[gl_LocalInvocationIndex] = subgroupShuffle(_60.f64_1[gl_LocalInvocationIndex], _277);
70+
_60.f64_2[gl_LocalInvocationIndex] = subgroupShuffle(_60.f64_2[gl_LocalInvocationIndex], _277);
71+
_60.f64_3[gl_LocalInvocationIndex] = subgroupShuffle(_60.f64_3[gl_LocalInvocationIndex], _277);
72+
_60.f64_4[gl_LocalInvocationIndex] = subgroupShuffle(_60.f64_4[gl_LocalInvocationIndex], _277);
73+
}
74+

reference/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl22.ios.comp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,34 @@ inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
124124
return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);
125125
}
126126

127+
template<>
128+
inline ulong spvSubgroupShuffle(ulong value, ushort lane)
129+
{
130+
return as_type<ulong>(spvSubgroupShuffle(as_type<uint2>(value), lane));
131+
}
132+
133+
template<>
134+
inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane)
135+
{
136+
return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane));
137+
}
138+
139+
inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane)
140+
{
141+
return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane));
142+
}
143+
144+
inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane)
145+
{
146+
return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane));
147+
}
148+
149+
template<uint N>
150+
inline vec<long, N> spvSubgroupShuffle(vec<long, N> value, ushort lane)
151+
{
152+
return vec<long, N>(spvSubgroupShuffle(vec<ulong, N>(value), lane));
153+
}
154+
127155
template<typename T>
128156
inline T spvSubgroupShuffleXor(T value, ushort mask)
129157
{

reference/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl23.ios.simd.comp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,34 @@ inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
124124
return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);
125125
}
126126

127+
template<>
128+
inline ulong spvSubgroupShuffle(ulong value, ushort lane)
129+
{
130+
return as_type<ulong>(spvSubgroupShuffle(as_type<uint2>(value), lane));
131+
}
132+
133+
template<>
134+
inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane)
135+
{
136+
return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane));
137+
}
138+
139+
inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane)
140+
{
141+
return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane));
142+
}
143+
144+
inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane)
145+
{
146+
return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane));
147+
}
148+
149+
template<uint N>
150+
inline vec<long, N> spvSubgroupShuffle(vec<long, N> value, ushort lane)
151+
{
152+
return vec<long, N>(spvSubgroupShuffle(vec<ulong, N>(value), lane));
153+
}
154+
127155
template<typename T>
128156
inline T spvSubgroupShuffleXor(T value, ushort mask)
129157
{

reference/shaders-msl-no-opt/frag/subgroups.nocompat.vk.msl22.frag

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,34 @@ inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
129129
return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);
130130
}
131131

132+
template<>
133+
inline ulong spvSubgroupShuffle(ulong value, ushort lane)
134+
{
135+
return as_type<ulong>(spvSubgroupShuffle(as_type<uint2>(value), lane));
136+
}
137+
138+
template<>
139+
inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane)
140+
{
141+
return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane));
142+
}
143+
144+
inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane)
145+
{
146+
return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane));
147+
}
148+
149+
inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane)
150+
{
151+
return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane));
152+
}
153+
154+
template<uint N>
155+
inline vec<long, N> spvSubgroupShuffle(vec<long, N> value, ushort lane)
156+
{
157+
return vec<long, N>(spvSubgroupShuffle(vec<ulong, N>(value), lane));
158+
}
159+
132160
template<typename T>
133161
inline T spvSubgroupShuffleXor(T value, ushort mask)
134162
{
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
3+
#include <metal_stdlib>
4+
#include <simd/simd.h>
5+
6+
using namespace metal;
7+
8+
template<typename T>
9+
inline T spvSubgroupShuffle(T value, ushort lane)
10+
{
11+
return simd_shuffle(value, lane);
12+
}
13+
14+
template<>
15+
inline bool spvSubgroupShuffle(bool value, ushort lane)
16+
{
17+
return !!simd_shuffle((ushort)value, lane);
18+
}
19+
20+
template<uint N>
21+
inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)
22+
{
23+
return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);
24+
}
25+
26+
template<>
27+
inline ulong spvSubgroupShuffle(ulong value, ushort lane)
28+
{
29+
return as_type<ulong>(spvSubgroupShuffle(as_type<uint2>(value), lane));
30+
}
31+
32+
template<>
33+
inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane)
34+
{
35+
return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane));
36+
}
37+
38+
inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane)
39+
{
40+
return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane));
41+
}
42+
43+
inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane)
44+
{
45+
return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane));
46+
}
47+
48+
template<uint N>
49+
inline vec<long, N> spvSubgroupShuffle(vec<long, N> value, ushort lane)
50+
{
51+
return vec<long, N>(spvSubgroupShuffle(vec<ulong, N>(value), lane));
52+
}
53+
54+
struct SSBO
55+
{
56+
uchar u8_1[4];
57+
char2 i8_2[4];
58+
uchar3 u8_3[4];
59+
char4 i8_4[4];
60+
ushort u16_1[4];
61+
short2 i16_2[4];
62+
ushort3 u16_3[4];
63+
short4 i16_4[4];
64+
half f16_1[4];
65+
half2 f16_2[4];
66+
half3 f16_3[4];
67+
half4 f16_4[4];
68+
ulong u64_1[4];
69+
long2 i64_2[4];
70+
ulong3 u64_3[4];
71+
long4 i64_4[4];
72+
};
73+
74+
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
75+
76+
static inline __attribute__((always_inline))
77+
void int80(device SSBO& _53, thread uint& gl_LocalInvocationIndex, thread uint& gl_SubgroupInvocationID)
78+
{
79+
_53.u8_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u8_1[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
80+
_53.i8_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i8_2[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
81+
_53.u8_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u8_3[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
82+
_53.i8_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i8_4[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
83+
}
84+
85+
static inline __attribute__((always_inline))
86+
void int160(device SSBO& _53, thread uint& gl_LocalInvocationIndex, thread uint& gl_SubgroupInvocationID)
87+
{
88+
_53.u16_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u16_1[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
89+
_53.i16_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i16_2[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
90+
_53.u16_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u16_3[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
91+
_53.i16_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i16_4[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
92+
}
93+
94+
static inline __attribute__((always_inline))
95+
void int64(device SSBO& _53, thread uint& gl_LocalInvocationIndex, thread uint& gl_SubgroupInvocationID)
96+
{
97+
_53.u64_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u64_1[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
98+
_53.i64_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i64_2[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
99+
_53.u64_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.u64_3[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
100+
_53.i64_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.i64_4[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
101+
}
102+
103+
static inline __attribute__((always_inline))
104+
void fp16(device SSBO& _53, thread uint& gl_LocalInvocationIndex, thread uint& gl_SubgroupInvocationID)
105+
{
106+
_53.f16_1[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_1[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
107+
_53.f16_2[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_2[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
108+
_53.f16_3[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_3[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
109+
_53.f16_4[gl_LocalInvocationIndex] = spvSubgroupShuffle(_53.f16_4[gl_LocalInvocationIndex], gl_SubgroupInvocationID ^ 1u);
110+
}
111+
112+
kernel void main0(device SSBO& _53 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
113+
{
114+
int80(_53, gl_LocalInvocationIndex, gl_SubgroupInvocationID);
115+
int160(_53, gl_LocalInvocationIndex, gl_SubgroupInvocationID);
116+
int64(_53, gl_LocalInvocationIndex, gl_SubgroupInvocationID);
117+
fp16(_53, gl_LocalInvocationIndex, gl_SubgroupInvocationID);
118+
}
119+

0 commit comments

Comments
 (0)