Skip to content

Commit 1edeecc

Browse files
committed
[HLSL][DXIL][SPIRV] WavePrefixSum: added unsigned DXIL intrinsic variant
- Added missing int_dx_wave_prefix_usum to DXIL.td - NFC change to HLSL CodeGen that introduced getUnsignedIntrinsicVariant so it can be used by the new GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED macro to pick the correct intrinsic based on the unsigned condition bool.
1 parent 0e77b1a commit 1edeecc

File tree

8 files changed

+162
-90
lines changed

8 files changed

+162
-90
lines changed

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -240,61 +240,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
240240
return RT.getFirstBitUHighIntrinsic();
241241
}
242242

243-
// Return wave active sum that corresponds to the QT scalar type
244-
static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
245-
CGHLSLRuntime &RT, QualType QT) {
246-
switch (Arch) {
247-
case llvm::Triple::spirv:
248-
return Intrinsic::spv_wave_reduce_sum;
249-
case llvm::Triple::dxil: {
250-
if (QT->isUnsignedIntegerType())
251-
return Intrinsic::dx_wave_reduce_usum;
252-
return Intrinsic::dx_wave_reduce_sum;
253-
}
254-
default:
255-
llvm_unreachable("Intrinsic WaveActiveSum"
256-
" not supported by target architecture");
257-
}
258-
}
259-
260-
// Return wave active max that corresponds to the QT scalar type
261-
static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
262-
CGHLSLRuntime &RT, QualType QT) {
263-
switch (Arch) {
264-
case llvm::Triple::spirv:
265-
if (QT->isUnsignedIntegerType())
266-
return Intrinsic::spv_wave_reduce_umax;
267-
return Intrinsic::spv_wave_reduce_max;
268-
case llvm::Triple::dxil: {
269-
if (QT->isUnsignedIntegerType())
270-
return Intrinsic::dx_wave_reduce_umax;
271-
return Intrinsic::dx_wave_reduce_max;
272-
}
273-
default:
274-
llvm_unreachable("Intrinsic WaveActiveMax"
275-
" not supported by target architecture");
276-
}
277-
}
278-
279-
// Return wave active min that corresponds to the QT scalar type
280-
static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
281-
CGHLSLRuntime &RT, QualType QT) {
282-
switch (Arch) {
283-
case llvm::Triple::spirv:
284-
if (QT->isUnsignedIntegerType())
285-
return Intrinsic::spv_wave_reduce_umin;
286-
return Intrinsic::spv_wave_reduce_min;
287-
case llvm::Triple::dxil: {
288-
if (QT->isUnsignedIntegerType())
289-
return Intrinsic::dx_wave_reduce_umin;
290-
return Intrinsic::dx_wave_reduce_min;
291-
}
292-
default:
293-
llvm_unreachable("Intrinsic WaveActiveMin"
294-
" not supported by target architecture");
295-
}
296-
}
297-
298243
// Returns the mangled name for a builtin function that the SPIR-V backend
299244
// will expand into a spec Constant.
300245
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
@@ -794,33 +739,33 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
794739
ArrayRef{OpExpr});
795740
}
796741
case Builtin::BI__builtin_hlsl_wave_active_sum: {
797-
// Due to the use of variadic arguments, explicitly retreive argument
742+
// Due to the use of variadic arguments, explicitly retrieve argument
798743
Value *OpExpr = EmitScalarExpr(E->getArg(0));
799-
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
800-
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
801-
E->getArg(0)->getType());
744+
QualType QT = E->getArg(0)->getType();
745+
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveSumIntrinsic(
746+
QT->isUnsignedIntegerType());
802747

803748
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
804749
&CGM.getModule(), IID, {OpExpr->getType()}),
805750
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
806751
}
807752
case Builtin::BI__builtin_hlsl_wave_active_max: {
808-
// Due to the use of variadic arguments, explicitly retreive argument
753+
// Due to the use of variadic arguments, explicitly retrieve argument
809754
Value *OpExpr = EmitScalarExpr(E->getArg(0));
810-
Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
811-
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
812-
E->getArg(0)->getType());
755+
QualType QT = E->getArg(0)->getType();
756+
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMaxIntrinsic(
757+
QT->isUnsignedIntegerType());
813758

