[X86][AVX] Convert X86ISD::VBROADCAST demanded elts combine to use SimplifyDemandedVe...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 20 Jul 2018 13:26:51 +0000 (13:26 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 20 Jul 2018 13:26:51 +0000 (13:26 +0000)
This is an early step towards using SimplifyDemandedVectorElts for target shuffle combining - this merely moves the existing X86ISD::VBROADCAST simplification code to use the SimplifyDemandedVectorElts mechanism.

Adds X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode to handle X86ISD::VBROADCAST - in time we can support all target shuffles (and other ops) here.

llvm-svn: 337547

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/lib/Target/X86/X86ISelLowering.h

index b1b6b7c..d9b42f6 100644 (file)
@@ -30635,24 +30635,13 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
 
   switch (Opcode) {
   case X86ISD::VBROADCAST: {
-    // If broadcasting from another shuffle, attempt to simplify it.
     // TODO - we really need a general SimplifyDemandedVectorElts mechanism.
-    SDValue Src = N.getOperand(0);
-    SDValue BC = peekThroughBitcasts(Src);
-    EVT SrcVT = Src.getValueType();
-    EVT BCVT = BC.getValueType();
-    if (isTargetShuffle(BC.getOpcode()) &&
-        VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits() == 0) {
-      unsigned Scale = VT.getScalarSizeInBits() / BCVT.getScalarSizeInBits();
-      SmallVector<int, 16> DemandedMask(BCVT.getVectorNumElements(),
-                                        SM_SentinelUndef);
-      for (unsigned i = 0; i != Scale; ++i)
-        DemandedMask[i] = i;
-      if (SDValue Res = combineX86ShufflesRecursively(
-              {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 1,
-              /*HasVarMask*/ false, DAG, Subtarget))
-        return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
-                           DAG.getBitcast(SrcVT, Res));
+    APInt KnownUndef, KnownZero;
+    APInt DemandedMask(APInt::getAllOnesValue(VT.getVectorNumElements()));
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    if (TLI.SimplifyDemandedVectorElts(N, DemandedMask, KnownUndef, KnownZero,
+                                       DCI)) {
+      return SDValue(N.getNode(), 0);
     }
     return SDValue();
   }
@@ -31298,6 +31287,41 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
+    SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth) const {
+
+  if (X86ISD::VBROADCAST != Op.getOpcode())
+    return false;
+
+  EVT VT = Op.getValueType();
+  SDValue Src = Op.getOperand(0);
+  SDValue BC = peekThroughBitcasts(Src);
+  EVT SrcVT = Src.getValueType();
+  EVT BCVT = BC.getValueType();
+
+  if (!isTargetShuffle(BC.getOpcode()) ||
+      (VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits()) != 0)
+    return false;
+
+  unsigned Scale = VT.getScalarSizeInBits() / BCVT.getScalarSizeInBits();
+  SmallVector<int, 16> DemandedMask(BCVT.getVectorNumElements(),
+                                    SM_SentinelUndef);
+  for (unsigned i = 0; i != Scale; ++i)
+    DemandedMask[i] = i;
+
+  if (SDValue Res = combineX86ShufflesRecursively(
+          {BC}, 0, BC, DemandedMask, {}, Depth + 1, /*HasVarMask*/ false,
+          TLO.DAG, Subtarget)) {
+    SDLoc DL(Op);
+    Res = TLO.DAG.getNode(X86ISD::VBROADCAST, DL, VT,
+                          TLO.DAG.getBitcast(SrcVT, Res));
+    return TLO.CombineTo(Op, Res);
+  }
+
+  return false;
+}
+
 /// Check if a vector extract from a target-specific shuffle of a load can be
 /// folded into a single element load.
 /// Similar handling for VECTOR_SHUFFLE is performed by DAGCombiner, but
index 32215b1..623b95b 100644 (file)
@@ -866,6 +866,13 @@ namespace llvm {
                                              const SelectionDAG &DAG,
                                              unsigned Depth) const override;
 
+    bool SimplifyDemandedVectorEltsForTargetNode(SDValue Op,
+                                                 const APInt &DemandedElts,
+                                                 APInt &KnownUndef,
+                                                 APInt &KnownZero,
+                                                 TargetLoweringOpt &TLO,
+                                                 unsigned Depth) const override;
+
     SDValue unwrapAddress(SDValue N) const override;
 
     bool isGAPlusOffset(SDNode *N, const GlobalValue* &GA,