Skip to content

Commit 3422b79

Browse files
committed
[LowerMemIntrinsics] Optimize memset lowering
This patch changes the memset lowering to match the optimized memcpy lowering. The memset lowering now queries TTI.getMemcpyLoopLoweringType for a preferred memory access type. If that type is larger than a byte, the memset is lowered into two loops: a main loop that stores a sufficiently wide vector splat of the SetValue with the preferred memory access type and a residual loop that covers the remaining bytes individually. If the memset size is statically known, the residual loop is replaced by a sequence of stores. This improves memset performance on gfx1030 (AMDGPU) in microbenchmarks by around 7-20x. I'm planning similar treatment for memset.pattern as a follow-up PR. For SWDEV-543208.
1 parent 8a88119 commit 3422b79

17 files changed

+4826
-301
lines changed

llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ LLVM_ABI bool expandMemMoveAsLoop(MemMoveInst *MemMove,
5959
const TargetTransformInfo &TTI);
6060

6161
/// Expand \p MemSet as a loop. \p MemSet is not deleted.
62-
LLVM_ABI void expandMemSetAsLoop(MemSetInst *MemSet);
62+
LLVM_ABI void expandMemSetAsLoop(MemSetInst *MemSet,
63+
const TargetTransformInfo &TTI);
6364

6465
/// Expand \p MemSetPattern as a loop. \p MemSet is not deleted.
6566
LLVM_ABI void expandMemSetPatternAsLoop(MemSetPatternInst *MemSet);

llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
369369
canEmitLibcall(TM, ParentFunc, RTLIB::MEMSET))
370370
break;
371371

372-
expandMemSetAsLoop(Memset);
372+
expandMemSetAsLoop(Memset, TTI);
373373
Changed = true;
374374
Memset->eraseFromParent();
375375
}
@@ -384,7 +384,9 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
384384
if (isa<ConstantInt>(Memset->getLength()))
385385
break;
386386

387-
expandMemSetAsLoop(Memset);
387+
Function *ParentFunc = Memset->getFunction();
388+
const TargetTransformInfo &TTI = LookupTTI(*ParentFunc);
389+
expandMemSetAsLoop(Memset, TTI);
388390
Changed = true;
389391
Memset->eraseFromParent();
390392
break;

llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
635635
MemSetInst &MSI) {
636636
if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
637637
return false;
638-
llvm::expandMemSetAsLoop(&MSI);
638+
llvm::expandMemSetAsLoop(&MSI,
639+
TM->getTargetTransformInfo(*MSI.getFunction()));
639640
MSI.eraseFromParent();
640641
return true;
641642
}

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ static cl::opt<size_t> InlineMaxBB(
8080
static cl::opt<unsigned> MemcpyLoopUnroll(
8181
"amdgpu-memcpy-loop-unroll",
8282
cl::desc("Unroll factor (affecting 4x32-bit operations) to use for memory "
83-
"operations when lowering memcpy as a loop"),
83+
"operations when lowering statically-sized memcpy, memmove, or"
84+
"memset as a loop"),
8485
cl::init(16), cl::Hidden);
8586

8687
static bool dependsOnLocalPhi(const Loop *L, const Value *Cond,

llvm/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
128128
} else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
129129
expandMemMoveAsLoop(Memmove, TTI);
130130
} else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
131-
expandMemSetAsLoop(Memset);
131+
expandMemSetAsLoop(Memset, TTI);
132132
}
133133
MemCall->eraseFromParent();
134134
}

llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "SPIRVTargetMachine.h"
2424
#include "SPIRVUtils.h"
2525
#include "llvm/ADT/StringExtras.h"
26+
#include "llvm/Analysis/TargetTransformInfo.h"
2627
#include "llvm/Analysis/ValueTracking.h"
2728
#include "llvm/CodeGen/IntrinsicLowering.h"
2829
#include "llvm/IR/IRBuilder.h"
@@ -93,7 +94,8 @@ static Function *getOrCreateFunction(Module *M, Type *RetTy,
9394
return NewF;
9495
}
9596

