Skip to content

Commit c678f94

Browse files
Merge pull request #2518 from KhronosGroup/fix-2507
MSL: Fix issues with fp16 trancendentals.
2 parents a92254a + b9c63db commit c678f94

File tree

5 files changed

+100
-16
lines changed

5 files changed

+100
-16
lines changed

reference/opt/shaders-msl/frag/fp16-trancendentals.frag

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ struct main0_out
77
{
88
half C [[color(0)]];
99
half D [[color(1)]];
10+
half3 o3 [[color(2)]];
1011
};
1112

1213
struct main0_in
1314
{
1415
half A [[user(locn0)]];
1516
half B [[user(locn1)]];
17+
half3 v3 [[user(locn2)]];
1618
};
1719

1820
fragment main0_out main0(main0_in in [[stage_in]])
@@ -31,19 +33,19 @@ fragment main0_out main0(main0_in in [[stage_in]])
3133
out.D += out.C;
3234
out.C = clamp(atan(in.A), in.A, in.B);
3335
out.D += out.C;
34-
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
36+
out.C = clamp(half(fast::sinh(float(in.A))), in.A, in.B);
3537
out.D += out.C;
36-
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
38+
out.C = clamp(half(fast::cosh(float(in.A))), in.A, in.B);
3739
out.D += out.C;
38-
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
40+
out.C = clamp(half(fast::tanh(float(in.A))), in.A, in.B);
3941
out.D += out.C;
4042
out.C = clamp(asinh(in.A), in.A, in.B);
4143
out.D += out.C;
4244
out.C = clamp(acosh(in.A), in.A, in.B);
4345
out.D += out.C;
4446
out.C = clamp(atanh(in.A), in.A, in.B);
4547
out.D += out.C;
46-
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
48+
out.C = clamp(half(fast::atan2(float(in.A), float(in.B))), in.A, in.B);
4749
out.D += out.C;
4850
out.C = clamp(powr(in.A, in.B), in.A, in.B);
4951
out.D += out.C;
@@ -59,6 +61,26 @@ fragment main0_out main0(main0_in in [[stage_in]])
5961
out.D += out.C;
6062
out.C = clamp(rsqrt(in.A), in.A, in.B);
6163
out.D += out.C;
64+
out.o3 = half3(fast::sinh(float3(in.v3)));
65+
out.o3 += half3(fast::cosh(float3(in.v3)));
66+
out.o3 += half3(fast::tanh(float3(in.v3)));
67+
out.o3 += sin(in.v3);
68+
out.o3 += cos(in.v3);
69+
out.o3 += tan(in.v3);
70+
out.o3 += asin(in.v3);
71+
out.o3 += acos(in.v3);
72+
out.o3 += atan(in.v3);
73+
out.o3 += asinh(in.v3);
74+
out.o3 += acosh(in.v3);
75+
out.o3 += atanh(in.v3);
76+
out.o3 += half3(fast::atan2(float3(out.o3), float3(in.v3)));
77+
out.o3 += powr(out.o3, in.v3);
78+
out.o3 += exp(in.v3);
79+
out.o3 += exp2(in.v3);
80+
out.o3 += log(in.v3);
81+
out.o3 += log2(in.v3);
82+
out.o3 += sqrt(in.v3);
83+
out.o3 += rsqrt(in.v3);
6284
return out;
6385
}
6486

reference/shaders-msl-no-opt/frag/fp16.desktop.invalid.frag

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ void test_builtins(thread half4& v4, thread half3& v3, thread half& v1)
100100
res = cos(v4);
101101
res = tan(v4);
102102
res = asin(v4);
103-
res = half(fast::atan2(v4, v3.xyzz));
103+
res = half4(fast::atan2(float4(v4), float4(v3.xyzz)));
104104
res = atan(v4);
105-
res = half(fast::sinh(v4));
106-
res = half(fast::cosh(v4));
107-
res = half(fast::tanh(v4));
105+
res = half4(fast::sinh(float4(v4)));
106+
res = half4(fast::cosh(float4(v4)));
107+
res = half4(fast::tanh(float4(v4)));
108108
res = asinh(v4);
109109
res = acosh(v4);
110110
res = atanh(v4);

