/// Node predicates
- /// If N is a BUILD_VECTOR node whose elements are all the same constant or
- /// undefined, return true and return the constant value in \p SplatValue.
- bool isConstantSplatVector(const SDNode *N, APInt &SplatValue);
-
- /// Return true if the specified node is a BUILD_VECTOR where all of the
- /// elements are ~0 or undef.
- bool isBuildVectorAllOnes(const SDNode *N);
-
- /// Return true if the specified node is a BUILD_VECTOR where all of the
- /// elements are 0 or undef.
- bool isBuildVectorAllZeros(const SDNode *N);
-
- /// Return true if the specified node is a BUILD_VECTOR node of all
- /// ConstantSDNode or undef.
- bool isBuildVectorOfConstantSDNodes(const SDNode *N);
-
- /// Return true if the specified node is a BUILD_VECTOR node of all
- /// ConstantFPSDNode or undef.
- bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);
-
- /// Return true if the node has at least one operand and all operands of the
- /// specified node are ISD::UNDEF.
- bool allOperandsUndef(const SDNode *N);
+/// If N is a BUILD_VECTOR or SPLAT_VECTOR node whose elements are all the
+/// same constant or undefined, return true and return the constant value in
+/// \p SplatValue.
+bool isConstantSplatVector(const SDNode *N, APInt &SplatValue);
+
+/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where
+/// all of the elements are ~0 or undef. If \p BuildVectorOnly is set to
+/// true, it only checks BUILD_VECTOR.
+bool isConstantSplatVectorAllOnes(const SDNode *N,
+ bool BuildVectorOnly = false);
+
+/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where
+/// all of the elements are 0 or undef. If \p BuildVectorOnly is set to true, it
+/// only checks BUILD_VECTOR.
+bool isConstantSplatVectorAllZeros(const SDNode *N,
+ bool BuildVectorOnly = false);
+
+/// Return true if the specified node is a BUILD_VECTOR where all of the
+/// elements are ~0 or undef.
+bool isBuildVectorAllOnes(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR where all of the
+/// elements are 0 or undef.
+bool isBuildVectorAllZeros(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR node of all
+/// ConstantSDNode or undef.
+bool isBuildVectorOfConstantSDNodes(const SDNode *N);
+
+/// Return true if the specified node is a BUILD_VECTOR node of all
+/// ConstantFPSDNode or undef.
+bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);
+
+/// Return true if the node has at least one operand and all operands of the
+/// specified node are ISD::UNDEF.
+bool allOperandsUndef(const SDNode *N);
} // end namespace ISD
def vtInt : PatLeaf<(vt), [{ return N->getVT().isInteger(); }]>;
def vtFP : PatLeaf<(vt), [{ return N->getVT().isFloatingPoint(); }]>;
-// Use ISD::isBuildVectorAllOnes or ISD::isBuildVectorAllZeros to look for
-// the corresponding build_vector. Will look through bitcasts except when used
-// as a pattern root.
-def immAllOnesV; // ISD::isBuildVectorAllOnes
-def immAllZerosV; // ISD::isBuildVectorAllZeros
+// Use ISD::isConstantSplatVectorAllOnes or ISD::isConstantSplatVectorAllZeros
+// to look for the corresponding build_vector or splat_vector. Will look through
+// bitcasts and check for either opcode, except when used as a pattern root.
+// When used as a pattern root, only fixed-length build_vector and scalable
+// splat_vector are supported.
+def immAllOnesV; // ISD::isConstantSplatVectorAllOnes
+def immAllZerosV; // ISD::isConstantSplatVectorAllZeros
// Other helper fragments.
def not : PatFrag<(ops node:$in), (xor node:$in, -1)>;
// FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be
// specializations of the more general isConstantSplatVector()?
-bool ISD::isBuildVectorAllOnes(const SDNode *N) {
+bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) {
// Look through a bit convert.
while (N->getOpcode() == ISD::BITCAST)
N = N->getOperand(0).getNode();
+ if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
+ APInt SplatVal;
+ return isConstantSplatVector(N, SplatVal) && SplatVal.isAllOnesValue();
+ }
+
if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
unsigned i = 0, e = N->getNumOperands();
return true;
}
-bool ISD::isBuildVectorAllZeros(const SDNode *N) {
+bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) {
// Look through a bit convert.
while (N->getOpcode() == ISD::BITCAST)
N = N->getOperand(0).getNode();
+ if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) {
+ APInt SplatVal;
+ return isConstantSplatVector(N, SplatVal) && SplatVal.isNullValue();
+ }
+
if (N->getOpcode() != ISD::BUILD_VECTOR) return false;
bool IsAllUndef = true;
return true;
}
+bool ISD::isBuildVectorAllOnes(const SDNode *N) {
+ return isConstantSplatVectorAllOnes(N, /*BuildVectorOnly*/ true);
+}
+
+bool ISD::isBuildVectorAllZeros(const SDNode *N) {
+ return isConstantSplatVectorAllZeros(N, /*BuildVectorOnly*/ true);
+}
+
bool ISD::isBuildVectorOfConstantSDNodes(const SDNode *N) {
if (N->getOpcode() != ISD::BUILD_VECTOR)
return false;
if (!::CheckOrImm(MatcherTable, MatcherIndex, N, *this)) break;
continue;
case OPC_CheckImmAllOnesV:
- if (!ISD::isBuildVectorAllOnes(N.getNode())) break;
+ if (!ISD::isConstantSplatVectorAllOnes(N.getNode()))
+ break;
continue;
case OPC_CheckImmAllZerosV:
- if (!ISD::isBuildVectorAllZeros(N.getNode())) break;
+ if (!ISD::isConstantSplatVectorAllZeros(N.getNode()))
+ break;
continue;
case OPC_CheckFoldableChainNode: {
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", []>;
def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", []>;
-// A mask-vector version of the standard 'vnot' fragment but using splat_vector
-// rather than (the implicit) build_vector
-def riscv_m_vnot : PatFrag<(ops node:$in),
- (xor node:$in, (splat_vector (XLenVT 1)))>;
-
multiclass VPatUSLoadStoreSDNode<LLVMType type,
LLVMType mask_type,
int sew,
(!cast<Instruction>("PseudoVMXOR_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (riscv_m_vnot (and VR:$rs1, VR:$rs2))),
+ def : Pat<(mti.Mask (vnot (and VR:$rs1, VR:$rs2))),
(!cast<Instruction>("PseudoVMNAND_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (riscv_m_vnot (or VR:$rs1, VR:$rs2))),
+ def : Pat<(mti.Mask (vnot (or VR:$rs1, VR:$rs2))),
(!cast<Instruction>("PseudoVMNOR_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (riscv_m_vnot (xor VR:$rs1, VR:$rs2))),
+ def : Pat<(mti.Mask (vnot (xor VR:$rs1, VR:$rs2))),
(!cast<Instruction>("PseudoVMXNOR_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (and VR:$rs1, (riscv_m_vnot VR:$rs2))),
+ def : Pat<(mti.Mask (and VR:$rs1, (vnot VR:$rs2))),
(!cast<Instruction>("PseudoVMANDNOT_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (or VR:$rs1, (riscv_m_vnot VR:$rs2))),
+ def : Pat<(mti.Mask (or VR:$rs1, (vnot VR:$rs2))),
(!cast<Instruction>("PseudoVMORNOT_MM_"#mti.LMul.MX)
VR:$rs1, VR:$rs2, VLMax, mti.SEW)>;
}
}
foreach mti = AllMasks in {
- def : Pat<(mti.Mask (splat_vector (XLenVT 1))),
+ def : Pat<(mti.Mask immAllOnesV),
(!cast<Instruction>("PseudoVMSET_M_"#mti.BX) VLMax, mti.SEW)>;
- def : Pat<(mti.Mask (splat_vector (XLenVT 0))),
+ def : Pat<(mti.Mask immAllZerosV),
(!cast<Instruction>("PseudoVMCLR_M_"#mti.BX) VLMax, mti.SEW)>;
}
} // Predicates = [HasStdExtV]
}
};
-/// CheckImmAllOnesVMatcher - This check if the current node is an build vector
-/// of all ones.
+/// CheckImmAllOnesVMatcher - This checks if the current node is a build_vector
+/// or splat_vector of all ones.
class CheckImmAllOnesVMatcher : public Matcher {
public:
CheckImmAllOnesVMatcher() : Matcher(CheckImmAllOnesV) {}
bool isContradictoryImpl(const Matcher *M) const override;
};
-/// CheckImmAllZerosVMatcher - This check if the current node is an build vector
-/// of all zeros.
+/// CheckImmAllZerosVMatcher - This checks if the current node is a
+/// build_vector or splat_vector of all zeros.
class CheckImmAllZerosVMatcher : public Matcher {
public:
CheckImmAllZerosVMatcher() : Matcher(CheckImmAllZerosV) {}
// check to ensure that this gets folded into the normal top-level
// OpcodeSwitch.
if (N == Pattern.getSrcPattern()) {
- const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector"));
+ MVT VT = N->getSimpleType(0);
+ StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector";
+ const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name));
AddMatcher(new CheckOpcodeMatcher(NI));
}
return AddMatcher(new CheckImmAllOnesVMatcher());
// check to ensure that this gets folded into the normal top-level
// OpcodeSwitch.
if (N == Pattern.getSrcPattern()) {
- const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector"));
+ MVT VT = N->getSimpleType(0);
+ StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector";
+ const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name));
AddMatcher(new CheckOpcodeMatcher(NI));
}
return AddMatcher(new CheckImmAllZerosVMatcher());