AVX-512, X86: Added lowering for shift operations for SKX.
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Tue, 12 May 2015 13:25:46 +0000 (13:25 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Tue, 12 May 2015 13:25:46 +0000 (13:25 +0000)
The other changes in the LowerShift() are not functional,
just to make the code more convenient.
So, the functional changes for SKX only.

llvm-svn: 237129

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx512-shift.ll

index 5ea17a4..512a9f9 100644 (file)
@@ -1507,6 +1507,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::AND,                MVT::v4i32, Legal);
     setOperationAction(ISD::OR,                 MVT::v4i32, Legal);
     setOperationAction(ISD::XOR,                MVT::v4i32, Legal);
+    setOperationAction(ISD::SRA,                MVT::v2i64, Custom);
+    setOperationAction(ISD::SRA,                MVT::v4i64, Custom);
   }
 
   // We want to custom lower some of our intrinsics.
@@ -16328,6 +16330,53 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget *Subtarget,
   return DAG.getMergeValues(Ops, dl);
 }
 
+// Return true if the requred (according to Opcode) shift-imm form is natively
+// supported by the Subtarget
+static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget *Subtarget, 
+                                        unsigned Opcode) {
+  if (VT.getScalarSizeInBits() < 16)
+    return false;
+  if (VT.is512BitVector() &&
+      (VT.getScalarSizeInBits() > 16 || Subtarget->hasBWI()))
+    return true;
+
+  bool LShift = VT.is128BitVector() || 
+    (VT.is256BitVector() && Subtarget->hasInt256());
+
+  bool AShift = LShift && (Subtarget->hasVLX() ||
+    (VT != MVT::v2i64 && VT != MVT::v4i64));
+  return (Opcode == ISD::SRA) ? AShift : LShift;
+}
+
+// The shift amount is a variable, but it is the same for all vector lanes.
+// These instrcutions are defined together with shift-immediate.
+static 
+bool SupportedVectorShiftWithBaseAmnt(MVT VT, const X86Subtarget *Subtarget, 
+                                      unsigned Opcode) {
+  return SupportedVectorShiftWithImm(VT, Subtarget, Opcode);
+}
+
+// Return true if the requred (according to Opcode) variable-shift form is
+// natively supported by the Subtarget
+static bool SupportedVectorVarShift(MVT VT, const X86Subtarget *Subtarget, 
+                                    unsigned Opcode) {
+
+  if (!Subtarget->hasInt256() || VT.getScalarSizeInBits() < 16)
+    return false;
+
+  // vXi16 supported only on AVX-512, BWI
+  if (VT.getScalarSizeInBits() == 16 && !Subtarget->hasBWI())
+    return false;
+
+  if (VT.is512BitVector() || Subtarget->hasVLX())
+    return true;
+
+  bool LShift = VT.is128BitVector() || VT.is256BitVector();
+  bool AShift = LShift &&  VT != MVT::v2i64 && VT != MVT::v4i64;
+  return (Opcode == ISD::SRA) ? AShift : LShift;
+}
+
 static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
                                          const X86Subtarget *Subtarget) {
   MVT VT = Op.getSimpleValueType();
@@ -16335,26 +16384,16 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
   SDValue R = Op.getOperand(0);
   SDValue Amt = Op.getOperand(1);
 
+  unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
+    (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
+
   // Optimize shl/srl/sra with constant shift amount.
   if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
     if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
       uint64_t ShiftAmt = ShiftConst->getZExtValue();
 
-      if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
-          (Subtarget->hasInt256() &&
-           (VT == MVT::v4i64 || VT == MVT::v8i32 || VT == MVT::v16i16)) ||
-          (Subtarget->hasAVX512() &&
-           (VT == MVT::v8i64 || VT == MVT::v16i32))) {
-        if (Op.getOpcode() == ISD::SHL)
-          return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
-                                            DAG);
-        if (Op.getOpcode() == ISD::SRL)
-          return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
-                                            DAG);
-        if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
-          return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
-                                            DAG);
-      }
+      if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
+        return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
 
       if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
         unsigned NumElts = VT.getVectorNumElements();
@@ -16435,19 +16474,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
       if (ShAmt != ShiftAmt)
         return SDValue();
     }
-    switch (Op.getOpcode()) {
-    default:
-      llvm_unreachable("Unknown shift opcode!");
-    case ISD::SHL:
-      return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
-                                        DAG);
-    case ISD::SRL:
-      return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
-                                        DAG);
-    case ISD::SRA:
-      return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
-                                        DAG);
-    }
+    return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
   }
 
   return SDValue();
@@ -16460,12 +16487,13 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
   SDValue R = Op.getOperand(0);
   SDValue Amt = Op.getOperand(1);
 
