[SelectionDAG] Add expansion and promotion of [US]MUL_LOHI
authorNicolai Haehnle <nhaehnle@gmail.com>
Thu, 8 Dec 2016 14:08:14 +0000 (14:08 +0000)
committerNicolai Haehnle <nhaehnle@gmail.com>
Thu, 8 Dec 2016 14:08:14 +0000 (14:08 +0000)
Summary:
Most targets set the action for these nodes to Expand even though there
isn't actually any code for them in ExpandNode. Instead, targets simply
relied on the fact that no code generates these nodes as long as the
nodes aren't legal or custom.

However, generating these nodes can be useful e.g. for divide-by-constant
in wider integer types.

Expand of [US]MUL_LOHI will use MULH[US] when legal or custom, and
a sequence of half-width multiplications otherwise. Promote uses a wider
multiply.

This patch intends to not change the generated code, but indirect effects
are possible since expansions/promotions that were previously done in
DAGCombine may now be done in LegalizeDAG.

See D24822 for a change that actually uses the new expansion.

Reviewers: spatel, bkramer, venkatra, efriedma, hfinkel, ast, nadav, tstellarAMD

Subscribers: arsenm, jyknight, nemanjai, wdng, nhaehnle, llvm-commits

Differential Revision: https://reviews.llvm.org/D24956

llvm-svn: 289050

llvm/include/llvm/Target/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

index f9a7697..c47a21e 100644 (file)
@@ -142,6 +142,13 @@ public:
     CmpXChg, // Expand the instruction into cmpxchg; used by at least X86.
   };
 
+  /// Enum that specifies when a multiplication should be expanded.
+  enum class MulExpansionKind {
+    Always,            // Always expand the instruction.
+    OnlyLegalOrCustom, // Only expand when the resulting instructions are legal
+                       // or custom.
+  };
+
   static ISD::NodeType getExtendForContent(BooleanContent Content) {
     switch (Content) {
     case UndefinedBooleanContent:
@@ -3036,9 +3043,28 @@ public:
   // Legalization utility functions
   //
 
+  /// Expand a MUL or [US]MUL_LOHI of n-bit values into two or four nodes,
+  /// respectively, each computing an n/2-bit part of the result.
+  /// \param Result A vector that will be filled with the parts of the result
+  ///        in little-endian order.
+  /// \param HalfVT The value type to use for the result nodes.
+  /// \param OnlyLegalOrCustom Only legal or custom instructions are used.
+  /// \param LL Low bits of the LHS of the MUL.  You can use this parameter
+  ///        if you want to control how low bits are extracted from the LHS.
+  /// \param LH High bits of the LHS of the MUL.  See LL for meaning.
+  /// \param RL Low bits of the RHS of the MUL.  See LL for meaning
+  /// \param RH High bits of the RHS of the MUL.  See LL for meaning.
+  /// \returns true if the node has been expanded, false if it has not
+  bool expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, SDValue LHS,
+                      SDValue RHS, SmallVectorImpl<SDValue> &Result, EVT HiLoVT,
+                      SelectionDAG &DAG, MulExpansionKind Kind,
+                      SDValue LL = SDValue(), SDValue LH = SDValue(),
+                      SDValue RL = SDValue(), SDValue RH = SDValue()) const;
+
   /// Expand a MUL into two nodes.  One that computes the high bits of
   /// the result and one that computes the low bits.
   /// \param HiLoVT The value type to use for the Lo and Hi nodes.
+  /// \param OnlyLegalOrCustom Only legal or custom instructions are used.
   /// \param LL Low bits of the LHS of the MUL.  You can use this parameter
   ///        if you want to control how low bits are extracted from the LHS.
   /// \param LH High bits of the LHS of the MUL.  See LL for meaning.
@@ -3046,9 +3072,9 @@ public:
   /// \param RH High bits of the RHS of the MUL.  See LL for meaning.
   /// \returns true if the node has been expanded. false if it has not
   bool expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
-                 SelectionDAG &DAG, SDValue LL = SDValue(),
-                 SDValue LH = SDValue(), SDValue RL = SDValue(),
-                 SDValue RH = SDValue()) const;
+                 SelectionDAG &DAG, MulExpansionKind Kind,
+                 SDValue LL = SDValue(), SDValue LH = SDValue(),
+                 SDValue RL = SDValue(), SDValue RH = SDValue()) const;
 
   /// Expand float(f32) to SINT(i64) conversion
   /// \param N Node to expand
