From 710e34e1360710275662ad5b0bdc394570fb26d5 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Wed, 2 Nov 2022 17:10:20 -0400 Subject: [PATCH] [VectorCombine] move load safety checks to helper function; NFC These checks can be re-used with other potential transforms such as a load of a subvector-insert. --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 45 +++++++++++++++---------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index a21add2..bac72b8f 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -127,6 +127,27 @@ private: }; } // namespace +static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { + // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. + // The widened load may load data from dirty regions or create data races + // non-existent in the source. + if (!Load || !Load->isSimple() || !Load->hasOneUse() || + Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || + mustSuppressSpeculation(*Load)) + return false; + + // We are potentially transforming byte-sized (8-bit) memory accesses, so make + // sure we have all of our type-based constraints in place for this target. + Type *ScalarTy = Load->getType()->getScalarType(); + uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); + unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); + if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || + ScalarSize % 8 != 0) + return false; + + return true; +} + bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Match insert into fixed vector of scalar value. // TODO: Handle non-zero insert index. @@ -142,35 +163,22 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { if (!HasExtract) X = Scalar; - // Match source value as load of scalar or vector. - // Do not vectorize scalar load (widening) if atomic/volatile or under - // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions - // or create data races non-existent in the source. auto *Load = dyn_cast(X); - if (!Load || !Load->isSimple() || !Load->hasOneUse() || - Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || - mustSuppressSpeculation(*Load)) + if (!canWidenLoad(Load, TTI)) return false; - const DataLayout &DL = I.getModule()->getDataLayout(); - Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); - assert(isa(SrcPtr->getType()) && "Expected a pointer type"); - - unsigned AS = Load->getPointerAddressSpace(); - - // We are potentially transforming byte-sized (8-bit) memory accesses, so make - // sure we have all of our type-based constraints in place for this target. Type *ScalarTy = Scalar->getType(); uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); - if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || - ScalarSize % 8 != 0) - return false; // Check safety of replacing the scalar load with a larger vector load. // We use minimal alignment (maximum flexibility) because we only care about // the dereferenceable region. When calculating cost and creating a new op, // we may use a larger value based on alignment attributes. + const DataLayout &DL = I.getModule()->getDataLayout(); + Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); + assert(isa(SrcPtr->getType()) && "Expected a pointer type"); + unsigned MinVecNumElts = MinVectorSize / ScalarSize; auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); unsigned OffsetEltIndex = 0; @@ -215,6 +223,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Use the greater of the alignment on the load or its source pointer. Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); Type *LoadTy = Load->getType(); + unsigned AS = Load->getPointerAddressSpace(); InstructionCost OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); -- 2.7.4