814759
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
815760
&CGM.getModule(), IID, {OpExpr->getType()}),
816761
ArrayRef{OpExpr}, "hlsl.wave.active.max");
817762
}
818763
case Builtin::BI__builtin_hlsl_wave_active_min: {
819-
// Due to the use of variadic arguments, explicitly retreive argument
764+
// Due to the use of variadic arguments, explicitly retrieve argument
820765
Value *OpExpr = EmitScalarExpr(E->getArg(0));
821-
Intrinsic::ID IID = getWaveActiveMinIntrinsic(
822-
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
823-
E->getArg(0)->getType());
766+
QualType QT = E->getArg(0)->getType();
767+
Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveMinIntrinsic(
768+
QT->isUnsignedIntegerType());
824769

825770
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
826771
&CGM.getModule(), IID, {OpExpr->getType()}),
@@ -866,7 +811,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
866811
}
867812
case Builtin::BI__builtin_hlsl_wave_prefix_sum: {
868813
Value *OpExpr = EmitScalarExpr(E->getArg(0));
869-
Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic();
814+
QualType QT = E->getArg(0)->getType();
815+
Intrinsic::ID IID = CGM.getHLSLRuntime().getWavePrefixSumIntrinsic(
816+
QT->isUnsignedIntegerType());
870817
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
871818
&CGM.getModule(), IID, {OpExpr->getType()}),
872819
ArrayRef{OpExpr}, "hlsl.wave.prefix.sum");

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,29 @@ llvm::Triple::ArchType CGHLSLRuntime::getArch() {
276276
return CGM.getTarget().getTriple().getArch();
277277
}
278278

279+
llvm::Intrinsic::ID
280+
CGHLSLRuntime::getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID) {
281+
switch (IID) {
282+
// DXIL intrinsics
283+
case Intrinsic::dx_wave_reduce_sum:
284+
return Intrinsic::dx_wave_reduce_usum;
285+
case Intrinsic::dx_wave_reduce_max:
286+
return Intrinsic::dx_wave_reduce_umax;
287+
case Intrinsic::dx_wave_reduce_min:
288+
return Intrinsic::dx_wave_reduce_umin;
289+
case Intrinsic::dx_wave_prefix_sum:
290+
return Intrinsic::dx_wave_prefix_usum;
291+
292+
// SPIR-V intrinsics
293+
case Intrinsic::spv_wave_reduce_max:
294+
return Intrinsic::spv_wave_reduce_umax;
295+
case Intrinsic::spv_wave_reduce_min:
296+
return Intrinsic::spv_wave_reduce_umin;
297+
default:
298+
return IID;
299+
}
300+
}
301+
279302
// Emits constant global variables for buffer constants declarations
280303
// and creates metadata linking the constant globals with the buffer global.
281304
void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@
4949
} \
5050
}
5151

52+
// A function generator macro for picking the right intrinsic for the target
53+
// backend given IsUnsigned boolean condition. If IsUnsigned == true, it calls
54+
// getUnsignedIntrinsicVariant(IID) to retrieve the unsigned variant of the
55+
// intrinsic else the regular intrinsic is returned. (NOTE:
56+
// getUnsignedIntrinsicVariant returns IID itself if there is no unsigned
57+
// variant).
58+
#define GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(FunctionName, \
59+
IntrinsicPostfix) \
60+
llvm::Intrinsic::ID get##FunctionName##Intrinsic(bool IsUnsigned) { \
61+
llvm::Triple::ArchType Arch = getArch(); \
62+
switch (Arch) { \
63+
case llvm::Triple::dxil: { \
64+
static constexpr llvm::Intrinsic::ID IID = \
65+
llvm::Intrinsic::dx_##IntrinsicPostfix; \
66+
return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \
67+
} \
68+
case llvm::Triple::spirv: { \
69+
static constexpr llvm::Intrinsic::ID IID = \
70+
llvm::Intrinsic::spv_##IntrinsicPostfix; \
71+
return IsUnsigned ? getUnsignedIntrinsicVariant(IID) : IID; \
72+
} \
73+
default: \
74+
llvm_unreachable("Intrinsic " #IntrinsicPostfix \
75+
" not supported by target architecture"); \
76+
} \
77+
}
78+
5279
using ResourceClass = llvm::dxil::ResourceClass;
5380