index 7a04e0d..3485e35 100644 (file)
@@ -3312,17 +3312,49 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   }
   case ISD::MULHU:
   case ISD::MULHS: {
-    unsigned ExpandOpcode = Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI :
-                                                              ISD::SMUL_LOHI;
+    unsigned ExpandOpcode =
+        Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : ISD::SMUL_LOHI;
     EVT VT = Node->getValueType(0);
     SDVTList VTs = DAG.getVTList(VT, VT);
-    assert(TLI.isOperationLegalOrCustom(ExpandOpcode, VT) &&
-           "If this wasn't legal, it shouldn't have been created!");
+
     Tmp1 = DAG.getNode(ExpandOpcode, dl, VTs, Node->getOperand(0),
                        Node->getOperand(1));
     Results.push_back(Tmp1.getValue(1));
     break;
   }
+  case ISD::UMUL_LOHI:
+  case ISD::SMUL_LOHI: {
+    SDValue LHS = Node->getOperand(0);
+    SDValue RHS = Node->getOperand(1);
+    MVT VT = LHS.getSimpleValueType();
+    unsigned MULHOpcode =
+        Node->getOpcode() == ISD::UMUL_LOHI ? ISD::MULHU : ISD::MULHS;
+
+    if (TLI.isOperationLegalOrCustom(MULHOpcode, VT)) {
+      Results.push_back(DAG.getNode(ISD::MUL, dl, VT, LHS, RHS));
+      Results.push_back(DAG.getNode(MULHOpcode, dl, VT, LHS, RHS));
+      break;
+    }
+
+    SmallVector<SDValue, 4> Halves;
+    EVT HalfType = EVT(VT).getHalfSizedIntegerVT(*DAG.getContext());
+    assert(TLI.isTypeLegal(HalfType));
+    if (TLI.expandMUL_LOHI(Node->getOpcode(), VT, Node, LHS, RHS, Halves,
+                           HalfType, DAG,
+                           TargetLowering::MulExpansionKind::Always)) {
+      for (unsigned i = 0; i < 2; ++i) {
+        SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Halves[2 * i]);
+        SDValue Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Halves[2 * i + 1]);
+        SDValue Shift = DAG.getConstant(
+            HalfType.getScalarSizeInBits(), dl,
+            TLI.getShiftAmountTy(HalfType, DAG.getDataLayout()));
+        Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
+        Results.push_back(DAG.getNode(ISD::OR, dl, VT, Lo, Hi));
+      }
+      break;
+    }
+    break;
+  }
   case ISD::MUL: {
     EVT VT = Node->getValueType(0);
     SDVTList VTs = DAG.getVTList(VT, VT);
@@ -3357,7 +3389,8 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
         TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND, VT) &&
         TLI.isOperationLegalOrCustom(ISD::SHL, VT) &&
         TLI.isOperationLegalOrCustom(ISD::OR, VT) &&
-        TLI.expandMUL(Node, Lo, Hi, HalfType, DAG)) {
+        TLI.expandMUL(Node, Lo, Hi, HalfType, DAG,
+                      TargetLowering::MulExpansionKind::OnlyLegalOrCustom)) {
       Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
       Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Hi);
       SDValue Shift =
@@ -4194,6 +4227,24 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     Results.push_back(DAG.getNode(TruncOp, dl, OVT, Tmp1));
     break;
   }