reference/shaders-msl/frag/fp16-trancendentals.frag

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ struct main0_out
77
{
88
half C [[color(0)]];
99
half D [[color(1)]];
10+
half3 o3 [[color(2)]];
1011
};
1112

1213
struct main0_in
1314
{
1415
half A [[user(locn0)]];
1516
half B [[user(locn1)]];
17+
half3 v3 [[user(locn2)]];
1618
};
1719

1820
fragment main0_out main0(main0_in in [[stage_in]])
@@ -31,19 +33,19 @@ fragment main0_out main0(main0_in in [[stage_in]])
3133
out.D += out.C;
3234
out.C = clamp(atan(in.A), in.A, in.B);
3335
out.D += out.C;
34-
out.C = clamp(half(fast::sinh(in.A)), in.A, in.B);
36+
out.C = clamp(half(fast::sinh(float(in.A))), in.A, in.B);
3537
out.D += out.C;
36-
out.C = clamp(half(fast::cosh(in.A)), in.A, in.B);
38+
out.C = clamp(half(fast::cosh(float(in.A))), in.A, in.B);
3739
out.D += out.C;
38-
out.C = clamp(half(fast::tanh(in.A)), in.A, in.B);
40+
out.C = clamp(half(fast::tanh(float(in.A))), in.A, in.B);
3941
out.D += out.C;
4042
out.C = clamp(asinh(in.A), in.A, in.B);
4143
out.D += out.C;
4244
out.C = clamp(acosh(in.A), in.A, in.B);
4345
out.D += out.C;
4446
out.C = clamp(atanh(in.A), in.A, in.B);
4547
out.D += out.C;
46-
out.C = clamp(half(fast::atan2(in.A, in.B)), in.A, in.B);
48+
out.C = clamp(half(fast::atan2(float(in.A), float(in.B))), in.A, in.B);
4749
out.D += out.C;
4850
out.C = clamp(powr(in.A, in.B), in.A, in.B);
4951
out.D += out.C;
@@ -59,6 +61,26 @@ fragment main0_out main0(main0_in in [[stage_in]])
5961
out.D += out.C;
6062
out.C = clamp(rsqrt(in.A), in.A, in.B);
6163
out.D += out.C;
64+
out.o3 = half3(fast::sinh(float3(in.v3)));
65+
out.o3 += half3(fast::cosh(float3(in.v3)));
66+
out.o3 += half3(fast::tanh(float3(in.v3)));
67+
out.o3 += sin(in.v3);
68+
out.o3 += cos(in.v3);
69+
out.o3 += tan(in.v3);
70+
out.o3 += asin(in.v3);
71+
out.o3 += acos(in.v3);
72+
out.o3 += atan(in.v3);
73+
out.o3 += asinh(in.v3);
74+
out.o3 += acosh(in.v3);
75+
out.o3 += atanh(in.v3);
76+
out.o3 += half3(fast::atan2(float3(out.o3), float3(in.v3)));
77+
out.o3 += powr(out.o3, in.v3);
78+
out.o3 += exp(in.v3);
79+
out.o3 += exp2(in.v3);
80+
out.o3 += log(in.v3);
81+
out.o3 += log2(in.v3);
82+
out.o3 += sqrt(in.v3);
83+
out.o3 += rsqrt(in.v3);
6284
return out;
6385
}
6486

shaders-msl/frag/fp16-trancendentals.frag

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ layout(location = 1) in float16_t B;
66
layout(location = 0) out float16_t C;
77
layout(location = 1) out float16_t D;
88