5481
namespace llvm {
@@ -141,10 +168,17 @@ class CGHLSLRuntime {
141168
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
142169
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
143170
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
171+
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveSum,
172+
wave_reduce_sum)
173+
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMax,
174+
wave_reduce_max)
175+
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WaveActiveMin,
176+
wave_reduce_min)
144177
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
145178
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
146179
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
147-
GENERATE_HLSL_INTRINSIC_FUNCTION(WavePrefixSum, wave_prefix_sum)
180+
GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED(WavePrefixSum,
181+
wave_prefix_sum)
148182
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
149183
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
150184
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitLow, firstbitlow)
@@ -247,6 +281,10 @@ class CGHLSLRuntime {
247281

248282
llvm::Triple::ArchType getArch();
249283

284+
// Returns the unsigned variant of the given intrinsic ID if possible,
285+
// otherwise, the original intrinsic ID is returned.
286+
llvm::Intrinsic::ID getUnsignedIntrinsicVariant(llvm::Intrinsic::ID IID);
287+
250288
llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
251289
unsigned SPIRVLastAssignedInputSemanticLocation = 0;
252290
};

clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ int test_int(int expr) {
2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
2323
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.prefix.sum.i64([[TY]] %[[#]])
24-
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.sum.i64([[TY]] %[[#]])
24+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.usum.i64([[TY]] %[[#]])
2525
// CHECK: ret [[TY]] %[[RET]]
2626
return WavePrefixSum(expr);
2727
}
2828

29-
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.sum.i64([[TY]]) #[[#attr:]]
29+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.prefix.usum.i64([[TY]]) #[[#attr:]]
3030
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.prefix.sum.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0
165165
def int_dx_wave_get_lane_count
166166
: DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent]>;
167167
def int_dx_wave_prefix_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
168+
def int_dx_wave_prefix_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
168169
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
169170
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
170171
def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>],

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,11 +1084,17 @@ def WavePrefixOp : DXILOp<121, wavePrefixOp> {
10841084
let intrinsics = [
10851085
IntrinSelect<int_dx_wave_prefix_sum,
10861086
[
1087-
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>
1088-
]>
1087+
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>,
1088+
IntrinArgI8<SignedOpKind_Signed>
1089+
]>,
1090+
IntrinSelect<int_dx_wave_prefix_usum,
1091+
[
1092+
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>,
1093+
IntrinArgI8<SignedOpKind_Unsigned>
1094+
]>,
10891095
];
10901096

1091-
let arguments = [OverloadTy, Int8Ty];
1097+
let arguments = [OverloadTy, Int8Ty, Int8Ty];
10921098
let result = OverloadTy;
10931099
let overloads = [
10941100
Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int16Ty, Int32Ty, Int64Ty]>

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
5656
case Intrinsic::dx_saturate:
5757
case Intrinsic::dx_splitdouble:
5858
case Intrinsic::dx_wave_readlane:
59-
case Intrinsic::dx_wave_prefix_sum:
6059
case Intrinsic::dx_wave_reduce_max:
6160
case Intrinsic::dx_wave_reduce_min:
6261
case Intrinsic::dx_wave_reduce_sum:
62+
case Intrinsic::dx_wave_prefix_sum:
6363
case Intrinsic::dx_wave_reduce_umax:
6464
case Intrinsic::dx_wave_reduce_umin:
6565
case Intrinsic::dx_wave_reduce_usum:
66+
case Intrinsic::dx_wave_prefix_usum:
6667
case Intrinsic::dx_imad:
6768
case Intrinsic::dx_umad:
6869
return true;

0 commit comments

Comments
 (0)