+  case ISD::UMUL_LOHI:
+  case ISD::SMUL_LOHI: {
+    // Promote to a multiply in a wider integer type.
+    unsigned ExtOp = Node->getOpcode() == ISD::UMUL_LOHI ? ISD::ZERO_EXTEND
+                                                         : ISD::SIGN_EXTEND;
+    Tmp1 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(0));
+    Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
+    Tmp1 = DAG.getNode(ISD::MUL, dl, NVT, Tmp1, Tmp2);
+
+    auto &DL = DAG.getDataLayout();
+    unsigned OriginalSize = OVT.getScalarSizeInBits();
+    Tmp2 = DAG.getNode(
+        ISD::SRL, dl, NVT, Tmp1,
+        DAG.getConstant(OriginalSize, dl, TLI.getScalarShiftAmountTy(DL, NVT)));
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
+    break;
+  }
   case ISD::SELECT: {
     unsigned ExtOp, TruncOp;
     if (Node->getValueType(0).isVector() ||
index 12fe168..9d07371 100644 (file)
@@ -2189,7 +2189,9 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
   GetExpandedInteger(N->getOperand(0), LL, LH);
   GetExpandedInteger(N->getOperand(1), RL, RH);
 
-  if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, LL, LH, RL, RH))
+  if (TLI.expandMUL(N, Lo, Hi, NVT, DAG,
+                    TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
+                    LL, LH, RL, RH))
     return;
 
   // If nothing else, we can make a libcall.
index 8eba6a3..d4fa20f 100644 (file)
@@ -333,6 +333,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::SMAX:
   case ISD::UMIN:
   case ISD::UMAX:
+  case ISD::SMUL_LOHI:
+  case ISD::UMUL_LOHI:
     QueryType = Node->getValueType(0);
     break;
   case ISD::FP_ROUND_INREG:
index 6a7de14..4cc04f3 100644 (file)
@@ -3093,24 +3093,29 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
 // Legalization Utilities
 //===----------------------------------------------------------------------===//
 
-bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
-                               SelectionDAG &DAG, SDValue LL, SDValue LH,
-                               SDValue RL, SDValue RH) const {
-  EVT VT = N->getValueType(0);
-  SDLoc dl(N);
-
-  bool HasMULHS = isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
-  bool HasMULHU = isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
-  bool HasSMUL_LOHI = isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
-  bool HasUMUL_LOHI = isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
+bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl,
+                                    SDValue LHS, SDValue RHS,
+                                    SmallVectorImpl<SDValue> &Result,
+                                    EVT HiLoVT, SelectionDAG &DAG,
+                                    MulExpansionKind Kind, SDValue LL,
+                                    SDValue LH, SDValue RL, SDValue RH) const {
+  assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI ||
+         Opcode == ISD::SMUL_LOHI);
+
+  bool HasMULHS = (Kind == MulExpansionKind::Always) ||
+                  isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
+  bool HasMULHU = (Kind == MulExpansionKind::Always) ||
+                  isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
+  bool HasSMUL_LOHI = (Kind == MulExpansionKind::Always) ||
+                      isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
+  bool HasUMUL_LOHI = (Kind == MulExpansionKind::Always) ||
+                      isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
 
   if (!HasMULHU && !HasMULHS && !HasUMUL_LOHI && !HasSMUL_LOHI)
     return false;
 
   unsigned OuterBitSize = VT.getScalarSizeInBits();
   unsigned InnerBitSize = HiLoVT.getScalarSizeInBits();
-  SDValue LHS = N->getOperand(0);
-  SDValue RHS = N->getOperand(1);
   unsigned LHSSB = DAG.ComputeNumSignBits(LHS);
   unsigned RHSSB = DAG.ComputeNumSignBits(RHS);
 
@@ -3134,6 +3139,8 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
     return false;
   };
 
+  SDValue Lo, Hi;
+
   if (!LL.getNode() && !RL.getNode() &&
       isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
     LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LHS);
@@ -3144,20 +3151,41 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
     return false;
 
   APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize);