9+
layout(location = 2) in f16vec3 v3;
10+
layout(location = 2) out f16vec3 o3;
11+
912
void main()
1013
{
1114
D = 0.0hf;
@@ -29,5 +32,27 @@ void main()
2932
C = clamp(log2(A), A, B); D += C;
3033
C = clamp(sqrt(A), A, B); D += C;
3134
C = clamp(inversesqrt(A), A, B); D += C;
35+
36+
// Vector variants since it tripped up overload resolution.
37+
o3 = sinh(v3);
38+
o3 += cosh(v3);
39+
o3 += tanh(v3);
40+
o3 += sin(v3);
41+
o3 += cos(v3);
42+
o3 += tan(v3);
43+
o3 += asin(v3);
44+
o3 += acos(v3);
45+
o3 += atan(v3);
46+
o3 += asinh(v3);
47+
o3 += acosh(v3);
48+
o3 += atanh(v3);
49+
o3 += atan(o3, v3);
50+
o3 += pow(o3, v3);
51+
o3 += exp(v3);
52+
o3 += exp2(v3);
53+
o3 += log(v3);
54+
o3 += log2(v3);
55+
o3 += sqrt(v3);
56+
o3 += inversesqrt(v3);
3257
}
3358

spirv_msl.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11244,8 +11244,11 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
1124411244
case GLSLstd450Sinh:
1124511245
if (restype.basetype == SPIRType::Half)
1124611246
{
11247+
auto ftype = restype;
11248+
ftype.basetype = SPIRType::Float;
11249+
1124711250
// MSL does not have overload for half. Force-cast back to half.
11248-
auto expr = join("half(fast::sinh(", to_unpacked_expression(args[0]), "))");
11251+
auto expr = join(type_to_glsl(restype), "(fast::sinh(", type_to_glsl(ftype), "(", to_unpacked_expression(args[0]), ")))");
1124911252
emit_op(result_type, id, expr, should_forward(args[0]));
1125011253
inherit_expression_dependencies(id, args[0]);
1125111254
}
@@ -11255,8 +11258,11 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
1125511258
case GLSLstd450Cosh:
1125611259
if (restype.basetype == SPIRType::Half)
1125711260
{
11261+
auto ftype = restype;
11262+
ftype.basetype = SPIRType::Float;
11263+
1125811264
// MSL does not have overload for half. Force-cast back to half.
11259-
auto expr = join("half(fast::cosh(", to_unpacked_expression(args[0]), "))");
11265+
auto expr = join(type_to_glsl(restype), "(fast::cosh(", type_to_glsl(ftype), "(", to_unpacked_expression(args[0]), ")))");
1126011266
emit_op(result_type, id, expr, should_forward(args[0]));
1126111267
inherit_expression_dependencies(id, args[0]);
1126211268
}
@@ -11266,8 +11272,11 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
1126611272
case GLSLstd450Tanh:
1126711273
if (restype.basetype == SPIRType::Half)
1126811274
{
11275+
auto ftype = restype;
11276+
ftype.basetype = SPIRType::Float;
11277+
1126911278
// MSL does not have overload for half. Force-cast back to half.
11270-
auto expr = join("half(fast::tanh(", to_unpacked_expression(args[0]), "))");
11279+
auto expr = join(type_to_glsl(restype), "(fast::tanh(", type_to_glsl(ftype), "(", to_unpacked_expression(args[0]), ")))");
1127111280
emit_op(result_type, id, expr, should_forward(args[0]));
1127211281
inherit_expression_dependencies(id, args[0]);
1127311282
}
@@ -11278,7 +11287,13 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
1127811287
if (restype.basetype == SPIRType::Half)
1127911288
{
1128011289
// MSL does not have overload for half. Force-cast back to half.
11281-
auto expr = join("half(fast::atan2(", to_unpacked_expression(args[0]), ", ", to_unpacked_expression(args[1]), "))");
11290+
auto ftype = restype;
11291+
ftype.basetype = SPIRType::Float;
11292+
11293+
auto expr = join(type_to_glsl(restype),
11294+
"(fast::atan2(",
11295+
type_to_glsl(ftype), "(", to_unpacked_expression(args[0]), "), ",
11296+
type_to_glsl(ftype), "(", to_unpacked_expression(args[1]), ")))");
1128211297
emit_op(result_type, id, expr, should_forward(args[0]) && should_forward(args[1]));
1128311298
inherit_expression_dependencies(id, args[0]);
1128411299
inherit_expression_dependencies(id, args[1]);

0 commit comments

Comments
 (0)