Skip to content

Commit 49bc3f0

Browse files
committed
[LowerMemIntrinsics] Optimize memset lowering
This patch changes the memset lowering to match the optimized memcpy lowering. The memset lowering now queries TTI.getMemcpyLoopLoweringType for a preferred memory access type. If that type is larger than a byte, the memset is lowered into two loops: a main loop that stores a sufficiently wide vector splat of the SetValue with the preferred memory access type and a residual loop that covers the remaining bytes individually. If the memset size is statically known, the residual loop is replaced by a sequence of stores. This improves memset performance on gfx1030 (AMDGPU) in microbenchmarks by around 7-20x. I'm planning similar treatment for memset.pattern as a follow-up PR. For SWDEV-543208.
1 parent a17777e commit 49bc3f0

17 files changed

+4822
-301
lines changed

llvm/include/llvm/Transforms/Utils/LowerMemIntrinsics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ LLVM_ABI bool expandMemMoveAsLoop(MemMoveInst *MemMove,
5959
const TargetTransformInfo &TTI);
6060

6161
/// Expand \p MemSet as a loop. \p MemSet is not deleted.
62-
LLVM_ABI void expandMemSetAsLoop(MemSetInst *MemSet);
62+
LLVM_ABI void expandMemSetAsLoop(MemSetInst *MemSet,
63+
const TargetTransformInfo &TTI);
6364

6465
/// Expand \p MemSetPattern as a loop. \p MemSet is not deleted.
6566
LLVM_ABI void expandMemSetPatternAsLoop(MemSetPatternInst *MemSet);

llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
369369
canEmitLibcall(TM, ParentFunc, RTLIB::MEMSET))
370370
break;
371371

372-
expandMemSetAsLoop(Memset);
372+
expandMemSetAsLoop(Memset, TTI);
373373
Changed = true;
374374
Memset->eraseFromParent();
375375
}
@@ -384,7 +384,9 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(
384384
if (isa<ConstantInt>(Memset->getLength()))
385385
break;
386386

387-
expandMemSetAsLoop(Memset);
387+
Function *ParentFunc = Memset->getFunction();
388+
const TargetTransformInfo &TTI = LookupTTI(*ParentFunc);
389+
expandMemSetAsLoop(Memset, TTI);
388390
Changed = true;
389391
Memset->eraseFromParent();
390392
break;

llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
635635
MemSetInst &MSI) {
636636
if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
637637
return false;
638-
llvm::expandMemSetAsLoop(&MSI);
638+
llvm::expandMemSetAsLoop(&MSI,
639+
TM->getTargetTransformInfo(*MSI.getFunction()));
639640
MSI.eraseFromParent();
640641
return true;
641642
}

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ static cl::opt<size_t> InlineMaxBB(
8080
static cl::opt<unsigned> MemcpyLoopUnroll(
8181
"amdgpu-memcpy-loop-unroll",
8282
cl::desc("Unroll factor (affecting 4x32-bit operations) to use for memory "
83-
"operations when lowering memcpy as a loop"),
83+
"operations when lowering statically-sized memcpy, memmove, or"
84+
"memset as a loop"),
8485
cl::init(16), cl::Hidden);
8586

8687
static bool dependsOnLocalPhi(const Loop *L, const Value *Cond,

llvm/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
128128
} else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
129129
expandMemMoveAsLoop(Memmove, TTI);
130130
} else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
131-
expandMemSetAsLoop(Memset);
131+
expandMemSetAsLoop(Memset, TTI);
132132
}
133133
MemCall->eraseFromParent();
134134
}

llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "SPIRVTargetMachine.h"
2424
#include "SPIRVUtils.h"
2525
#include "llvm/ADT/StringExtras.h"
26+
#include "llvm/Analysis/TargetTransformInfo.h"
2627
#include "llvm/Analysis/ValueTracking.h"
2728
#include "llvm/CodeGen/IntrinsicLowering.h"
2829
#include "llvm/IR/IRBuilder.h"
@@ -93,7 +94,8 @@ static Function *getOrCreateFunction(Module *M, Type *RetTy,
9394
return NewF;
9495
}
9596

