From d77753381fe024434ae8ffaaacfe4b9ed9d4d760 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 21 Jan 2021 14:54:03 -0500 Subject: [PATCH] [SLP] simplify reduction matching This is NFC-intended and removes the "OperationData" class which had become nothing more than a recurrence (reduction) type. I adjusted the matching logic to distinguish instructions from non-instructions - that's all that the "IsLeafValue" member was keeping track of. --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 189 ++++++++++-------------- 1 file changed, 78 insertions(+), 111 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 2597f88..7326001 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6401,44 +6401,9 @@ class HorizontalReduction { SmallVector ReducedVals; // Use map vector to make stable output. MapVector ExtraArgs; - - /// This wraps functionality around a RecurKind (reduction kind). - /// TODO: Remove this class if callers can use the 'Kind' value directly? - class OperationData { - /// Kind of the reduction operation. - RecurKind Kind = RecurKind::None; - bool IsLeafValue = false; - - public: - explicit OperationData() = default; - - /// Constructor for reduced values. They are identified by the bool only. - explicit OperationData(Instruction &I) { IsLeafValue = true; } - - /// Constructor for reduction operations with opcode and type. - OperationData(RecurKind RdxKind) : Kind(RdxKind) { - assert(Kind != RecurKind::None && "Expected reduction operation."); - } - - explicit operator bool() const { - return IsLeafValue || Kind != RecurKind::None; - } - - /// Checks if two operation data are both a reduction op or both a reduced - /// value. - bool operator==(const OperationData &OD) const { - return Kind == OD.Kind && IsLeafValue == OD.IsLeafValue; - } - bool operator!=(const OperationData &OD) const { return !(*this == OD); } - - /// Get kind of reduction data. - RecurKind getKind() const { return Kind; } - }; - WeakTrackingVH ReductionRoot; - - /// The operation data of the reduction operation. - OperationData RdxTreeInst; + /// The type of reduction operation. + RecurKind RdxKind; /// Checks if instruction is associative and can be vectorized. static bool isVectorizable(RecurKind Kind, Instruction *I) { @@ -6471,8 +6436,8 @@ class HorizontalReduction { // in this case. // Do not perform analysis of remaining operands of ParentStackElem.first // instruction, this whole instruction is an extra argument. - OperationData OpData = getOperationData(ParentStackElem.first); - ParentStackElem.second = getNumberOfOperands(OpData.getKind()); + RecurKind RdxKind = getRdxKind(ParentStackElem.first); + ParentStackElem.second = getNumberOfOperands(RdxKind); } else { // We ran into something like: // ParentStackElem.first += ... + ExtraArg + ... @@ -6550,39 +6515,37 @@ class HorizontalReduction { return Op; } - static OperationData getOperationData(Instruction *I) { - if (!I) - return OperationData(); - + static RecurKind getRdxKind(Instruction *I) { + assert(I && "Expected instruction for reduction matching"); TargetTransformInfo::ReductionFlags RdxFlags; if (match(I, m_Add(m_Value(), m_Value()))) - return OperationData(RecurKind::Add); + return RecurKind::Add; if (match(I, m_Mul(m_Value(), m_Value()))) - return OperationData(RecurKind::Mul); + return RecurKind::Mul; if (match(I, m_And(m_Value(), m_Value()))) - return OperationData(RecurKind::And); + return RecurKind::And; if (match(I, m_Or(m_Value(), m_Value()))) - return OperationData(RecurKind::Or); + return RecurKind::Or; if (match(I, m_Xor(m_Value(), m_Value()))) - return OperationData(RecurKind::Xor); + return RecurKind::Xor; if (match(I, m_FAdd(m_Value(), m_Value()))) - return OperationData(RecurKind::FAdd); + return RecurKind::FAdd; if (match(I, m_FMul(m_Value(), m_Value()))) - return OperationData(RecurKind::FMul); + return RecurKind::FMul; if (match(I, m_Intrinsic(m_Value(), m_Value()))) - return OperationData(RecurKind::FMax); + return RecurKind::FMax; if (match(I, m_Intrinsic(m_Value(), m_Value()))) - return OperationData(RecurKind::FMin); + return RecurKind::FMin; if (match(I, m_SMax(m_Value(), m_Value()))) - return OperationData(RecurKind::SMax); + return RecurKind::SMax; if (match(I, m_SMin(m_Value(), m_Value()))) - return OperationData(RecurKind::SMin); + return RecurKind::SMin; if (match(I, m_UMax(m_Value(), m_Value()))) - return OperationData(RecurKind::UMax); + return RecurKind::UMax; if (match(I, m_UMin(m_Value(), m_Value()))) - return OperationData(RecurKind::UMin); + return RecurKind::UMin; if (auto *Select = dyn_cast(I)) { // Try harder: look for min/max pattern based on instructions producing @@ -6608,39 +6571,39 @@ class HorizontalReduction { if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) { if (!isa(RHS) || !L2->isIdenticalTo(cast(RHS))) - return OperationData(*I); + return RecurKind::None; } else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) { if (!isa(LHS) || !L1->isIdenticalTo(cast(LHS))) - return OperationData(*I); + return RecurKind::None; } else { if (!isa(LHS) || !isa(RHS)) - return OperationData(*I); + return RecurKind::None; if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) || !L1->isIdenticalTo(cast(LHS)) || !L2->isIdenticalTo(cast(RHS))) - return OperationData(*I); + return RecurKind::None; } TargetTransformInfo::ReductionFlags RdxFlags; switch (Pred) { default: - return OperationData(*I); + return RecurKind::None; case CmpInst::ICMP_SGT: case CmpInst::ICMP_SGE: - return OperationData(RecurKind::SMax); + return RecurKind::SMax; case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - return OperationData(RecurKind::SMin); + return RecurKind::SMin; case CmpInst::ICMP_UGT: case CmpInst::ICMP_UGE: - return OperationData(RecurKind::UMax); + return RecurKind::UMax; case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - return OperationData(RecurKind::UMin); + return RecurKind::UMin; } } - return OperationData(*I); + return RecurKind::None; } /// Return true if this operation is a cmp+select idiom. @@ -6724,24 +6687,28 @@ public: assert((!Phi || is_contained(Phi->operands(), B)) && "Phi needs to use the binary operator"); - RdxTreeInst = getOperationData(B); + RdxKind = getRdxKind(B); // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (getLHS(RdxTreeInst.getKind(), B) == Phi) { + if (getLHS(RdxKind, B) == Phi) { Phi = nullptr; - B = dyn_cast(getRHS(RdxTreeInst.getKind(), B)); - RdxTreeInst = getOperationData(B); - } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) { + B = dyn_cast(getRHS(RdxKind, B)); + if (!B) + return false; + RdxKind = getRdxKind(B); + } else if (getRHS(RdxKind, B) == Phi) { Phi = nullptr; - B = dyn_cast(getLHS(RdxTreeInst.getKind(), B)); - RdxTreeInst = getOperationData(B); + B = dyn_cast(getLHS(RdxKind, B)); + if (!B) + return false; + RdxKind = getRdxKind(B); } } - if (!isVectorizable(RdxTreeInst.getKind(), B)) + if (!isVectorizable(RdxKind, B)) return false; // Analyze "regular" integer/FP types for reductions - no target-specific @@ -6761,18 +6728,16 @@ public: // Post order traverse the reduction tree starting at B. We only handle true // trees containing only binary operators. SmallVector, 32> Stack; - Stack.push_back( - std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind()))); - initReductionOps(RdxTreeInst.getKind()); + Stack.push_back(std::make_pair(B, getFirstOperandIndex(RdxKind))); + initReductionOps(RdxKind); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVisit = Stack.back().second++; - const OperationData OpData = getOperationData(TreeN); - bool IsReducedValue = OpData != RdxTreeInst; + const RecurKind TreeRdxKind = getRdxKind(TreeN); + bool IsReducedValue = TreeRdxKind != RdxKind; // Postorder visit. - if (IsReducedValue || - EdgeToVisit == getNumberOfOperands(OpData.getKind())) { + if (IsReducedValue || EdgeToVisit == getNumberOfOperands(TreeRdxKind)) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -6790,7 +6755,7 @@ public: markExtraArg(Stack[Stack.size() - 2], TreeN); ExtraArgs.erase(TreeN); } else - addReductionOps(RdxTreeInst.getKind(), TreeN); + addReductionOps(RdxKind, TreeN); } // Retract. Stack.pop_back(); @@ -6798,9 +6763,15 @@ public: } // Visit left or right. - Value *NextV = TreeN->getOperand(EdgeToVisit); - auto *I = dyn_cast(NextV); - const OperationData EdgeOpData = getOperationData(I); + Value *EdgeVal = TreeN->getOperand(EdgeToVisit); + auto *I = dyn_cast(EdgeVal); + if (!I) { + // Edge value is not a reduction instruction or a leaf instruction. + // (It may be a constant, function argument, or something else.) + markExtraArg(Stack.back(), EdgeVal); + continue; + } + RecurKind EdgeRdxKind = getRdxKind(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a leaf value. If the leaf value opcode is not set, // the first met operation != reduction operation is considered as the @@ -6808,14 +6779,14 @@ public: // Only handle trees in the current basic block. // Each tree node needs to have minimal number of users except for the // ultimate reduction. - const bool IsRdxInst = EdgeOpData == RdxTreeInst; - if (I && I != Phi && I != B && - hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) && - hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) && + const bool IsRdxInst = EdgeRdxKind == RdxKind; + if (I != Phi && I != B && + hasSameParent(RdxKind, I, B->getParent(), IsRdxInst) && + hasRequiredNumberOfUses(RdxKind, I, IsRdxInst) && (!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) { if (IsRdxInst) { // We need to be able to reassociate the reduction operations. - if (!isVectorizable(EdgeOpData.getKind(), I)) { + if (!isVectorizable(EdgeRdxKind, I)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; @@ -6823,12 +6794,11 @@ public: } else if (!LeafOpcode) { LeafOpcode = I->getOpcode(); } - Stack.push_back( - std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind()))); + Stack.push_back(std::make_pair(I, getFirstOperandIndex(EdgeRdxKind))); continue; } - // NextV is an extra argument for TreeN (its parent operation). - markExtraArg(Stack.back(), NextV); + // I is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), I); } return true; } @@ -6922,7 +6892,7 @@ public: } if (V.isTreeTinyAndNotFullyVectorizable()) break; - if (V.isLoadCombineReductionCandidate(RdxTreeInst.getKind())) + if (V.isLoadCombineReductionCandidate(RdxKind)) break; V.computeMinimumValueSizes(); @@ -6965,7 +6935,7 @@ public: // Emit a reduction. If the root is a select (min/max idiom), the insert // point is the compare condition of that select. Instruction *RdxRootInst = cast(ReductionRoot); - if (isCmpSel(RdxTreeInst.getKind())) + if (isCmpSel(RdxKind)) Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst)); else Builder.SetInsertPoint(RdxRootInst); @@ -6979,9 +6949,8 @@ public: } else { // Update the final value in the reduction. Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = - createOp(Builder, RdxTreeInst.getKind(), VectorizedTree, - ReducedSubTree, "op.rdx", ReductionOps); + VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, + ReducedSubTree, "op.rdx", ReductionOps); } i += ReduxWidth; ReduxWidth = PowerOf2Floor(NumReducedVals - i); @@ -6992,15 +6961,15 @@ public: for (; i < NumReducedVals; ++i) { auto *I = cast(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = createOp(Builder, RdxTreeInst.getKind(), - VectorizedTree, I, "", ReductionOps); + VectorizedTree = + createOp(Builder, RdxKind, VectorizedTree, I, "", ReductionOps); } for (auto &Pair : ExternallyUsedValues) { // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = createOp(Builder, RdxTreeInst.getKind(), - VectorizedTree, Pair.first, "op.extra", I); + VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, + Pair.first, "op.extra", I); } } @@ -7008,7 +6977,7 @@ public: // select, we also have to RAUW for the compare instruction feeding the // reduction root. That's because the original compare may have extra uses // besides the final select of the reduction. - if (isCmpSel(RdxTreeInst.getKind())) { + if (isCmpSel(RdxKind)) { if (auto *VecSelect = dyn_cast(VectorizedTree)) { Instruction *ScalarCmp = getCmpForMinMaxReduction(cast(ReductionRoot)); @@ -7032,10 +7001,8 @@ private: unsigned ReduxWidth) { Type *ScalarTy = FirstReducedVal->getType(); FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); - - RecurKind Kind = RdxTreeInst.getKind(); int VectorCost, ScalarCost; - switch (Kind) { + switch (RdxKind) { case RecurKind::Add: case RecurKind::Mul: case RecurKind::Or: @@ -7043,7 +7010,7 @@ private: case RecurKind::Xor: case RecurKind::FAdd: case RecurKind::FMul: { - unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); + unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind); VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, /*IsPairwiseForm=*/false); ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); @@ -7066,7 +7033,8 @@ private: case RecurKind::UMax: case RecurKind::UMin: { auto *VecCondTy = cast(CmpInst::makeCmpResultType(VectorTy)); - bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin; + bool IsUnsigned = + RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, /*IsPairwiseForm=*/false, IsUnsigned); @@ -7098,8 +7066,7 @@ private: // FIXME: The builder should use an FMF guard. It should not be hard-coded // to 'fast'. assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF"); - return createSimpleTargetReduction(Builder, TTI, VectorizedValue, - RdxTreeInst.getKind(), + return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind, ReductionOps.back()); } }; -- 2.7.4