-  if (DAG.MaskedValueIsZero(N->getOperand(0), HighMask) &&
-      DAG.MaskedValueIsZero(N->getOperand(1), HighMask)) {
+  if (DAG.MaskedValueIsZero(LHS, HighMask) &&
+      DAG.MaskedValueIsZero(RHS, HighMask)) {
     // The inputs are both zero-extended.
-    if (MakeMUL_LOHI(LL, RL, Lo, Hi, false))
+    if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
+      Result.push_back(Lo);
+      Result.push_back(Hi);
+      if (Opcode != ISD::MUL) {
+        SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
+        Result.push_back(Zero);
+        Result.push_back(Zero);
+      }
       return true;
+    }
   }
-  if (LHSSB > InnerBitSize && RHSSB > InnerBitSize) {
+
+  if (!VT.isVector() && Opcode == ISD::MUL && LHSSB > InnerBitSize &&
+      RHSSB > InnerBitSize) {
     // The input values are both sign-extended.
-    if (MakeMUL_LOHI(LL, RL, Lo, Hi, true))
+    // TODO non-MUL case?
+    if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) {
+      Result.push_back(Lo);
+      Result.push_back(Hi);
       return true;
+    }
   }
 
-  SDValue Shift = DAG.getConstant(OuterBitSize - InnerBitSize, dl,
-                                  getShiftAmountTy(VT, DAG.getDataLayout()));
+  unsigned ShiftAmount = OuterBitSize - InnerBitSize;
+  EVT ShiftAmountTy = getShiftAmountTy(VT, DAG.getDataLayout());
+  if (APInt::getMaxValue(ShiftAmountTy.getSizeInBits()).ult(ShiftAmount)) {
+    // FIXME getShiftAmountTy does not always return a sensible result when VT
+    // is an illegal type, and so the type may be too small to fit the shift
+    // amount. Override it with i32. The shift will have to be legalized.
+    ShiftAmountTy = MVT::i32;
+  }
+  SDValue Shift = DAG.getConstant(ShiftAmount, dl, ShiftAmountTy);
 
   if (!LH.getNode() && !RH.getNode() &&
       isOperationLegalOrCustom(ISD::SRL, VT) &&
@@ -3171,15 +3199,84 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
   if (!LH.getNode())
     return false;
 
-  if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
+  if (!MakeMUL_LOHI(LL, RL, Lo, Hi, false))
+    return false;
+
+  Result.push_back(Lo);
+
+  if (Opcode == ISD::MUL) {
     RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH);
     LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL);
     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH);
     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH);
+    Result.push_back(Hi);
     return true;
   }
 
-  return false;
+  // Compute the full width result.
+  auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue {
+    Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
+    Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
+    Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
+    return DAG.getNode(ISD::OR, dl, VT, Lo, Hi);
+  };
+
+  SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
+  if (!MakeMUL_LOHI(LL, RH, Lo, Hi, false))
+    return false;
+
+  // This is effectively the add part of a multiply-add of half-sized operands,
+  // so it cannot overflow.
+  Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
+
+  if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false))
+    return false;
+
+  Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next,
+                     Merge(Lo, Hi));
+
+  SDValue Carry = Next.getValue(1);
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
+
+  if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI))
+    return false;
+
+  SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
+  Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero,
+                   Carry);
+  Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
+
+  if (Opcode == ISD::SMUL_LOHI) {
+    SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
+                                  DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL));
+    Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT);
+
+    NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
+                          DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL));
+    Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT);
+  }
+
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  return true;
+}
+
+bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
+                               SelectionDAG &DAG, MulExpansionKind Kind,
+                               SDValue LL, SDValue LH, SDValue RL,
+                               SDValue RH) const {
+  SmallVector<SDValue, 2> Result;
+  bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), N,
+                           N->getOperand(0), N->getOperand(1), Result, HiLoVT,
+                           DAG, Kind, LL, LH, RL, RH);
+  if (Ok) {
+    assert(Result.size() == 2);
+    Lo = Result[0];
+    Hi = Result[1];
+  }
+  return Ok;
 }
 
 bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,