From 86497785d540e59eaca24bed4219ddec183cbc9b Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 24 May 2021 09:19:40 +0100 Subject: [PATCH] [VectorCombine] Scalarize vector load/extract. This patch adds a new combine that tries to scalarize chains of `extractelement (load %ptr), %idx` to `load (gep %ptr, %idx)`. This is profitable when extracting only a few elements out of a large vector. At the moment, `store (extractelement (load %ptr), %idx), %ptr` operations on large vectors result in huge code in the backend. This can easily be triggered by using the matrix extension, e.g. https://clang.godbolt.org/z/qsccPdPf4 This should complement D98240. Reviewed By: spatel Differential Revision: https://reviews.llvm.org/D100273 --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 100 ++++++++++++++++++++- .../AArch64/load-extractelement-scalarization.ll | 99 +++++++++++++------- .../VectorCombine/X86/load-inseltpoison.ll | 4 +- llvm/test/Transforms/VectorCombine/X86/load.ll | 4 +- 4 files changed, 168 insertions(+), 39 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index ed6709e..dc5b178 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -89,6 +89,7 @@ private: bool scalarizeBinopOrCmp(Instruction &I); bool foldExtractedCmps(Instruction &I); bool foldSingleElementStore(Instruction &I); + bool scalarizeLoadExtract(Instruction &I); }; } // namespace @@ -771,6 +772,12 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin, }); } +/// Check if it is legal to scalarize a memory access to \p VecTy at index \p +/// Idx. \p Idx must access a valid vector element. +static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) { + return Idx->getValue().ult(VecTy->getNumElements()); +} + // Combine patterns like: // %0 = load <4 x i32>, <4 x i32>* %a // %1 = insertelement <4 x i32> %0, i32 %b, i32 1 @@ -803,7 +810,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - Idx->uge(VecTy->getNumElements()) || + !canScalarizeAccess(VecTy, Idx) || SrcAddr != SI->getPointerOperand()->stripPointerCasts() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -825,6 +832,96 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { return false; } +/// Try to scalarize vector loads feeding extractelement instructions. +bool VectorCombine::scalarizeLoadExtract(Instruction &I) { + Value *Ptr; + ConstantInt *Idx; + if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx)))) + return false; + + auto *LI = cast(I.getOperand(0)); + const DataLayout &DL = I.getModule()->getDataLayout(); + if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType())) + return false; + + auto *FixedVT = dyn_cast(LI->getType()); + if (!FixedVT) + return false; + + if (!canScalarizeAccess(FixedVT, Idx)) + return false; + + InstructionCost OriginalCost = TTI.getMemoryOpCost( + Instruction::Load, LI->getType(), Align(LI->getAlignment()), + LI->getPointerAddressSpace()); + InstructionCost ScalarizedCost = 0; + + Instruction *LastCheckedInst = LI; + unsigned NumInstChecked = 0; + // Check if all users of the load are extracts with no memory modifications + // between the load and the extract. Compute the cost of both the original + // code and the scalarized version. + for (User *U : LI->users()) { + auto *UI = dyn_cast(U); + if (!UI || UI->getParent() != LI->getParent()) + return false; + + // Check if any instruction between the load and the extract may modify + // memory. + if (LastCheckedInst->comesBefore(UI)) { + for (Instruction &I : + make_range(std::next(LI->getIterator()), UI->getIterator())) { + // Bail out if we reached the check limit or the instruction may write + // to memory. + if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory()) + return false; + NumInstChecked++; + } + } + + if (!LastCheckedInst) + LastCheckedInst = UI; + else if (LastCheckedInst->comesBefore(UI)) + LastCheckedInst = UI; + + auto *Index = dyn_cast(UI->getOperand(1)); + OriginalCost += + TTI.getVectorInstrCost(Instruction::ExtractElement, LI->getType(), + Index ? Index->getZExtValue() : -1); + ScalarizedCost += + TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(), + Align(1), LI->getPointerAddressSpace()); + ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType()); + } + + if (ScalarizedCost >= OriginalCost) + return false; + + // Replace extracts with narrow scalar loads. + for (User *U : LI->users()) { + auto *EI = cast(U); + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(EI); + Value *GEP = Builder.CreateInBoundsGEP( + FixedVT, Ptr, {Builder.getInt32(0), EI->getOperand(1)}); + auto *NewLoad = cast(Builder.CreateLoad( + FixedVT->getElementType(), GEP, EI->getName() + ".scalar")); + + // Set the alignment for the new load. For index 0, we can use the original + // alignment. Otherwise choose the common alignment of the load's align and + // the alignment for the scalar type. + auto *ConstIdx = dyn_cast(EI->getOperand(1)); + if (ConstIdx && ConstIdx->isNullValue()) + NewLoad->setAlignment(LI->getAlign()); + else + NewLoad->setAlignment(commonAlignment( + DL.getABITypeAlign(NewLoad->getType()), LI->getAlign())); + replaceValue(*EI, *NewLoad); + } + + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -851,6 +948,7 @@ bool VectorCombine::run() { MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= foldExtractedCmps(I); MadeChange |= foldSingleElementStore(I); + MadeChange |= scalarizeLoadExtract(I); } } diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll index 8e747a4..d71b589 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -1,10 +1,11 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s +; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck --check-prefixes=CHECK,LIMIT-DEFAULT %s +; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -vector-combine-max-scan-instrs=2 -S %s | FileCheck --check-prefixes=CHECK,LIMIT2 %s define i32 @load_extract_idx_0(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_0( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 3 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 3 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load <4 x i32>, <4 x i32>* %x @@ -12,10 +13,23 @@ define i32 @load_extract_idx_0(<4 x i32>* %x) { ret i32 %r } +; If the original load had a smaller alignment than the scalar type, the +; smaller alignment should be used. +define i32 @load_extract_idx_0_small_alignment(<4 x i32>* %x) { +; CHECK-LABEL: @load_extract_idx_0_small_alignment( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 3 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 2 +; CHECK-NEXT: ret i32 [[R]] +; + %lv = load <4 x i32>, <4 x i32>* %x, align 2 + %r = extractelement <4 x i32> %lv, i32 3 + ret i32 %r +} + define i32 @load_extract_idx_1(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_1( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 1 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load <4 x i32>, <4 x i32>* %x @@ -25,8 +39,8 @@ define i32 @load_extract_idx_1(<4 x i32>* %x) { define i32 @load_extract_idx_2(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_2( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load <4 x i32>, <4 x i32>* %x @@ -36,8 +50,8 @@ define i32 @load_extract_idx_2(<4 x i32>* %x) { define i32 @load_extract_idx_3(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_idx_3( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 3 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 3 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load <4 x i32>, <4 x i32>* %x @@ -202,8 +216,8 @@ declare void @clobber() define i32 @load_extract_clobber_call_before(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_clobber_call_before( ; CHECK-NEXT: call void @clobber() -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; call void @clobber() @@ -227,8 +241,8 @@ define i32 @load_extract_clobber_call_between(<4 x i32>* %x) { define i32 @load_extract_clobber_call_after(<4 x i32>* %x) { ; CHECK-LABEL: @load_extract_clobber_call_after( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: call void @clobber() ; CHECK-NEXT: ret i32 [[R]] ; @@ -241,8 +255,8 @@ define i32 @load_extract_clobber_call_after(<4 x i32>* %x) { define i32 @load_extract_clobber_store_before(<4 x i32>* %x, i8* %y) { ; CHECK-LABEL: @load_extract_clobber_store_before( ; CHECK-NEXT: store i8 0, i8* [[Y:%.*]], align 1 -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; store i8 0, i8* %y @@ -298,21 +312,37 @@ define i32 @load_extract_clobber_store_between_limit(<4 x i32>* %x, i8* %y, <8 x } define i32 @load_extract_clobber_store_after_limit(<4 x i32>* %x, i8* %y, <8 x i32> %z) { -; CHECK-LABEL: @load_extract_clobber_store_after_limit( -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[Z_0:%.*]] = extractelement <8 x i32> [[Z:%.*]], i32 0 -; CHECK-NEXT: [[Z_1:%.*]] = extractelement <8 x i32> [[Z]], i32 1 -; CHECK-NEXT: [[ADD_0:%.*]] = add i32 [[Z_0]], [[Z_1]] -; CHECK-NEXT: [[Z_2:%.*]] = extractelement <8 x i32> [[Z]], i32 2 -; CHECK-NEXT: [[ADD_1:%.*]] = add i32 [[ADD_0]], [[Z_2]] -; CHECK-NEXT: [[Z_3:%.*]] = extractelement <8 x i32> [[Z]], i32 3 -; CHECK-NEXT: [[ADD_2:%.*]] = add i32 [[ADD_1]], [[Z_3]] -; CHECK-NEXT: [[Z_4:%.*]] = extractelement <8 x i32> [[Z]], i32 4 -; CHECK-NEXT: [[ADD_3:%.*]] = add i32 [[ADD_2]], [[Z_4]] -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 -; CHECK-NEXT: store i8 0, i8* [[Y:%.*]], align 1 -; CHECK-NEXT: [[ADD_4:%.*]] = add i32 [[ADD_3]], [[R]] -; CHECK-NEXT: ret i32 [[ADD_4]] +; LIMIT-DEFAULT-LABEL: @load_extract_clobber_store_after_limit( +; LIMIT-DEFAULT-NEXT: [[Z_0:%.*]] = extractelement <8 x i32> [[Z:%.*]], i32 0 +; LIMIT-DEFAULT-NEXT: [[Z_1:%.*]] = extractelement <8 x i32> [[Z]], i32 1 +; LIMIT-DEFAULT-NEXT: [[ADD_0:%.*]] = add i32 [[Z_0]], [[Z_1]] +; LIMIT-DEFAULT-NEXT: [[Z_2:%.*]] = extractelement <8 x i32> [[Z]], i32 2 +; LIMIT-DEFAULT-NEXT: [[ADD_1:%.*]] = add i32 [[ADD_0]], [[Z_2]] +; LIMIT-DEFAULT-NEXT: [[Z_3:%.*]] = extractelement <8 x i32> [[Z]], i32 3 +; LIMIT-DEFAULT-NEXT: [[ADD_2:%.*]] = add i32 [[ADD_1]], [[Z_3]] +; LIMIT-DEFAULT-NEXT: [[Z_4:%.*]] = extractelement <8 x i32> [[Z]], i32 4 +; LIMIT-DEFAULT-NEXT: [[ADD_3:%.*]] = add i32 [[ADD_2]], [[Z_4]] +; LIMIT-DEFAULT-NEXT: [[TMP1:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i32 2 +; LIMIT-DEFAULT-NEXT: [[R:%.*]] = load i32, i32* [[TMP1]], align 4 +; LIMIT-DEFAULT-NEXT: store i8 0, i8* [[Y:%.*]], align 1 +; LIMIT-DEFAULT-NEXT: [[ADD_4:%.*]] = add i32 [[ADD_3]], [[R]] +; LIMIT-DEFAULT-NEXT: ret i32 [[ADD_4]] +; +; LIMIT2-LABEL: @load_extract_clobber_store_after_limit( +; LIMIT2-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 +; LIMIT2-NEXT: [[Z_0:%.*]] = extractelement <8 x i32> [[Z:%.*]], i32 0 +; LIMIT2-NEXT: [[Z_1:%.*]] = extractelement <8 x i32> [[Z]], i32 1 +; LIMIT2-NEXT: [[ADD_0:%.*]] = add i32 [[Z_0]], [[Z_1]] +; LIMIT2-NEXT: [[Z_2:%.*]] = extractelement <8 x i32> [[Z]], i32 2 +; LIMIT2-NEXT: [[ADD_1:%.*]] = add i32 [[ADD_0]], [[Z_2]] +; LIMIT2-NEXT: [[Z_3:%.*]] = extractelement <8 x i32> [[Z]], i32 3 +; LIMIT2-NEXT: [[ADD_2:%.*]] = add i32 [[ADD_1]], [[Z_3]] +; LIMIT2-NEXT: [[Z_4:%.*]] = extractelement <8 x i32> [[Z]], i32 4 +; LIMIT2-NEXT: [[ADD_3:%.*]] = add i32 [[ADD_2]], [[Z_4]] +; LIMIT2-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i32 2 +; LIMIT2-NEXT: store i8 0, i8* [[Y:%.*]], align 1 +; LIMIT2-NEXT: [[ADD_4:%.*]] = add i32 [[ADD_3]], [[R]] +; LIMIT2-NEXT: ret i32 [[ADD_4]] ; %lv = load <4 x i32>, <4 x i32>* %x %z.0 = extractelement <8 x i32> %z, i32 0 @@ -386,13 +416,14 @@ define i32 @load_multiple_extracts_with_constant_idx(<4 x i32>* %x) { ; because the vector large vector requires 2 vector registers. define i32 @load_multiple_extracts_with_constant_idx_profitable(<8 x i32>* %x) { ; CHECK-LABEL: @load_multiple_extracts_with_constant_idx_profitable( -; CHECK-NEXT: [[LV:%.*]] = load <8 x i32>, <8 x i32>* [[X:%.*]], align 32 -; CHECK-NEXT: [[E_0:%.*]] = extractelement <8 x i32> [[LV]], i32 0 -; CHECK-NEXT: [[E_1:%.*]] = extractelement <8 x i32> [[LV]], i32 6 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <8 x i32>, <8 x i32>* [[X:%.*]], i32 0, i32 0 +; CHECK-NEXT: [[E_0:%.*]] = load i32, i32* [[TMP1]], align 16 +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds <8 x i32>, <8 x i32>* [[X]], i32 0, i32 6 +; CHECK-NEXT: [[E_1:%.*]] = load i32, i32* [[TMP2]], align 4 ; CHECK-NEXT: [[RES:%.*]] = add i32 [[E_0]], [[E_1]] ; CHECK-NEXT: ret i32 [[RES]] ; - %lv = load <8 x i32>, <8 x i32>* %x + %lv = load <8 x i32>, <8 x i32>* %x, align 16 %e.0 = extractelement <8 x i32> %lv, i32 0 %e.1 = extractelement <8 x i32> %lv, i32 6 %res = add i32 %e.0, %e.1 diff --git a/llvm/test/Transforms/VectorCombine/X86/load-inseltpoison.ll b/llvm/test/Transforms/VectorCombine/X86/load-inseltpoison.ll index 561ba22..90bdf62 100644 --- a/llvm/test/Transforms/VectorCombine/X86/load-inseltpoison.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load-inseltpoison.ll @@ -630,8 +630,8 @@ define <8 x i32> @load_v1i32_extract_insert_v8i32_extra_use(<1 x i32>* align 16 define <8 x i16> @gep1_load_v2i16_extract_insert_v8i16(<2 x i16>* align 1 dereferenceable(16) %p) { ; SSE2-LABEL: @gep1_load_v2i16_extract_insert_v8i16( ; SSE2-NEXT: [[GEP:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[P:%.*]], i64 1 -; SSE2-NEXT: [[L:%.*]] = load <2 x i16>, <2 x i16>* [[GEP]], align 8 -; SSE2-NEXT: [[S:%.*]] = extractelement <2 x i16> [[L]], i32 0 +; SSE2-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[GEP]], i32 0, i32 0 +; SSE2-NEXT: [[S:%.*]] = load i16, i16* [[TMP1]], align 8 ; SSE2-NEXT: [[R:%.*]] = insertelement <8 x i16> poison, i16 [[S]], i64 0 ; SSE2-NEXT: ret <8 x i16> [[R]] ; diff --git a/llvm/test/Transforms/VectorCombine/X86/load.ll b/llvm/test/Transforms/VectorCombine/X86/load.ll index 3bf8492..c2fa94c 100644 --- a/llvm/test/Transforms/VectorCombine/X86/load.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load.ll @@ -630,8 +630,8 @@ define <8 x i32> @load_v1i32_extract_insert_v8i32_extra_use(<1 x i32>* align 16 define <8 x i16> @gep1_load_v2i16_extract_insert_v8i16(<2 x i16>* align 1 dereferenceable(16) %p) { ; SSE2-LABEL: @gep1_load_v2i16_extract_insert_v8i16( ; SSE2-NEXT: [[GEP:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[P:%.*]], i64 1 -; SSE2-NEXT: [[L:%.*]] = load <2 x i16>, <2 x i16>* [[GEP]], align 8 -; SSE2-NEXT: [[S:%.*]] = extractelement <2 x i16> [[L]], i32 0 +; SSE2-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[GEP]], i32 0, i32 0 +; SSE2-NEXT: [[S:%.*]] = load i16, i16* [[TMP1]], align 8 ; SSE2-NEXT: [[R:%.*]] = insertelement <8 x i16> undef, i16 [[S]], i64 0 ; SSE2-NEXT: ret <8 x i16> [[R]] ; -- 2.7.4