From 575e2aff5574550d10278d9a41fca2926a5b8409 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Tue, 25 May 2021 13:54:55 +0100 Subject: [PATCH] [VectorCombine] Use constant range info for index scalarization legality. We can only scalarize memory accesses if we know the index is valid. This patch adjusts canScalarizeAcceess to fall back to computeConstantRange to check if the index is known to be valid. Reviewed By: nlopes Differential Revision: https://reviews.llvm.org/D102476 --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 38 +++++++++++++++------- .../PhaseOrdering/AArch64/matrix-extract-insert.ll | 12 +++---- .../AArch64/load-extractelement-scalarization.ll | 13 ++++---- .../Transforms/VectorCombine/load-insert-store.ll | 15 ++++----- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 3913032..50916d0 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Vectorize/VectorCombine.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" @@ -60,8 +61,8 @@ namespace { class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} bool run(); @@ -71,6 +72,7 @@ private: const TargetTransformInfo &TTI; const DominatorTree &DT; AAResults &AA; + AssumptionCache &AC; bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, @@ -774,8 +776,16 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin, /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) { - return Idx->getValue().ult(VecTy->getNumElements()); +static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, + Instruction *CtxI, AssumptionCache &AC) { + if (auto *C = dyn_cast(Idx)) + return C->getValue().ult(VecTy->getNumElements()); + + APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); + APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + ConstantRange ValidIndices(Zero, MaxElts); + ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); + return ValidIndices.contains(IdxRange); } // Combine patterns like: @@ -796,10 +806,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { // TargetTransformInfo. Instruction *Source; Value *NewElement; - ConstantInt *Idx; + Value *Idx; if (!match(SI->getValueOperand(), m_InsertElt(m_Instruction(Source), m_Value(NewElement), - m_ConstantInt(Idx)))) + m_Value(Idx)))) return false; if (auto *Load = dyn_cast(Source)) { @@ -810,7 +820,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx) || + !canScalarizeAccess(VecTy, Idx, Load, AC) || SrcAddr != SI->getPointerOperand()->stripPointerCasts() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) @@ -835,8 +845,8 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - ConstantInt *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx)))) + Value *Idx; + if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) return false; auto *LI = cast(I.getOperand(0)); @@ -848,7 +858,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (!FixedVT) return false; - if (!canScalarizeAccess(FixedVT, Idx)) + if (!canScalarizeAccess(FixedVT, Idx, &I, AC)) return false; InstructionCost OriginalCost = TTI.getMemoryOpCost( @@ -971,6 +981,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -985,10 +996,11 @@ public: bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); auto &DT = getAnalysis().getDomTree(); auto &AA = getAnalysis().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); return Combiner.run(); } }; @@ -998,6 +1010,7 @@ char VectorCombineLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", "Optimize scalar/vector ops", false, false) @@ -1007,10 +1020,11 @@ Pass *llvm::createVectorCombinePass() { PreservedAnalyses VectorCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult(F); TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); AAResults &AA = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT, AA); + VectorCombine Combiner(F, TTI, DT, AA, AC); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll index 31d8a16..1089f54 100644 --- a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll +++ b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll @@ -13,8 +13,8 @@ define void @matrix_extract_insert_scalar(i32 %i, i32 %k, i32 %j, [225 x double] ; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 225 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP2]]) ; CHECK-NEXT: [[TMP3:%.*]] = bitcast [225 x double]* [[A:%.*]] to <225 x double>* -; CHECK-NEXT: [[TMP4:%.*]] = load <225 x double>, <225 x double>* [[TMP3]], align 8 -; CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <225 x double> [[TMP4]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP3]], i64 0, i64 [[TMP1]] +; CHECK-NEXT: [[MATRIXEXT:%.*]] = load double, double* [[TMP4]], align 8 ; CHECK-NEXT: [[CONV2:%.*]] = zext i32 [[I:%.*]] to i64 ; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i64 [[TMP0]], [[CONV2]] ; CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[TMP5]], 225 @@ -25,8 +25,8 @@ define void @matrix_extract_insert_scalar(i32 %i, i32 %k, i32 %j, [225 x double] ; CHECK-NEXT: [[MUL:%.*]] = fmul double [[MATRIXEXT]], [[MATRIXEXT4]] ; CHECK-NEXT: [[MATRIXEXT7:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP1]] ; CHECK-NEXT: [[SUB:%.*]] = fsub double [[MATRIXEXT7]], [[MUL]] -; CHECK-NEXT: [[MATINS:%.*]] = insertelement <225 x double> [[TMP8]], double [[SUB]], i64 [[TMP1]] -; CHECK-NEXT: store <225 x double> [[MATINS]], <225 x double>* [[TMP7]], align 8 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP7]], i64 0, i64 [[TMP1]] +; CHECK-NEXT: store double [[SUB]], double* [[TMP9]], align 8 ; CHECK-NEXT: ret void ; entry: @@ -112,8 +112,8 @@ define void @matrix_extract_insert_loop(i32 %i, [225 x double]* nonnull align 8 ; CHECK-NEXT: [[TMP6:%.*]] = add nuw nsw i64 [[TMP2]], [[CONV_US]] ; CHECK-NEXT: [[TMP7:%.*]] = icmp ult i64 [[TMP6]], 225 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP7]]) -; CHECK-NEXT: [[TMP8:%.*]] = load <225 x double>, <225 x double>* [[TMP0]], align 8 -; CHECK-NEXT: [[MATRIXEXT_US:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP0]], i64 0, i64 [[TMP6]] +; CHECK-NEXT: [[MATRIXEXT_US:%.*]] = load double, double* [[TMP8]], align 8 ; CHECK-NEXT: [[MATRIXEXT8_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP3]] ; CHECK-NEXT: [[MUL_US:%.*]] = fmul double [[MATRIXEXT_US]], [[MATRIXEXT8_US]] ; CHECK-NEXT: [[MATRIXEXT11_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP6]] diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll index 7179909..4ffbae6 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck --check-prefixes=CHECK,LIMIT-DEFAULT %s +; RUN: opt -vector-combine -enable-new-pm=false -mtriple=arm64-apple-darwinos -S %s | FileCheck --check-prefixes=CHECK,LIMIT-DEFAULT %s ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -vector-combine-max-scan-instrs=2 -S %s | FileCheck --check-prefixes=CHECK,LIMIT2 %s define i32 @load_extract_idx_0(<4 x i32>* %x) { @@ -90,9 +91,9 @@ define i32 @load_extract_idx_var_i64_known_valid_by_assume(<4 x i32>* %x, i64 %i ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 ; CHECK-NEXT: call void @maythrow() -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; entry: @@ -146,8 +147,8 @@ define i32 @load_extract_idx_var_i64_known_valid_by_and(<4 x i32>* %x, i64 %idx) ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3 -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; entry: @@ -176,8 +177,8 @@ define i32 @load_extract_idx_var_i64_known_valid_by_urem(<4 x i32>* %x, i64 %idx ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4 -; CHECK-NEXT: [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]] +; CHECK-NEXT: [[R:%.*]] = load i32, i32* [[TMP0]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; entry: diff --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll index 611d6697..1b43834 100644 --- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll +++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll @@ -130,9 +130,8 @@ define void @insert_store_nonconst_index_known_valid_by_assume(<16 x i8>* %q, i8 ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -191,10 +190,9 @@ declare void @llvm.assume(i1) define void @insert_store_nonconst_index_known_valid_by_and(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: @@ -225,10 +223,9 @@ entry: define void @insert_store_nonconst_index_known_valid_by_urem(<16 x i8>* %q, i8 zeroext %s, i32 %idx) { ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16 ; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16 -; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]] -; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]] +; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP0]], align 1 ; CHECK-NEXT: ret void ; entry: -- 2.7.4