96-
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
97+
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
98+
const TargetTransformInfo &TTI) {
9799
// For @llvm.memset.* intrinsic cases with constant value and length arguments
98100
// are emulated via "storing" a constant array to the destination. For other
99101
// cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
@@ -137,7 +139,7 @@ static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
137139
auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
138140
MSI->isVolatile());
139141
IRB.CreateRetVoid();
140-
expandMemSetAsLoop(cast<MemSetInst>(MemSet));
142+
expandMemSetAsLoop(cast<MemSetInst>(MemSet), TTI);
141143
MemSet->eraseFromParent();
142144
break;
143145
}
@@ -399,6 +401,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
399401
bool Changed = false;
400402
const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
401403
SmallVector<Instruction *> EraseFromParent;
404+
const TargetTransformInfo &TTI = TM.getTargetTransformInfo(*F);
402405
for (BasicBlock &BB : *F) {
403406
for (Instruction &I : make_early_inc_range(BB)) {
404407
auto Call = dyn_cast<CallInst>(&I);
@@ -411,7 +414,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
411414
switch (II->getIntrinsicID()) {
412415
case Intrinsic::memset:
413416
case Intrinsic::bswap:
414-
Changed |= lowerIntrinsicToFunction(II);
417+
Changed |= lowerIntrinsicToFunction(II, TTI);
415418
break;
416419
case Intrinsic::fshl:
417420
case Intrinsic::fshr:
@@ -459,7 +462,7 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
459462
return false;
460463
return II->getCalledFunction()->getName().starts_with(Prefix);
461464
}))
462-
Changed |= lowerIntrinsicToFunction(II);
465+
Changed |= lowerIntrinsicToFunction(II, TTI);
463466
break;
464467
}
465468
}

llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp

Lines changed: 197 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,187 @@ static void createMemMoveLoopKnownSize(Instruction *InsertBefore,
929929
}
930930
}
931931