-  if ((VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) ||
-      VT == MVT::v4i32 || VT == MVT::v8i16 ||
-      (Subtarget->hasInt256() &&
-       ((VT == MVT::v4i64 && Op.getOpcode() != ISD::SRA) ||
-        VT == MVT::v8i32 || VT == MVT::v16i16)) ||
-       (Subtarget->hasAVX512() && (VT == MVT::v8i64 || VT == MVT::v16i32))) {
+  unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
+    (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
+
+  unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
+    (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
+
+  if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
     SDValue BaseShAmt;
     EVT EltVT = VT.getVectorElementType();
 
@@ -16509,47 +16537,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
       else if (EltVT.bitsLT(MVT::i32))
         BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt);
 
-      switch (Op.getOpcode()) {
-      default:
-        llvm_unreachable("Unknown shift opcode!");
-      case ISD::SHL:
-        switch (VT.SimpleTy) {
-        default: return SDValue();
-        case MVT::v2i64:
-        case MVT::v4i32:
-        case MVT::v8i16:
-        case MVT::v4i64:
-        case MVT::v8i32:
-        case MVT::v16i16:
-        case MVT::v16i32:
-        case MVT::v8i64:
-          return getTargetVShiftNode(X86ISD::VSHLI, dl, VT, R, BaseShAmt, DAG);
-        }
-      case ISD::SRA:
-        switch (VT.SimpleTy) {
-        default: return SDValue();
-        case MVT::v4i32:
-        case MVT::v8i16:
-        case MVT::v8i32:
-        case MVT::v16i16:
-        case MVT::v16i32:
-        case MVT::v8i64:
-          return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, R, BaseShAmt, DAG);
-        }
-      case ISD::SRL:
-        switch (VT.SimpleTy) {
-        default: return SDValue();
-        case MVT::v2i64:
-        case MVT::v4i32:
-        case MVT::v8i16:
-        case MVT::v4i64:
-        case MVT::v8i32:
-        case MVT::v16i16:
-        case MVT::v16i32:
-        case MVT::v8i64:
-          return getTargetVShiftNode(X86ISD::VSRLI, dl, VT, R, BaseShAmt, DAG);
-        }
-      }
+      return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, DAG);
     }
   }
 
@@ -16568,18 +16556,8 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
         if (Vals[j] != Amt.getOperand(i + j))
           return SDValue();
     }
-    switch (Op.getOpcode()) {
-    default:
-      llvm_unreachable("Unknown shift opcode!");
-    case ISD::SHL:
-      return DAG.getNode(X86ISD::VSHL, dl, VT, R, Op.getOperand(1));
-    case ISD::SRL:
-      return DAG.getNode(X86ISD::VSRL, dl, VT, R, Op.getOperand(1));
-    case ISD::SRA:
-      return DAG.getNode(X86ISD::VSRA, dl, VT, R, Op.getOperand(1));
-    }
+    return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
   }
-
   return SDValue();
 }
 
@@ -16599,23 +16577,9 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
   if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget))
       return V;
 
-  if (Subtarget->hasAVX512() && (VT == MVT::v16i32 || VT == MVT::v8i64))
+  if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode()))
     return Op;
 
-  // AVX2 has VPSLLV/VPSRAV/VPSRLV.
-  if (Subtarget->hasInt256()) {
-    if (Op.getOpcode() == ISD::SRL &&
-        (VT == MVT::v2i64 || VT == MVT::v4i32 ||
-         VT == MVT::v4i64 || VT == MVT::v8i32))
-      return Op;
-    if (Op.getOpcode() == ISD::SHL &&
-        (VT == MVT::v2i64 || VT == MVT::v4i32 ||
-         VT == MVT::v4i64 || VT == MVT::v8i32))
-      return Op;
-    if (Op.getOpcode() == ISD::SRA && (VT == MVT::v4i32 || VT == MVT::v8i32))
-      return Op;
-  }
-
   // 2i64 vector logical shifts can efficiently avoid scalarization - do the
   // shifts per-lane and then shuffle the partial results back together.
   if (VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) {
index 0636cd2..10883a5 100644 (file)
@@ -1,4 +1,5 @@
 ;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=knl | FileCheck %s
+;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=skx | FileCheck --check-prefix=SKX %s
 
 ;CHECK-LABEL: shift_16_i32
 ;CHECK: vpsrld
@@ -24,6 +25,18 @@ define <8 x i64> @shift_8_i64(<8 x i64> %a) {
    ret <8 x i64> %d;
 }
 
+;SKX-LABEL: shift_4_i64
+;SKX: vpsrlq
+;SKX: vpsllq
+;SKX: vpsraq
+;SKX: ret
+define <4 x i64> @shift_4_i64(<4 x i64> %a) {
+   %b = lshr <4 x i64> %a, <i64 1, i64 1, i64 1, i64 1>
+   %c = shl <4 x i64> %b,  <i64 12, i64 12, i64 12, i64 12>
+   %d = ashr <4 x i64> %c, <i64 12, i64 12, i64 12, i64 12>
+   ret <4 x i64> %d;
+}
+
 ; CHECK-LABEL: variable_shl4
 ; CHECK: vpsllvq %zmm
 ; CHECK: ret
@@ -72,6 +85,22 @@ define <8 x i64> @variable_sra2(<8 x i64> %x, <8 x i64> %y) {
   ret <8 x i64> %k
 }
 
+; SKX-LABEL: variable_sra3
+; SKX: vpsravq %ymm
+; SKX: ret
+define <4 x i64> @variable_sra3(<4 x i64> %x, <4 x i64> %y) {
+  %k = ashr <4 x i64> %x, %y
+  ret <4 x i64> %k
+}
+
+; SKX-LABEL: variable_sra4
+; SKX: vpsravw %xmm
+; SKX: ret
+define <8 x i16> @variable_sra4(<8 x i16> %x, <8 x i16> %y) {
+  %k = ashr <8 x i16> %x, %y
+  ret <8 x i16> %k
+}
+
 ; CHECK-LABEL: variable_sra01_load
 ; CHECK: vpsravd (%
 ; CHECK: ret