96-
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
97+
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
98+
const TargetTransformInfo &TTI) {
9799
// For @llvm.memset.* intrinsic cases with constant value and length arguments
98100
// are emulated via "storing" a constant array to the destination. For other
99101
// cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
@@ -137,7 +139,7 @@ static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
137139
auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
138140
MSI->isVolatile());
139141
IRB.CreateRetVoid();
140-
expandMemSetAsLoop(cast<MemSetInst>(MemSet));
142+
expandMemSetAsLoop(cast<MemSetInst>(MemSet), TTI);
141143
MemSet->eraseFromParent();
142144
break;
143145
}
@@ -399,6 +401,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
399401
bool Changed = false;
400402
const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
401403
SmallVector<Instruction *> EraseFromParent;
404+
const TargetTransformInfo &TTI = TM.getTargetTransformInfo(*F);
402405
for (BasicBlock &BB : *F) {
403406
for (Instruction &I : make_early_inc_range(BB)) {
404407
auto Call = dyn_cast<CallInst>(&I);
@@ -411,7 +414,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
411414
switch (II->getIntrinsicID()) {
412415
case Intrinsic::memset:
413416
case Intrinsic::bswap:
414-
Changed |= lowerIntrinsicToFunction(II);
417+
Changed |= lowerIntrinsicToFunction(II, TTI);
415418
break;
416419
case Intrinsic::fshl:
417420
case Intrinsic::fshr:
@@ -459,7 +462,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
459462
return false;
460463
return II->getCalledFunction()->getName().starts_with(Prefix);
461464
}))
462-
Changed |= lowerIntrinsicToFunction(II);
465+
Changed |= lowerIntrinsicToFunction(II, TTI);
463466
break;
464467
}
465468
}

llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp

Lines changed: 197 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,187 @@ static void createMemMoveLoopKnownSize(Instruction *InsertBefore,
930930
}
931931
}
932932

