[X86][SSE] Lower vXi8 general shifts to SSE shifts directly. NFCI.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 21 Aug 2018 17:27:03 +0000 (17:27 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 21 Aug 2018 17:27:03 +0000 (17:27 +0000)
Most of these shifts are extended to vXi16 so we don't gain anything from forcing another round of generic shift lowering - we know these extended cases are legal constant splat shifts.

llvm-svn: 340307

llvm/lib/Target/X86/X86ISelLowering.cpp

index 2260bc2..842dbc6 100644 (file)
@@ -23844,7 +23844,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
     MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
-    unsigned ShiftOpcode = Opc;
 
     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
       if (VT.is512BitVector()) {
@@ -23878,32 +23877,33 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
     // We can safely do this using i16 shifts as we're only interested in
     // the 3 lower bits of each byte.
     Amt = DAG.getBitcast(ExtVT, Amt);
-    Amt = DAG.getNode(ISD::SHL, dl, ExtVT, Amt, DAG.getConstant(5, dl, ExtVT));
+    Amt = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ExtVT, Amt, 5, DAG);
     Amt = DAG.getBitcast(VT, Amt);
 
     if (Opc == ISD::SHL || Opc == ISD::SRL) {
       // r = VSELECT(r, shift(r, 4), a);
-      SDValue M =
-          DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(4, dl, VT));
+      SDValue M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(4, dl, VT));
       R = SignBitSelect(VT, Amt, M, R);
 
       // a += a
       Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
 
       // r = VSELECT(r, shift(r, 2), a);
-      M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(2, dl, VT));
+      M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(2, dl, VT));
       R = SignBitSelect(VT, Amt, M, R);
 
       // a += a
       Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
 
       // return VSELECT(r, shift(r, 1), a);
-      M = DAG.getNode(ShiftOpcode, dl, VT, R, DAG.getConstant(1, dl, VT));
+      M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(1, dl, VT));
       R = SignBitSelect(VT, Amt, M, R);
       return R;
     }
 
     if (Opc == ISD::SRA) {
+      unsigned X86Opc = getTargetVShiftUniformOpcode(Opc, false);
+
       // For SRA we need to unpack each byte to the higher byte of a i16 vector
       // so we can correctly sign extend. We don't care what happens to the
       // lower byte.
@@ -23917,10 +23917,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       RHi = DAG.getBitcast(ExtVT, RHi);
 
       // r = VSELECT(r, shift(r, 4), a);
-      SDValue MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
-                                DAG.getConstant(4, dl, ExtVT));
-      SDValue MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
-                                DAG.getConstant(4, dl, ExtVT));
+      SDValue MLo = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RLo, 4, DAG);
+      SDValue MHi = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RHi, 4, DAG);
       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
 
@@ -23929,10 +23927,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);
 
       // r = VSELECT(r, shift(r, 2), a);
-      MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
-                        DAG.getConstant(2, dl, ExtVT));
-      MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
-                        DAG.getConstant(2, dl, ExtVT));
+      MLo = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RLo, 2, DAG);
+      MHi = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RHi, 2, DAG);
       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
 
@@ -23941,20 +23937,15 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);
 
       // r = VSELECT(r, shift(r, 1), a);
-      MLo = DAG.getNode(ShiftOpcode, dl, ExtVT, RLo,
-                        DAG.getConstant(1, dl, ExtVT));
-      MHi = DAG.getNode(ShiftOpcode, dl, ExtVT, RHi,
-                        DAG.getConstant(1, dl, ExtVT));
+      MLo = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RLo, 1, DAG);
+      MHi = getTargetVShiftByConstNode(X86Opc, dl, ExtVT, RHi, 1, DAG);
       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
 
       // Logical shift the result back to the lower byte, leaving a zero upper
-      // byte
-      // meaning that we can safely pack with PACKUSWB.
-      RLo =
-          DAG.getNode(ISD::SRL, dl, ExtVT, RLo, DAG.getConstant(8, dl, ExtVT));
-      RHi =
-          DAG.getNode(ISD::SRL, dl, ExtVT, RHi, DAG.getConstant(8, dl, ExtVT));
+      // byte meaning that we can safely pack with PACKUSWB.
+      RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RLo, 8, DAG);
+      RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RHi, 8, DAG);
       return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
     }
   }
@@ -23972,8 +23963,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
     RHi = DAG.getBitcast(ExtVT, RHi);
     SDValue Lo = DAG.getNode(Opc, dl, ExtVT, RLo, ALo);
     SDValue Hi = DAG.getNode(Opc, dl, ExtVT, RHi, AHi);
-    Lo = DAG.getNode(ISD::SRL, dl, ExtVT, Lo, DAG.getConstant(16, dl, ExtVT));
-    Hi = DAG.getNode(ISD::SRL, dl, ExtVT, Hi, DAG.getConstant(16, dl, ExtVT));
+    Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Lo, 16, DAG);
+    Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Hi, 16, DAG);
     return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi);
   }