[X86] Add getAVX512Node helper. NFC.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sat, 13 Nov 2021 13:59:42 +0000 (13:59 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sat, 13 Nov 2021 13:59:42 +0000 (13:59 +0000)
For AVX512 targets without VLX, we have to widen 128/256-bit vectors to 512-bits to use some specific AVX512 instructions (or some other instructions with predicates etc.).

I've pulled out the widening code from LowerFunnelShift into the helper function, so we can convert some other widening patterns in the future.

llvm/lib/Target/X86/X86ISelLowering.cpp

index 71dfd09..0d152d6 100644 (file)
@@ -6393,6 +6393,37 @@ SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
 }
 
+// Helper function that extends a non-512-bit vector op to 512-bits on non-VLX
+// targets.
+static SDValue getAVX512Node(unsigned Opcode, const SDLoc &DL, MVT VT,
+                             ArrayRef<SDValue> Ops, SelectionDAG &DAG,
+                             const X86Subtarget &Subtarget) {
+  assert(Subtarget.hasAVX512() && "AVX512 target expected");
+
+  // If we have VLX or the type is already 512-bits, then create the node
+  // directly.
+  if (Subtarget.hasVLX() || VT.is512BitVector())
+    return DAG.getNode(Opcode, DL, VT, Ops);
+
+  // Widen the vector ops.
+  MVT SVT = VT.getScalarType();
+  MVT WideVT = MVT::getVectorVT(SVT, 512 / SVT.getSizeInBits());
+  SmallVector<SDValue> WideOps(Ops.begin(), Ops.end());
+  for (SDValue &Op : WideOps) {
+    MVT OpVT = Op.getSimpleValueType();
+    // Just pass through scalar operands.
+    if (!OpVT.isVector())
+      continue;
+    assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
+           "Vector size mismatch");
+    Op = widenSubVector(Op, false, Subtarget, DAG, DL, 512);
+  }
+
+  // Perform the 512-bit op then extract the bottom subvector.
+  SDValue Res = DAG.getNode(Opcode, DL, WideVT, WideOps);
+  return extractSubVector(Res, 0, DAG, DL, VT.getSizeInBits());
+}
+
 /// Insert i1-subvector to i1-vector.
 static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
                                 const X86Subtarget &Subtarget) {
@@ -29593,29 +29624,15 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
     if (IsFSHR)
       std::swap(Op0, Op1);
 
-    // With AVX512, but not VLX we need to widen to get a 512-bit result type.
-    if (!Subtarget.hasVLX() && !VT.is512BitVector()) {
-      Op0 = widenSubVector(Op0, false, Subtarget, DAG, DL, 512);
-      Op1 = widenSubVector(Op1, false, Subtarget, DAG, DL, 512);
-    }
-
-    SDValue Funnel;
     APInt APIntShiftAmt;
-    MVT ResultVT = Op0.getSimpleValueType();
     if (X86::isConstantSplat(Amt, APIntShiftAmt)) {
       uint64_t ShiftAmt = APIntShiftAmt.urem(VT.getScalarSizeInBits());
-      Funnel =
-          DAG.getNode(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, ResultVT, Op0,
-                      Op1, DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
-    } else {
-      if (!Subtarget.hasVLX() && !VT.is512BitVector())
-        Amt = widenSubVector(Amt, false, Subtarget, DAG, DL, 512);
-      Funnel = DAG.getNode(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL,
-                           ResultVT, Op0, Op1, Amt);
-    }
-    if (!Subtarget.hasVLX() && !VT.is512BitVector())
-      Funnel = extractSubVector(Funnel, 0, DAG, DL, VT.getSizeInBits());
-    return Funnel;
+      SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
+      return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
+                           {Op0, Op1, Imm}, DAG, Subtarget);
+    }
+    return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT,
+                         {Op0, Op1, Amt}, DAG, Subtarget);
   }
   assert(
       (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&