933+
/// Create a Value of \p DstType that consists of a sequence of copies of
934+
/// \p SetValue, using bitcasts and a vector splat.
935+
static Value *createMemSetSplat(const DataLayout &DL, IRBuilderBase &B,
936+
Value *SetValue, Type *DstType) {
937+
unsigned DstSize = DL.getTypeStoreSize(DstType);
938+
Type *SetValueType = SetValue->getType();
939+
unsigned SetValueSize = DL.getTypeStoreSize(SetValueType);
940+
assert(SetValueSize == DL.getTypeAllocSize(SetValueType) &&
941+
"Store size and alloc size of SetValue's type must match");
942+
assert(SetValueSize != 0 && DstSize % SetValueSize == 0 &&
943+
"DstType size must be a multiple of SetValue size");
944+
945+
Value *Result = SetValue;
946+
if (DstSize != SetValueSize) {
947+
if (!SetValueType->isIntegerTy() && !SetValueType->isFloatingPointTy()) {
948+
// If the type cannot be put into a vector, bitcast to iN first.
949+
LLVMContext &Ctx = SetValue->getContext();
950+
Result = B.CreateBitCast(Result, Type::getIntNTy(Ctx, SetValueSize * 8),
951+
"setvalue.toint");
952+
}
953+
// Form a sufficiently large vector consisting of SetValue, repeated.
954+
Result =
955+
B.CreateVectorSplat(DstSize / SetValueSize, Result, "setvalue.splat");
956+
}
957+
958+
// The value has the right size, but we might have to bitcast it to the right
959+
// type.
960+
if (Result->getType() != DstType) {
961+
Result = B.CreateBitCast(Result, DstType, "setvalue.splat.cast");
962+
}
963+
return Result;
964+
}
965+
966+
static void createMemSetLoopKnownSize(Instruction *InsertBefore, Value *DstAddr,
967+
ConstantInt *Len, Value *SetValue,
968+
Align DstAlign, bool IsVolatile,
969+
const TargetTransformInfo &TTI) {
970+
// No need to expand zero length memsets.
971+
if (Len->isZero())
972+
return;
973+
974+
BasicBlock *PreLoopBB = InsertBefore->getParent();
975+
Function *ParentFunc = PreLoopBB->getParent();
976+
const DataLayout &DL = ParentFunc->getDataLayout();
977+
LLVMContext &Ctx = PreLoopBB->getContext();
978+
979+
unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
980+
981+
Type *TypeOfLen = Len->getType();
982+
Type *Int8Type = Type::getInt8Ty(Ctx);
983+
assert(SetValue->getType() == Int8Type && "Can only set bytes");
984+
985+
// Use the same memory access type as for a memcpy with the same Dst and Src
986+
// alignment and address space.
987+
Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
988+
Ctx, Len, DstAS, DstAS, DstAlign, DstAlign, std::nullopt);
989+
unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
990+
991+
uint64_t LoopEndCount = alignDown(Len->getZExtValue(), LoopOpSize);
992+
993+
if (LoopEndCount != 0) {
994+
Value *SplatSetValue = nullptr;
995+
{
996+
IRBuilder<> PreLoopBuilder(InsertBefore);
997+
SplatSetValue =
998+
createMemSetSplat(DL, PreLoopBuilder, SetValue, LoopOpType);
999+
}
1000+
1001+
// Don't generate a residual loop, the remaining bytes are set with
1002+
// straight-line code.
1003+
LoopExpansionInfo LEI =
1004+
insertLoopExpansion(InsertBefore, Len, LoopOpSize, 0, "static-memset");
1005+
1006+
// Fill MainLoopBB
1007+
IRBuilder<> MainLoopBuilder(LEI.MainLoopIP);
1008+
Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
1009+
1010+
Value *DstGEP =
1011+
MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex);
1012+
1013+
MainLoopBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1014+
IsVolatile);
1015+
1016+
assert(!LEI.ResidualLoopIP && !LEI.ResidualLoopIndex &&
1017+
"No residual loop was requested");
1018+
}
1019+
1020+
uint64_t BytesSet = LoopEndCount;
1021+
uint64_t RemainingBytes = Len->getZExtValue() - BytesSet;
1022+
if (RemainingBytes == 0)
1023+
return;
1024+
1025+
IRBuilder<> RBuilder(InsertBefore);
1026+
1027+
SmallVector<Type *, 5> RemainingOps;
1028+
TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
1029+
DstAS, DstAS, DstAlign, DstAlign,
1030+
std::nullopt);
1031+
1032+
Type *PreviousOpTy = nullptr;
1033+
Value *SplatSetValue = nullptr;
1034+
for (auto *OpTy : RemainingOps) {
1035+
unsigned OperandSize = DL.getTypeStoreSize(OpTy);
1036+
Align PartDstAlign(commonAlignment(DstAlign, BytesSet));
1037+
1038+
// Avoid recomputing the splat SetValue if it's the same as for the last
1039+
// iteration.
1040+
if (OpTy != PreviousOpTy)
1041+
SplatSetValue = createMemSetSplat(DL, RBuilder, SetValue, OpTy);
1042+
1043+
Value *DstGEP = RBuilder.CreateInBoundsGEP(
1044+
Int8Type, DstAddr, ConstantInt::get(TypeOfLen, BytesSet));
1045+
RBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1046+
IsVolatile);
1047+
BytesSet += OperandSize;
1048+
PreviousOpTy = OpTy;
1049+
}
1050+
assert(BytesSet == Len->getZExtValue() &&
1051+
"Bytes set should match size in the call!");
1052+
}
1053+
1054+
static void createMemSetLoopUnknownSize(Instruction *InsertBefore,
1055+
Value *DstAddr, Value *Len,
1056+
Value *SetValue, Align DstAlign,
1057+
bool IsVolatile,
1058+
const TargetTransformInfo &TTI) {
1059+
BasicBlock *PreLoopBB = InsertBefore->getParent();
1060+
Function *ParentFunc = PreLoopBB->getParent();
1061+
const DataLayout &DL = ParentFunc->getDataLayout();
1062+
LLVMContext &Ctx = PreLoopBB->getContext();
1063+
1064+
unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
1065+
1066+
Type *Int8Type = Type::getInt8Ty(Ctx);
1067+
assert(SetValue->getType() == Int8Type && "Can only set bytes");
1068+
1069+
Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
1070+
Ctx, Len, DstAS, DstAS, DstAlign, DstAlign, std::nullopt);
1071+
unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
1072+
1073+
Type *ResidualLoopOpType = Int8Type;
1074+
unsigned ResidualLoopOpSize = DL.getTypeStoreSize(ResidualLoopOpType);
1075+
1076+
Value *SplatSetValue = SetValue;
1077+
{
1078+
IRBuilder<> PreLoopBuilder(InsertBefore);
1079+
SplatSetValue = createMemSetSplat(DL, PreLoopBuilder, SetValue, LoopOpType);
1080+
}
1081+
1082+
LoopExpansionInfo LEI = insertLoopExpansion(
1083+
InsertBefore, Len, LoopOpSize, ResidualLoopOpSize, "dynamic-memset");
1084+
1085+
// Fill MainLoopBB
1086+
IRBuilder<> MainLoopBuilder(LEI.MainLoopIP);
1087+
Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
1088+
1089+
Value *DstGEP =
1090+
MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex);
1091+
MainLoopBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1092+
IsVolatile);
1093+
1094+
// Fill ResidualLoopBB
1095+
if (!LEI.ResidualLoopIP)
1096+
return;
1097+
1098+
Align ResDstAlign(commonAlignment(PartDstAlign, ResidualLoopOpSize));
1099+
1100+
IRBuilder<> ResLoopBuilder(LEI.ResidualLoopIP);
1101+
1102+
Value *ResDstGEP = ResLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
1103+
LEI.ResidualLoopIndex);
1104+
ResLoopBuilder.CreateAlignedStore(SetValue, ResDstGEP, ResDstAlign,
1105+
IsVolatile);
1106+
}
1107+
9331108
static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr,
9341109
Value *CopyLen, Value *SetValue, Align DstAlign,
9351110
bool IsVolatile) {
1111+
// Currently no longer used for memset, only for memset.pattern.
1112+
// TODO: Update the memset.pattern lowering to also use the loop expansion
1113+
// framework and remove this function.
9361114
Type *TypeOfCopyLen = CopyLen->getType();
9371115
BasicBlock *OrigBB = InsertBefore->getParent();
9381116
Function *F = OrigBB->getParent();
@@ -1067,13 +1245,25 @@ bool llvm::expandMemMoveAsLoop(MemMoveInst *Memmove,
10671245
return true;
10681246
}
10691247

1070-
void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
1071-
createMemSetLoop(/* InsertBefore */ Memset,
1072-
/* DstAddr */ Memset->getRawDest(),
1073-
/* CopyLen */ Memset->getLength(),
1074-
/* SetValue */ Memset->getValue(),
1075-
/* Alignment */ Memset->getDestAlign().valueOrOne(),
1076-
Memset->isVolatile());
1248+
void llvm::expandMemSetAsLoop(MemSetInst *Memset,
1249+
const TargetTransformInfo &TTI) {
1250+
if (ConstantInt *CI = dyn_cast<ConstantInt>(Memset->getLength())) {
1251+
createMemSetLoopKnownSize(
1252+
/* InsertBefore */ Memset,
1253+
/* DstAddr */ Memset->getRawDest(),
1254+
/* Len */ CI,
1255+
/* SetValue */ Memset->getValue(),
1256+
/* DstAlign */ Memset->getDestAlign().valueOrOne(),
1257+
Memset->isVolatile(), TTI);
1258+
} else {
1259+
createMemSetLoopUnknownSize(
1260+
/* InsertBefore */ Memset,
1261+
/* DstAddr */ Memset->getRawDest(),
1262+
/* Len */ Memset->getLength(),
1263+
/* SetValue */ Memset->getValue(),
1264+
/* DstAlign */ Memset->getDestAlign().valueOrOne(),
1265+
Memset->isVolatile(), TTI);
1266+
}
10771267
}
10781268

10791269
void llvm::expandMemSetPatternAsLoop(MemSetPatternInst *Memset) {

0 commit comments

Comments
 (0)