From: Fraser Cormack Date: Wed, 6 Jan 2021 08:07:41 +0000 (+0000) Subject: [SelectionDAG] Extend immAll(Ones|Zeros)V to handle ISD::SPLAT_VECTOR X-Git-Tag: llvmorg-13-init~1625 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=de373ef779880e923636d90cdb277e4db84c7479;p=platform%2Fupstream%2Fllvm.git [SelectionDAG] Extend immAll(Ones|Zeros)V to handle ISD::SPLAT_VECTOR The TableGen immAllOnesV and immAllZerosV helpers implicitly wrapped the ISD::isBuildVectorAll(Ones|Zeros) helper functions. This was inhibiting their use for targets such as RISC-V which use ISD::SPLAT_VECTOR. In particular, RISC-V had to define its own 'vnot' fragment. In order to extend the scope of these nodes to include support for ISD::SPLAT_VECTOR, two new ISD predicate functions have been introduced: ISD::isConstantSplatVectorAll(Ones|Zeros). These effectively supersede the older "isBuildVector" predicates, which are now simple wrappers for the new functions. They pass a defaulted boolean toggle which preserves the old behaviour. It is hoped that in time all call-sites can be ported to the "isConstantSplatVector" functions. While the use of ISD::isBuildVectorAll(Ones|Zeros) has not changed, the behaviour of the TableGen immAll(Ones|Zeros)V **has**. To test the new functionality, the custom RISC-V TableGen fragment has been removed and replaced with the built-in 'vnot'. To test their use as pattern-roots, two splat patterns have been updated accordingly. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D94223 --- diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 5926c52..3d12240 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -85,29 +85,42 @@ namespace ISD { /// Node predicates - /// If N is a BUILD_VECTOR node whose elements are all the same constant or - /// undefined, return true and return the constant value in \p SplatValue. - bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); - - /// Return true if the specified node is a BUILD_VECTOR where all of the - /// elements are ~0 or undef. - bool isBuildVectorAllOnes(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR where all of the - /// elements are 0 or undef. - bool isBuildVectorAllZeros(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR node of all - /// ConstantSDNode or undef. - bool isBuildVectorOfConstantSDNodes(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR node of all - /// ConstantFPSDNode or undef. - bool isBuildVectorOfConstantFPSDNodes(const SDNode *N); - - /// Return true if the node has at least one operand and all operands of the - /// specified node are ISD::UNDEF. - bool allOperandsUndef(const SDNode *N); +/// If N is a BUILD_VECTOR or SPLAT_VECTOR node whose elements are all the +/// same constant or undefined, return true and return the constant value in +/// \p SplatValue. +bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); + +/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where +/// all of the elements are ~0 or undef. If \p BuildVectorOnly is set to +/// true, it only checks BUILD_VECTOR. +bool isConstantSplatVectorAllOnes(const SDNode *N, + bool BuildVectorOnly = false); + +/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where +/// all of the elements are 0 or undef. If \p BuildVectorOnly is set to true, it +/// only checks BUILD_VECTOR. +bool isConstantSplatVectorAllZeros(const SDNode *N, + bool BuildVectorOnly = false); + +/// Return true if the specified node is a BUILD_VECTOR where all of the +/// elements are ~0 or undef. +bool isBuildVectorAllOnes(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR where all of the +/// elements are 0 or undef. +bool isBuildVectorAllZeros(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR node of all +/// ConstantSDNode or undef. +bool isBuildVectorOfConstantSDNodes(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR node of all +/// ConstantFPSDNode or undef. +bool isBuildVectorOfConstantFPSDNodes(const SDNode *N); + +/// Return true if the node has at least one operand and all operands of the +/// specified node are ISD::UNDEF. +bool allOperandsUndef(const SDNode *N); } // end namespace ISD diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index a1e961a..a09feca 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -909,11 +909,13 @@ class FPImmLeaf def vtInt : PatLeaf<(vt), [{ return N->getVT().isInteger(); }]>; def vtFP : PatLeaf<(vt), [{ return N->getVT().isFloatingPoint(); }]>; -// Use ISD::isBuildVectorAllOnes or ISD::isBuildVectorAllZeros to look for -// the corresponding build_vector. Will look through bitcasts except when used -// as a pattern root. -def immAllOnesV; // ISD::isBuildVectorAllOnes -def immAllZerosV; // ISD::isBuildVectorAllZeros +// Use ISD::isConstantSplatVectorAllOnes or ISD::isConstantSplatVectorAllZeros +// to look for the corresponding build_vector or splat_vector. Will look through +// bitcasts and check for either opcode, except when used as a pattern root. +// When used as a pattern root, only fixed-length build_vector and scalable +// splat_vector are supported. +def immAllOnesV; // ISD::isConstantSplatVectorAllOnes +def immAllZerosV; // ISD::isConstantSplatVectorAllZeros // Other helper fragments. def not : PatFrag<(ops node:$in), (xor node:$in, -1)>; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index b13cb4f..14496b5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -164,11 +164,16 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be // specializations of the more general isConstantSplatVector()? -bool ISD::isBuildVectorAllOnes(const SDNode *N) { +bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) { // Look through a bit convert. while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt SplatVal; + return isConstantSplatVector(N, SplatVal) && SplatVal.isAllOnesValue(); + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; unsigned i = 0, e = N->getNumOperands(); @@ -208,11 +213,16 @@ bool ISD::isBuildVectorAllOnes(const SDNode *N) { return true; } -bool ISD::isBuildVectorAllZeros(const SDNode *N) { +bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) { // Look through a bit convert. while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt SplatVal; + return isConstantSplatVector(N, SplatVal) && SplatVal.isNullValue(); + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; bool IsAllUndef = true; @@ -245,6 +255,14 @@ bool ISD::isBuildVectorAllZeros(const SDNode *N) { return true; } +bool ISD::isBuildVectorAllOnes(const SDNode *N) { + return isConstantSplatVectorAllOnes(N, /*BuildVectorOnly*/ true); +} + +bool ISD::isBuildVectorAllZeros(const SDNode *N) { + return isConstantSplatVectorAllZeros(N, /*BuildVectorOnly*/ true); +} + bool ISD::isBuildVectorOfConstantSDNodes(const SDNode *N) { if (N->getOpcode() != ISD::BUILD_VECTOR) return false; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index c26dbfa..7bae504 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3202,10 +3202,12 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, if (!::CheckOrImm(MatcherTable, MatcherIndex, N, *this)) break; continue; case OPC_CheckImmAllOnesV: - if (!ISD::isBuildVectorAllOnes(N.getNode())) break; + if (!ISD::isConstantSplatVectorAllOnes(N.getNode())) + break; continue; case OPC_CheckImmAllZerosV: - if (!ISD::isBuildVectorAllZeros(N.getNode())) break; + if (!ISD::isConstantSplatVectorAllZeros(N.getNode())) + break; continue; case OPC_CheckFoldableChainNode: { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 208e501..e158b63 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -35,11 +35,6 @@ def SplatPat : ComplexPattern; def SplatPat_simm5 : ComplexPattern; def SplatPat_uimm5 : ComplexPattern; -// A mask-vector version of the standard 'vnot' fragment but using splat_vector -// rather than (the implicit) build_vector -def riscv_m_vnot : PatFrag<(ops node:$in), - (xor node:$in, (splat_vector (XLenVT 1)))>; - multiclass VPatUSLoadStoreSDNode("PseudoVMXOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (and VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (and VR:$rs1, VR:$rs2))), (!cast("PseudoVMNAND_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (or VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (or VR:$rs1, VR:$rs2))), (!cast("PseudoVMNOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (xor VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (xor VR:$rs1, VR:$rs2))), (!cast("PseudoVMXNOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (and VR:$rs1, (riscv_m_vnot VR:$rs2))), + def : Pat<(mti.Mask (and VR:$rs1, (vnot VR:$rs2))), (!cast("PseudoVMANDNOT_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (or VR:$rs1, (riscv_m_vnot VR:$rs2))), + def : Pat<(mti.Mask (or VR:$rs1, (vnot VR:$rs2))), (!cast("PseudoVMORNOT_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; } @@ -233,9 +228,9 @@ foreach vti = AllIntegerVectors in { } foreach mti = AllMasks in { - def : Pat<(mti.Mask (splat_vector (XLenVT 1))), + def : Pat<(mti.Mask immAllOnesV), (!cast("PseudoVMSET_M_"#mti.BX) VLMax, mti.SEW)>; - def : Pat<(mti.Mask (splat_vector (XLenVT 0))), + def : Pat<(mti.Mask immAllZerosV), (!cast("PseudoVMCLR_M_"#mti.BX) VLMax, mti.SEW)>; } } // Predicates = [HasStdExtV] diff --git a/llvm/utils/TableGen/DAGISelMatcher.h b/llvm/utils/TableGen/DAGISelMatcher.h index dca1865..3a920d1 100644 --- a/llvm/utils/TableGen/DAGISelMatcher.h +++ b/llvm/utils/TableGen/DAGISelMatcher.h @@ -763,8 +763,8 @@ private: } }; -/// CheckImmAllOnesVMatcher - This check if the current node is an build vector -/// of all ones. +/// CheckImmAllOnesVMatcher - This checks if the current node is a build_vector +/// or splat_vector of all ones. class CheckImmAllOnesVMatcher : public Matcher { public: CheckImmAllOnesVMatcher() : Matcher(CheckImmAllOnesV) {} @@ -779,8 +779,8 @@ private: bool isContradictoryImpl(const Matcher *M) const override; }; -/// CheckImmAllZerosVMatcher - This check if the current node is an build vector -/// of all zeros. +/// CheckImmAllZerosVMatcher - This checks if the current node is a +/// build_vector or splat_vector of all zeros. class CheckImmAllZerosVMatcher : public Matcher { public: CheckImmAllZerosVMatcher() : Matcher(CheckImmAllZerosV) {} diff --git a/llvm/utils/TableGen/DAGISelMatcherGen.cpp b/llvm/utils/TableGen/DAGISelMatcherGen.cpp index 792bf17..f7415b8 100644 --- a/llvm/utils/TableGen/DAGISelMatcherGen.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherGen.cpp @@ -282,7 +282,9 @@ void MatcherGen::EmitLeafMatchCode(const TreePatternNode *N) { // check to ensure that this gets folded into the normal top-level // OpcodeSwitch. if (N == Pattern.getSrcPattern()) { - const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector")); + MVT VT = N->getSimpleType(0); + StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector"; + const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name)); AddMatcher(new CheckOpcodeMatcher(NI)); } return AddMatcher(new CheckImmAllOnesVMatcher()); @@ -292,7 +294,9 @@ void MatcherGen::EmitLeafMatchCode(const TreePatternNode *N) { // check to ensure that this gets folded into the normal top-level // OpcodeSwitch. if (N == Pattern.getSrcPattern()) { - const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector")); + MVT VT = N->getSimpleType(0); + StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector"; + const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name)); AddMatcher(new CheckOpcodeMatcher(NI)); } return AddMatcher(new CheckImmAllZerosVMatcher());