932+
/// Create a Value of \p DstType that consists of a sequence of copies of
933+
/// \p SetValue, using bitcasts and a vector splat.
934+
static Value *createMemSetSplat(const DataLayout &DL, IRBuilderBase &B,
935+
Value *SetValue, Type *DstType) {
936+
unsigned DstSize = DL.getTypeStoreSize(DstType);
937+
Type *SetValueType = SetValue->getType();
938+
unsigned SetValueSize = DL.getTypeStoreSize(SetValueType);
939+
assert(SetValueSize == DL.getTypeAllocSize(SetValueType) &&
940+
"Store size and alloc size of SetValue's type must match");
941+
assert(SetValueSize != 0 && DstSize % SetValueSize == 0 &&
942+
"DstType size must be a multiple of SetValue size");
943+
944+
Value *Result = SetValue;
945+
if (DstSize != SetValueSize) {
946+
if (!SetValueType->isIntegerTy() && !SetValueType->isFloatingPointTy()) {
947+
// If the type cannot be put into a vector, bitcast to iN first.
948+
LLVMContext &Ctx = SetValue->getContext();
949+
Result = B.CreateBitCast(Result, Type::getIntNTy(Ctx, SetValueSize * 8),
950+
"setvalue.toint");
951+
}
952+
// Form a sufficiently large vector consisting of SetValue, repeated.
953+
Result =
954+
B.CreateVectorSplat(DstSize / SetValueSize, Result, "setvalue.splat");
955+
}
956+
957+
// The value has the right size, but we might have to bitcast it to the right
958+
// type.
959+
if (Result->getType() != DstType) {
960+
Result = B.CreateBitCast(Result, DstType, "setvalue.splat.cast");
961+
}
962+
return Result;
963+
}
964+
965+
static void createMemSetLoopKnownSize(Instruction *InsertBefore, Value *DstAddr,
966+
ConstantInt *Len, Value *SetValue,
967+
Align DstAlign, bool IsVolatile,
968+
const TargetTransformInfo &TTI) {
969+
// No need to expand zero length memsets.
970+
if (Len->isZero())
971+
return;
972+
973+
BasicBlock *PreLoopBB = InsertBefore->getParent();
974+
Function *ParentFunc = PreLoopBB->getParent();
975+
const DataLayout &DL = ParentFunc->getDataLayout();
976+
LLVMContext &Ctx = PreLoopBB->getContext();
977+
978+
unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
979+
980+
Type *TypeOfLen = Len->getType();
981+
Type *Int8Type = Type::getInt8Ty(Ctx);
982+
assert(SetValue->getType() == Int8Type && "Can only set bytes");
983+
984+
// Use the same memory access type as for a memcpy with the same Dst and Src
985+
// alignment and address space.
986+
Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
987+
Ctx, Len, DstAS, DstAS, DstAlign, DstAlign, std::nullopt);
988+
unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
989+
990+
uint64_t LoopEndCount = alignDown(Len->getZExtValue(), LoopOpSize);
991+
992+
if (LoopEndCount != 0) {
993+
Value *SplatSetValue = nullptr;
994+
{
995+
IRBuilder<> PreLoopBuilder(InsertBefore);
996+
SplatSetValue =
997+
createMemSetSplat(DL, PreLoopBuilder, SetValue, LoopOpType);
998+
}
999+
1000+
// Don't generate a residual loop, the remaining bytes are set with
1001+
// straight-line code.
1002+
LoopExpansionInfo LEI =
1003+
insertLoopExpansion(InsertBefore, Len, LoopOpSize, 0, "static-memset");
1004+
1005+
// Fill MainLoopBB
1006+
IRBuilder<> MainLoopBuilder(LEI.MainLoopIP);
1007+
Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
1008+
1009+
Value *DstGEP =
1010+
MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex);
1011+
1012+
MainLoopBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1013+
IsVolatile);
1014+
1015+
assert(!LEI.ResidualLoopIP && !LEI.ResidualLoopIndex &&
1016+
"No residual loop was requested");
1017+
}
1018+
1019+
uint64_t BytesSet = LoopEndCount;
1020+
uint64_t RemainingBytes = Len->getZExtValue() - BytesSet;
1021+
if (RemainingBytes == 0)
1022+
return;
1023+
1024+
IRBuilder<> RBuilder(InsertBefore);
1025+
1026+
SmallVector<Type *, 5> RemainingOps;
1027+
TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes,
1028+
DstAS, DstAS, DstAlign, DstAlign,
1029+
std::nullopt);
1030+
1031+
Type *PreviousOpTy = nullptr;
1032+
Value *SplatSetValue = nullptr;
1033+
for (auto *OpTy : RemainingOps) {
1034+
unsigned OperandSize = DL.getTypeStoreSize(OpTy);
1035+
Align PartDstAlign(commonAlignment(DstAlign, BytesSet));
1036+
1037+
// Avoid recomputing the splat SetValue if it's the same as for the last
1038+
// iteration.
1039+
if (OpTy != PreviousOpTy)
1040+
SplatSetValue = createMemSetSplat(DL, RBuilder, SetValue, OpTy);
1041+
1042+
Value *DstGEP = RBuilder.CreateInBoundsGEP(
1043+
Int8Type, DstAddr, ConstantInt::get(TypeOfLen, BytesSet));
1044+
RBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1045+
IsVolatile);
1046+
BytesSet += OperandSize;
1047+
PreviousOpTy = OpTy;
1048+
}
1049+
assert(BytesSet == Len->getZExtValue() &&
1050+
"Bytes set should match size in the call!");
1051+
}
1052+
1053+
static void createMemSetLoopUnknownSize(Instruction *InsertBefore,
1054+
Value *DstAddr, Value *Len,
1055+
Value *SetValue, Align DstAlign,
1056+
bool IsVolatile,
1057+
const TargetTransformInfo &TTI) {
1058+
BasicBlock *PreLoopBB = InsertBefore->getParent();
1059+
Function *ParentFunc = PreLoopBB->getParent();
1060+
const DataLayout &DL = ParentFunc->getDataLayout();
1061+
LLVMContext &Ctx = PreLoopBB->getContext();
1062+
1063+
unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
1064+
1065+
Type *Int8Type = Type::getInt8Ty(Ctx);
1066+
assert(SetValue->getType() == Int8Type && "Can only set bytes");
1067+
1068+
Type *LoopOpType = TTI.getMemcpyLoopLoweringType(
1069+
Ctx, Len, DstAS, DstAS, DstAlign, DstAlign, std::nullopt);
1070+
unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType);
1071+
1072+
Type *ResidualLoopOpType = Int8Type;
1073+
unsigned ResidualLoopOpSize = DL.getTypeStoreSize(ResidualLoopOpType);
1074+
1075+
Value *SplatSetValue = SetValue;
1076+
{
1077+
IRBuilder<> PreLoopBuilder(InsertBefore);
1078+
SplatSetValue = createMemSetSplat(DL, PreLoopBuilder, SetValue, LoopOpType);
1079+
}
1080+
1081+
LoopExpansionInfo LEI = insertLoopExpansion(
1082+
InsertBefore, Len, LoopOpSize, ResidualLoopOpSize, "dynamic-memset");
1083+
1084+
// Fill MainLoopBB
1085+
IRBuilder<> MainLoopBuilder(LEI.MainLoopIP);
1086+
Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize));
1087+
1088+
Value *DstGEP =
1089+
MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex);
1090+
MainLoopBuilder.CreateAlignedStore(SplatSetValue, DstGEP, PartDstAlign,
1091+
IsVolatile);
1092+
1093+
// Fill ResidualLoopBB
1094+
if (!LEI.ResidualLoopIP)
1095+
return;
1096+
1097+
Align ResDstAlign(commonAlignment(PartDstAlign, ResidualLoopOpSize));
1098+
1099+
IRBuilder<> ResLoopBuilder(LEI.ResidualLoopIP);
1100+
1101+
Value *ResDstGEP = ResLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr,
1102+
LEI.ResidualLoopIndex);
1103+
ResLoopBuilder.CreateAlignedStore(SetValue, ResDstGEP, ResDstAlign,
1104+
IsVolatile);
1105+
}
1106+
9321107
static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr,
9331108
Value *CopyLen, Value *SetValue, Align DstAlign,
9341109
bool IsVolatile) {
1110+
// Currently no longer used for memset, only for memset.pattern.
1111+
// TODO: Update the memset.pattern lowering to also use the loop expansion
1112+
// framework and remove this function.
9351113
Type *TypeOfCopyLen = CopyLen->getType();
9361114
BasicBlock *OrigBB = InsertBefore->getParent();
9371115
Function *F = OrigBB->getParent();
@@ -1066,13 +1244,25 @@ bool llvm::expandMemMoveAsLoop(MemMoveInst *Memmove,
10661244
return true;
10671245
}
10681246

1069-
void llvm::expandMemSetAsLoop(MemSetInst *Memset) {
1070-
createMemSetLoop(/* InsertBefore */ Memset,
1071-
/* DstAddr */ Memset->getRawDest(),
1072-
/* CopyLen */ Memset->getLength(),
1073-
/* SetValue */ Memset->getValue(),
1074-
/* Alignment */ Memset->getDestAlign().valueOrOne(),
1075-
Memset->isVolatile());
1247+
void llvm::expandMemSetAsLoop(MemSetInst *Memset,
1248+
const TargetTransformInfo &TTI) {
1249+
if (ConstantInt *CI = dyn_cast<ConstantInt>(Memset->getLength())) {
1250+
createMemSetLoopKnownSize(
1251+
/* InsertBefore */ Memset,
1252+
/* DstAddr */ Memset->getRawDest(),
1253+
/* Len */ CI,
1254+
/* SetValue */ Memset->getValue(),
1255+
/* DstAlign */ Memset->getDestAlign().valueOrOne(),
1256+
Memset->isVolatile(), TTI);
1257+
} else {
1258+
createMemSetLoopUnknownSize(
1259+
/* InsertBefore */ Memset,
1260+
/* DstAddr */ Memset->getRawDest(),
1261+
/* Len */ Memset->getLength(),
1262+
/* SetValue */ Memset->getValue(),
1263+
/* DstAlign */ Memset->getDestAlign().valueOrOne(),
1264+
Memset->isVolatile(), TTI);
1265+
}
10761266
}
10771267

10781268
void llvm::expandMemSetPatternAsLoop(MemSetPatternInst *Memset) {

0 commit comments

Comments
 (0)