[NVPTX] Implement fma and imad contraction as target DAGCombiner patterns
authorJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:37 +0000 (18:35 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:37 +0000 (18:35 +0000)
This also introduces DAGCombiner patterns for mul.wide to multiply two smaller integers and produce a larger integer

llvm-svn: 211935

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/test/CodeGen/NVPTX/imad.ll [new file with mode: 0644]
llvm/test/CodeGen/NVPTX/mulwide.ll [new file with mode: 0644]

index 1ea47fc..21e4ba5 100644 (file)
@@ -24,11 +24,14 @@ using namespace llvm;
 
 #define DEBUG_TYPE "nvptx-isel"
 
-static cl::opt<int>
-FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden,
-                 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
-                          " 1: do it  2: do it aggressively"),
-                 cl::init(2));
+unsigned FMAContractLevel = 0;
+
+static cl::opt<unsigned, true>
+FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden,
+                    cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
+                             " 1: do it  2: do it aggressively"),
+                    cl::location(FMAContractLevel),
+                    cl::init(2));
 
 static cl::opt<int> UsePrecDivF32(
     "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden,
index 60da5f1..09e5a61 100644 (file)
@@ -243,6 +243,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
 
+  // We have some custom DAG combine patterns for these nodes
+  setTargetDAGCombine(ISD::ADD);
+  setTargetDAGCombine(ISD::AND);
+  setTargetDAGCombine(ISD::FADD);
+  setTargetDAGCombine(ISD::MUL);
+  setTargetDAGCombine(ISD::SHL);
+
   // Now deduce the information based on the above mentioned
   // actions
   computeRegisterProperties();
@@ -334,6 +341,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     return "NVPTXISD::StoreV2";
   case NVPTXISD::StoreV4:
     return "NVPTXISD::StoreV4";
+  case NVPTXISD::FUN_SHFL_CLAMP:
+    return "NVPTXISD::FUN_SHFL_CLAMP";
+  case NVPTXISD::FUN_SHFR_CLAMP:
+    return "NVPTXISD::FUN_SHFR_CLAMP";
   case NVPTXISD::Tex1DFloatI32:        return "NVPTXISD::Tex1DFloatI32";
   case NVPTXISD::Tex1DFloatFloat:      return "NVPTXISD::Tex1DFloatFloat";
   case NVPTXISD::Tex1DFloatFloatLevel:
@@ -2475,6 +2486,406 @@ unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
   return 4;
 }
 
+//===----------------------------------------------------------------------===//
+//                         NVPTX DAG Combining
+//===----------------------------------------------------------------------===//
+
+extern unsigned FMAContractLevel;
+
+/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
+/// operands N0 and N1.  This is a helper for PerformADDCombine that is
+/// called with the default operands, and if that fails, with commuted
+/// operands.
+static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                             const NVPTXSubtarget &Subtarget,
+                                             CodeGenOpt::Level OptLevel) {
+  SelectionDAG  &DAG = DCI.DAG;
+  // Skip non-integer, non-scalar case
+  EVT VT=N0.getValueType();
+  if (VT.isVector())
+    return SDValue();
+
+  // fold (add (mul a, b), c) -> (mad a, b, c)
+  //
+  if (N0.getOpcode() == ISD::MUL) {
+    assert (VT.isInteger());
+    // For integer:
+    // Since integer multiply-add costs the same as integer multiply
+    // but is more costly than integer add, do the fusion only when
+    // the mul is only used in the add.
+    if (OptLevel==CodeGenOpt::None || VT != MVT::i32 ||
+        !N0.getNode()->hasOneUse())
+      return SDValue();
+
+    // Do the folding
+    return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
+                       N0.getOperand(0), N0.getOperand(1), N1);
+  }
+  else if (N0.getOpcode() == ISD::FMUL) {
+    if (VT == MVT::f32 || VT == MVT::f64) {
+      if (FMAContractLevel == 0)
+        return SDValue();
+
+      // For floating point:
+      // Do the fusion only when the mul has less than 5 uses and all
+      // are add.
+      // The heuristic is that if a use is not an add, then that use
+      // cannot be fused into fma, therefore mul is still needed anyway.
+      // If there are more than 4 uses, even if they are all add, fusing
+      // them will increase register pressue.
+      //
+      int numUses = 0;
+      int nonAddCount = 0;
+      for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
+           UE = N0.getNode()->use_end();
+           UI != UE; ++UI) {
+        numUses++;
+        SDNode *User = *UI;
+        if (User->getOpcode() != ISD::FADD)
+          ++nonAddCount;
+      }
+      if (numUses >= 5)
+        return SDValue();
+      if (nonAddCount) {
+        int orderNo = N->getIROrder();
+        int orderNo2 = N0.getNode()->getIROrder();
+        // simple heuristics here for considering potential register
+        // pressure, the logics here is that the differnce are used
+        // to measure the distance between def and use, the longer distance
+        // more likely cause register pressure.
+        if (orderNo - orderNo2 < 500)
+          return SDValue();
+
+        // Now, check if at least one of the FMUL's operands is live beyond the node N,
+        // which guarantees that the FMA will not increase register pressure at node N.
+        bool opIsLive = false;
+        const SDNode *left = N0.getOperand(0).getNode();
+        const SDNode *right = N0.getOperand(1).getNode();
+
+        if (dyn_cast<ConstantSDNode>(left) || dyn_cast<ConstantSDNode>(right))
+          opIsLive = true;
+
+        if (!opIsLive)
+          for (SDNode::use_iterator UI = left->use_begin(), UE = left->use_end(); UI != UE; ++UI) {
+            SDNode *User = *UI;
+            int orderNo3 = User->getIROrder();
+            if (orderNo3 > orderNo) {
+              opIsLive = true;
+              break;
+            }
+          }
+
+        if (!opIsLive)
+          for (SDNode::use_iterator UI = right->use_begin(), UE = right->use_end(); UI != UE; ++UI) {
+            SDNode *User = *UI;
+            int orderNo3 = User->getIROrder();
+            if (orderNo3 > orderNo) {
+              opIsLive = true;
+              break;
+            }
+          }
+
+        if (!opIsLive)
+          return SDValue();
+      }
+
+      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
+                         N0.getOperand(0), N0.getOperand(1), N1);
+    }
+  }
+
+  return SDValue();
+}
+
+/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
+///
+static SDValue PerformADDCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const NVPTXSubtarget &Subtarget,
+                                 CodeGenOpt::Level OptLevel) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+
+  // First try with the default operand order.
+  SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget,
+                                                 OptLevel);
+  if (Result.getNode())
+    return Result;
+
+  // If that didn't work, try again with the operands commuted.
+  return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
+}
+
+static SDValue PerformANDCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  // The type legalizer turns a vector load of i8 values into a zextload to i16
+  // registers, optionally ANY_EXTENDs it (if target type is integer),
+  // and ANDs off the high 8 bits. Since we turn this load into a
+  // target-specific DAG node, the DAG combiner fails to eliminate these AND
+  // nodes. Do that here.
+  SDValue Val = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+
+  if (isa<ConstantSDNode>(Val)) {
+    std::swap(Val, Mask);
+  }
+
+  SDValue AExt;
+  // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
+  if (Val.getOpcode() == ISD::ANY_EXTEND) {
+    AExt = Val;
+    Val = Val->getOperand(0);
+  }
+
+  if (Val->isMachineOpcode() && Val->getMachineOpcode() == NVPTX::IMOV16rr) {
+    Val = Val->getOperand(0);
+  }
+
+  if (Val->getOpcode() == NVPTXISD::LoadV2 ||
+      Val->getOpcode() == NVPTXISD::LoadV4) {
+    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+    if (!MaskCnst) {
+      // Not an AND with a constant
+      return SDValue();
+    }
+
+    uint64_t MaskVal = MaskCnst->getZExtValue();
+    if (MaskVal != 0xff) {
+      // Not an AND that chops off top 8 bits
+      return SDValue();
+    }
+
+    MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+    if (!Mem) {
+      // Not a MemSDNode?!?
+      return SDValue();
+    }
+
+    EVT MemVT = Mem->getMemoryVT();
+    if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+      // We only handle the i8 case
+      return SDValue();
+    }
+
+    unsigned ExtType =
+      cast<ConstantSDNode>(Val->getOperand(Val->getNumOperands()-1))->
+        getZExtValue();
+    if (ExtType == ISD::SEXTLOAD) {
+      // If for some reason the load is a sextload, the and is needed to zero
+      // out the high 8 bits
+      return SDValue();
+    }
+
+    bool AddTo = false;
+    if (AExt.getNode() != 0) {
+      // Re-insert the ext as a zext.
+      Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
+                            AExt.getValueType(), Val);
+      AddTo = true;
+    }
+
+    // If we get here, the AND is unnecessary.  Just replace it with the load
+    DCI.CombineTo(N, Val, AddTo);
+  }
+
+  return SDValue();
+}
+
+enum OperandSignedness {
+  Signed = 0,
+  Unsigned,
+  Unknown
+};
+
+/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
+/// that can be demoted to \p OptSize bits without loss of information. The
+/// signedness of the operand, if determinable, is placed in \p S.
+static bool IsMulWideOperandDemotable(SDValue Op,
+                                      unsigned OptSize,
+                                      OperandSignedness &S) {
+  S = Unknown;
+
+  if (Op.getOpcode() == ISD::SIGN_EXTEND ||
+      Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+    EVT OrigVT = Op.getOperand(0).getValueType();
+    if (OrigVT.getSizeInBits() == OptSize) {
+      S = Signed;
+      return true;
+    }
+  } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
+    EVT OrigVT = Op.getOperand(0).getValueType();
+    if (OrigVT.getSizeInBits() == OptSize) {
+      S = Unsigned;
+      return true;
+    }
+  }
+
+  return false;
+}
+
+/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
+/// be demoted to \p OptSize bits without loss of information. If the operands
+/// contain a constant, it should appear as the RHS operand. The signedness of
+/// the operands is placed in \p IsSigned.
+static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
+                                        unsigned OptSize,
+                                        bool &IsSigned) {
+
+  OperandSignedness LHSSign;
+
+  // The LHS operand must be a demotable op
+  if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
+    return false;
+
+  // We should have been able to determine the signedness from the LHS
+  if (LHSSign == Unknown)
+    return false;
+
+  IsSigned = (LHSSign == Signed);
+
+  // The RHS can be a demotable op or a constant
+  if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(RHS)) {
+    APInt Val = CI->getAPIntValue();
+    if (LHSSign == Unsigned) {
+      if (Val.isIntN(OptSize)) {
+        return true;
+      }
+      return false;
+    } else {
+      if (Val.isSignedIntN(OptSize)) {
+        return true;
+      }
+      return false;
+    }
+  } else {
+    OperandSignedness RHSSign;
+    if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
+      return false;
+
+    if (LHSSign != RHSSign)
+      return false;
+
+    return true;
+  }
+}
+
+/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
+/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
+/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
+/// amount.
+static SDValue TryMULWIDECombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  EVT MulType = N->getValueType(0);
+  if (MulType != MVT::i32 && MulType != MVT::i64) {
+    return SDValue();
+  }
+
+  unsigned OptSize = MulType.getSizeInBits() >> 1;
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  // Canonicalize the multiply so the constant (if any) is on the right
+  if (N->getOpcode() == ISD::MUL) {
+    if (isa<ConstantSDNode>(LHS)) {
+      std::swap(LHS, RHS);
+    }
+  }
+
+  // If we have a SHL, determine the actual multiply amount
+  if (N->getOpcode() == ISD::SHL) {
+    ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(RHS);
+    if (!ShlRHS) {
+      return SDValue();
+    }
+
+    APInt ShiftAmt = ShlRHS->getAPIntValue();
+    unsigned BitWidth = MulType.getSizeInBits();
+    if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
+      APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
+      RHS = DCI.DAG.getConstant(MulVal, MulType);
+    } else {
+      return SDValue();
+    }
+  }
+
+  bool Signed;
+  // Verify that our operands are demotable
+  if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
+    return SDValue();
+  }
+
+  EVT DemotedVT;
+  if (MulType == MVT::i32) {
+    DemotedVT = MVT::i16;
+  } else {
+    DemotedVT = MVT::i32;
+  }
+
+  // Truncate the operands to the correct size. Note that these are just for
+  // type consistency and will (likely) be eliminated in later phases.
+  SDValue TruncLHS =
+    DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, LHS);
+  SDValue TruncRHS =
+    DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, RHS);
+
+  unsigned Opc;
+  if (Signed) {
+    Opc = NVPTXISD::MUL_WIDE_SIGNED;
+  } else {
+    Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
+  }
+
+  return DCI.DAG.getNode(Opc, SDLoc(N), MulType, TruncLHS, TruncRHS);
+}
+
+/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
+static SDValue PerformMULCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 CodeGenOpt::Level OptLevel) {
+  if (OptLevel > 0) {
+    // Try mul.wide combining at OptLevel > 0
+    SDValue Ret = TryMULWIDECombine(N, DCI);
+    if (Ret.getNode())
+      return Ret;
+  }
+
+  return SDValue();
+}
+
+/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
+static SDValue PerformSHLCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 CodeGenOpt::Level OptLevel) {
+  if (OptLevel > 0) {
+    // Try mul.wide combining at OptLevel > 0
+    SDValue Ret = TryMULWIDECombine(N, DCI);
+    if (Ret.getNode())
+      return Ret;
+  }
+
+  return SDValue();
+}
+
+SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
+                                               DAGCombinerInfo &DCI) const {
+  // FIXME: Get this from the DAG somehow
+  CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
+  switch (N->getOpcode()) {
+    default: break;
+    case ISD::ADD:
+    case ISD::FADD:
+      return PerformADDCombine(N, DCI, nvptxSubtarget, OptLevel);
+    case ISD::MUL:
+      return PerformMULCombine(N, DCI, OptLevel);
+    case ISD::SHL:
+      return PerformSHLCombine(N, DCI, OptLevel);
+    case ISD::AND:
+      return PerformANDCombine(N, DCI);
+  }
+  return SDValue();
+}
+
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
                               SmallVectorImpl<SDValue> &Results) {
index 4f5e025..e909e0a 100644 (file)
@@ -49,6 +49,9 @@ enum NodeType {
   CallSeqBegin,
   CallSeqEnd,
   CallPrototype,
+  MUL_WIDE_SIGNED,
+  MUL_WIDE_UNSIGNED,
+  IMAD,
   Dummy,
 
   LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -258,6 +261,7 @@ private:
 
   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
                           SelectionDAG &DAG) const override;
+  SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
 
   unsigned getArgumentAlignment(SDValue Callee, const ImmutableCallSite *CS,
                                 Type *Ty, unsigned Idx) const;
index 725d6fc..99309f9 100644 (file)
@@ -464,33 +464,45 @@ def SHL2MUL16 : SDNodeXForm<imm, [{
   return CurDAG->getTargetConstant(temp.shl(v), MVT::i16);
 }]>;
 
-def MULWIDES64 : NVPTXInst<(outs Int64Regs:$dst),
-                           (ins Int32Regs:$a, Int32Regs:$b),
+def MULWIDES64
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
+              "mul.wide.s32 \t$dst, $a, $b;", []>;
+def MULWIDES64Imm
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
                            "mul.wide.s32 \t$dst, $a, $b;", []>;
-def MULWIDES64Imm : NVPTXInst<(outs Int64Regs:$dst),
-                            (ins Int32Regs:$a, i64imm:$b),
+def MULWIDES64Imm64
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
                            "mul.wide.s32 \t$dst, $a, $b;", []>;
 
-def MULWIDEU64 : NVPTXInst<(outs Int64Regs:$dst),
-                           (ins Int32Regs:$a, Int32Regs:$b),
+def MULWIDEU64
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
+              "mul.wide.u32 \t$dst, $a, $b;", []>;
+def MULWIDEU64Imm
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
                            "mul.wide.u32 \t$dst, $a, $b;", []>;
-def MULWIDEU64Imm : NVPTXInst<(outs Int64Regs:$dst),
-                            (ins Int32Regs:$a, i64imm:$b),
+def MULWIDEU64Imm64
+  : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
                            "mul.wide.u32 \t$dst, $a, $b;", []>;
 
-def MULWIDES32 : NVPTXInst<(outs Int32Regs:$dst),
-                            (ins Int16Regs:$a, Int16Regs:$b),
+def MULWIDES32
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
                            "mul.wide.s16 \t$dst, $a, $b;", []>;
-def MULWIDES32Imm : NVPTXInst<(outs Int32Regs:$dst),
-                            (ins Int16Regs:$a, i32imm:$b),
+def MULWIDES32Imm
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
+              "mul.wide.s16 \t$dst, $a, $b;", []>;
+def MULWIDES32Imm32
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
                            "mul.wide.s16 \t$dst, $a, $b;", []>;
 
-def MULWIDEU32 : NVPTXInst<(outs Int32Regs:$dst),
-                            (ins Int16Regs:$a, Int16Regs:$b),
-                           "mul.wide.u16 \t$dst, $a, $b;", []>;
-def MULWIDEU32Imm : NVPTXInst<(outs Int32Regs:$dst),
-                            (ins Int16Regs:$a, i32imm:$b),
+def MULWIDEU32
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
+              "mul.wide.u16 \t$dst, $a, $b;", []>;
+def MULWIDEU32Imm
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
                            "mul.wide.u16 \t$dst, $a, $b;", []>;
+def MULWIDEU32Imm32
+  : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
+                            "mul.wide.u16 \t$dst, $a, $b;", []>;
 
 def : Pat<(shl (sext Int32Regs:$a), (i32 Int5Const:$b)),
           (MULWIDES64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>,
@@ -510,25 +522,63 @@ def : Pat<(mul (sext Int32Regs:$a), (sext Int32Regs:$b)),
           (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
           Requires<[doMulWide]>;
 def : Pat<(mul (sext Int32Regs:$a), (i64 SInt32Const:$b)),
-          (MULWIDES64Imm Int32Regs:$a, (i64 SInt32Const:$b))>,
+          (MULWIDES64Imm64 Int32Regs:$a, (i64 SInt32Const:$b))>,
           Requires<[doMulWide]>;
 
 def : Pat<(mul (zext Int32Regs:$a), (zext Int32Regs:$b)),
-          (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, Requires<[doMulWide]>;
+          (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
+      Requires<[doMulWide]>;
 def : Pat<(mul (zext Int32Regs:$a), (i64 UInt32Const:$b)),
-          (MULWIDEU64Imm Int32Regs:$a, (i64 UInt32Const:$b))>,
+          (MULWIDEU64Imm64 Int32Regs:$a, (i64 UInt32Const:$b))>,
           Requires<[doMulWide]>;
 
 def : Pat<(mul (sext Int16Regs:$a), (sext Int16Regs:$b)),
-          (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
+          (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
+      Requires<[doMulWide]>;
 def : Pat<(mul (sext Int16Regs:$a), (i32 SInt16Const:$b)),
-          (MULWIDES32Imm Int16Regs:$a, (i32 SInt16Const:$b))>,
+          (MULWIDES32Imm32 Int16Regs:$a, (i32 SInt16Const:$b))>,
           Requires<[doMulWide]>;
 
 def : Pat<(mul (zext Int16Regs:$a), (zext Int16Regs:$b)),
-          (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
+          (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
+      Requires<[doMulWide]>;
 def : Pat<(mul (zext Int16Regs:$a), (i32 UInt16Const:$b)),
-          (MULWIDEU32Imm Int16Regs:$a, (i32 UInt16Const:$b))>,
+          (MULWIDEU32Imm32 Int16Regs:$a, (i32 UInt16Const:$b))>,
+          Requires<[doMulWide]>;
+
+
+def SDTMulWide
+  : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
+def mul_wide_signed
+  : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
+def mul_wide_unsigned
+  : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
+
+def : Pat<(i32 (mul_wide_signed Int16Regs:$a, Int16Regs:$b)),
+          (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
+      Requires<[doMulWide]>;
+def : Pat<(i32 (mul_wide_signed Int16Regs:$a, imm:$b)),
+          (MULWIDES32Imm Int16Regs:$a, imm:$b)>,
+          Requires<[doMulWide]>;
+def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, Int16Regs:$b)),
+          (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
+          Requires<[doMulWide]>;
+def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, imm:$b)),
+          (MULWIDEU32Imm Int16Regs:$a, imm:$b)>,
+          Requires<[doMulWide]>;
+
+
+def : Pat<(i64 (mul_wide_signed Int32Regs:$a, Int32Regs:$b)),
+          (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
+          Requires<[doMulWide]>;
+def : Pat<(i64 (mul_wide_signed Int32Regs:$a, imm:$b)),
+          (MULWIDES64Imm Int32Regs:$a, imm:$b)>,
+          Requires<[doMulWide]>;
+def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, Int32Regs:$b)),
+          (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
+          Requires<[doMulWide]>;
+def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, imm:$b)),
+          (MULWIDEU64Imm Int32Regs:$a, imm:$b)>,
           Requires<[doMulWide]>;
 
 defm MULT : I3<"mul.lo.s", mul>;
@@ -544,69 +594,75 @@ defm SREM : I3<"rem.s", srem>;
 defm UREM : I3<"rem.u", urem>;
 // The ri version will not be selected as DAGCombiner::visitUREM will lower it.
 
+def SDTIMAD
+  : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
+                         SDTCisInt<2>, SDTCisSameAs<0, 2>,
+                         SDTCisSameAs<0, 3>]>;
+def imad
+  : SDNode<"NVPTXISD::IMAD", SDTIMAD>;
+
 def MAD16rrr : NVPTXInst<(outs Int16Regs:$dst),
                       (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
                       "mad.lo.s16 \t$dst, $a, $b, $c;",
-                      [(set Int16Regs:$dst, (add
-                        (mul Int16Regs:$a, Int16Regs:$b), Int16Regs:$c))]>;
+                      [(set Int16Regs:$dst,
+                         (imad Int16Regs:$a, Int16Regs:$b, Int16Regs:$c))]>;
 def MAD16rri : NVPTXInst<(outs Int16Regs:$dst),
                       (ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
                       "mad.lo.s16 \t$dst, $a, $b, $c;",
-                      [(set Int16Regs:$dst, (add
-                        (mul Int16Regs:$a, Int16Regs:$b), imm:$c))]>;
+                      [(set Int16Regs:$dst,
+                         (imad Int16Regs:$a, Int16Regs:$b, imm:$c))]>;
 def MAD16rir : NVPTXInst<(outs Int16Regs:$dst),
                       (ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
                       "mad.lo.s16 \t$dst, $a, $b, $c;",
-                      [(set Int16Regs:$dst, (add
-                        (mul Int16Regs:$a, imm:$b), Int16Regs:$c))]>;
+                      [(set Int16Regs:$dst,
+                        (imad Int16Regs:$a, imm:$b, Int16Regs:$c))]>;
 def MAD16rii : NVPTXInst<(outs Int16Regs:$dst),
     (ins Int16Regs:$a, i16imm:$b, i16imm:$c),
                       "mad.lo.s16 \t$dst, $a, $b, $c;",
-                      [(set Int16Regs:$dst, (add (mul Int16Regs:$a, imm:$b),
-                        imm:$c))]>;
+                      [(set Int16Regs:$dst,
+                        (imad Int16Regs:$a, imm:$b, imm:$c))]>;
 
 def MAD32rrr : NVPTXInst<(outs Int32Regs:$dst),
                       (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
                       "mad.lo.s32 \t$dst, $a, $b, $c;",
-                      [(set Int32Regs:$dst, (add
-                        (mul Int32Regs:$a, Int32Regs:$b), Int32Regs:$c))]>;
+                      [(set Int32Regs:$dst,
+                        (imad Int32Regs:$a, Int32Regs:$b, Int32Regs:$c))]>;
 def MAD32rri : NVPTXInst<(outs Int32Regs:$dst),
                       (ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
                       "mad.lo.s32 \t$dst, $a, $b, $c;",
-                      [(set Int32Regs:$dst, (add
-                        (mul Int32Regs:$a, Int32Regs:$b), imm:$c))]>;
+                      [(set Int32Regs:$dst,
+                        (imad Int32Regs:$a, Int32Regs:$b, imm:$c))]>;
 def MAD32rir : NVPTXInst<(outs Int32Regs:$dst),
                       (ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
                       "mad.lo.s32 \t$dst, $a, $b, $c;",
-                      [(set Int32Regs:$dst, (add
-                        (mul Int32Regs:$a, imm:$b), Int32Regs:$c))]>;
+                      [(set Int32Regs:$dst,
+                        (imad Int32Regs:$a, imm:$b, Int32Regs:$c))]>;
 def MAD32rii : NVPTXInst<(outs Int32Regs:$dst),
                       (ins Int32Regs:$a, i32imm:$b, i32imm:$c),
                       "mad.lo.s32 \t$dst, $a, $b, $c;",
-                      [(set Int32Regs:$dst, (add
-                        (mul Int32Regs:$a, imm:$b), imm:$c))]>;
+                      [(set Int32Regs:$dst,
+                        (imad Int32Regs:$a, imm:$b, imm:$c))]>;
 
 def MAD64rrr : NVPTXInst<(outs Int64Regs:$dst),
                       (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
                       "mad.lo.s64 \t$dst, $a, $b, $c;",
-                      [(set Int64Regs:$dst, (add
-                        (mul Int64Regs:$a, Int64Regs:$b), Int64Regs:$c))]>;
+                      [(set Int64Regs:$dst,
+                        (imad Int64Regs:$a, Int64Regs:$b, Int64Regs:$c))]>;
 def MAD64rri : NVPTXInst<(outs Int64Regs:$dst),
                       (ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
                       "mad.lo.s64 \t$dst, $a, $b, $c;",
-                      [(set Int64Regs:$dst, (add
-                        (mul Int64Regs:$a, Int64Regs:$b), imm:$c))]>;
+                      [(set Int64Regs:$dst,
+                        (imad Int64Regs:$a, Int64Regs:$b, imm:$c))]>;
 def MAD64rir : NVPTXInst<(outs Int64Regs:$dst),
                       (ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
                       "mad.lo.s64 \t$dst, $a, $b, $c;",
-                      [(set Int64Regs:$dst, (add
-                        (mul Int64Regs:$a, imm:$b), Int64Regs:$c))]>;
+                      [(set Int64Regs:$dst,
+                        (imad Int64Regs:$a, imm:$b, Int64Regs:$c))]>;
 def MAD64rii : NVPTXInst<(outs Int64Regs:$dst),
                       (ins Int64Regs:$a, i64imm:$b, i64imm:$c),
                       "mad.lo.s64 \t$dst, $a, $b, $c;",
-                      [(set Int64Regs:$dst, (add
-                        (mul Int64Regs:$a, imm:$b), imm:$c))]>;
-
+                      [(set Int64Regs:$dst,
+                        (imad Int64Regs:$a, imm:$b, imm:$c))]>;
 
 def INEG16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
                      "neg.s16 \t$dst, $src;",
@@ -812,36 +868,26 @@ multiclass FPCONTRACT32<string OpcStr, Predicate Pred> {
    def rrr : NVPTXInst<(outs Float32Regs:$dst),
                       (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float32Regs:$dst, (fadd
-                        (fmul Float32Regs:$a, Float32Regs:$b),
-                        Float32Regs:$c))]>, Requires<[Pred]>;
-   // This is to WAR a weird bug in Tablegen that does not automatically
-   // generate the following permutated rule rrr2 from the above rrr.
-   // So we explicitly add it here. This happens to FMA32 only.
-   // See the comments at FMAD32 and FMA32 for more information.
-   def rrr2 : NVPTXInst<(outs Float32Regs:$dst),
-                        (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
-                      !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float32Regs:$dst, (fadd Float32Regs:$c,
-                        (fmul Float32Regs:$a, Float32Regs:$b)))]>,
+                      [(set Float32Regs:$dst,
+                        (fma Float32Regs:$a, Float32Regs:$b, Float32Regs:$c))]>,
                       Requires<[Pred]>;
    def rri : NVPTXInst<(outs Float32Regs:$dst),
                       (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float32Regs:$dst, (fadd
-                        (fmul Float32Regs:$a, Float32Regs:$b), fpimm:$c))]>,
+                      [(set Float32Regs:$dst,
+                        (fma Float32Regs:$a, Float32Regs:$b, fpimm:$c))]>,
                       Requires<[Pred]>;
    def rir : NVPTXInst<(outs Float32Regs:$dst),
                       (ins Float32Regs:$a, f32imm:$b, Float32Regs:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float32Regs:$dst, (fadd
-                        (fmul Float32Regs:$a, fpimm:$b), Float32Regs:$c))]>,
+                      [(set Float32Regs:$dst,
+                        (fma Float32Regs:$a, fpimm:$b, Float32Regs:$c))]>,
                       Requires<[Pred]>;
    def rii : NVPTXInst<(outs Float32Regs:$dst),
                       (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float32Regs:$dst, (fadd
-                        (fmul Float32Regs:$a, fpimm:$b), fpimm:$c))]>,
+                      [(set Float32Regs:$dst,
+                        (fma Float32Regs:$a, fpimm:$b, fpimm:$c))]>,
                       Requires<[Pred]>;
 }
 
@@ -849,73 +895,32 @@ multiclass FPCONTRACT64<string OpcStr, Predicate Pred> {
    def rrr : NVPTXInst<(outs Float64Regs:$dst),
                       (ins Float64Regs:$a, Float64Regs:$b, Float64Regs:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float64Regs:$dst, (fadd
-                        (fmul Float64Regs:$a, Float64Regs:$b),
-                        Float64Regs:$c))]>, Requires<[Pred]>;
+                      [(set Float64Regs:$dst,
+                        (fma Float64Regs:$a, Float64Regs:$b, Float64Regs:$c))]>,
+                      Requires<[Pred]>;
    def rri : NVPTXInst<(outs Float64Regs:$dst),
                       (ins Float64Regs:$a, Float64Regs:$b, f64imm:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float64Regs:$dst, (fadd (fmul Float64Regs:$a,
-                        Float64Regs:$b), fpimm:$c))]>, Requires<[Pred]>;
+                      [(set Float64Regs:$dst,
+                        (fma Float64Regs:$a, Float64Regs:$b, fpimm:$c))]>,
+                      Requires<[Pred]>;
    def rir : NVPTXInst<(outs Float64Regs:$dst),
                       (ins Float64Regs:$a, f64imm:$b, Float64Regs:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float64Regs:$dst, (fadd
-                        (fmul Float64Regs:$a, fpimm:$b), Float64Regs:$c))]>,
+                      [(set Float64Regs:$dst,
+                        (fma Float64Regs:$a, fpimm:$b, Float64Regs:$c))]>,
                       Requires<[Pred]>;
    def rii : NVPTXInst<(outs Float64Regs:$dst),
                       (ins Float64Regs:$a, f64imm:$b, f64imm:$c),
                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
-                      [(set Float64Regs:$dst, (fadd
-                        (fmul Float64Regs:$a, fpimm:$b), fpimm:$c))]>,
+                      [(set Float64Regs:$dst,
+                        (fma Float64Regs:$a, fpimm:$b, fpimm:$c))]>,
                       Requires<[Pred]>;
 }
 
-// Due to a unknown reason (most likely a bug in tablegen), tablegen does not
-// automatically generate the rrr2 rule from
-// the rrr rule (see FPCONTRACT32) for FMA32, though it does for FMAD32.
-// If we reverse the order of the following two lines, then rrr2 rule will be
-// generated for FMA32, but not for rrr.
-// Therefore, we manually write the rrr2 rule in FPCONTRACT32.
-defm FMA32_ftz  : FPCONTRACT32<"fma.rn.ftz.f32", doFMAF32_ftz>;
-defm FMA32  : FPCONTRACT32<"fma.rn.f32", doFMAF32>;
-defm FMA64  : FPCONTRACT64<"fma.rn.f64", doFMAF64>;
-
-// b*c-a => fmad(b, c, -a)
-multiclass FPCONTRACT32_SUB_PAT_MAD<NVPTXInst Inst, Predicate Pred> {
-  def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
-          (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
-          Requires<[Pred]>;
-}
-
-// a-b*c => fmad(-b,c, a)
-// - legal because a-b*c <=> a+(-b*c) <=> a+(-b)*c
-// b*c-a => fmad(b, c, -a)
-// - legal because b*c-a <=> b*c+(-a)
-multiclass FPCONTRACT32_SUB_PAT<NVPTXInst Inst, Predicate Pred> {
-  def : Pat<(fsub Float32Regs:$a, (fmul Float32Regs:$b, Float32Regs:$c)),
-          (Inst (FNEGf32 Float32Regs:$b), Float32Regs:$c, Float32Regs:$a)>,
-          Requires<[Pred]>;
-  def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
-          (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
-          Requires<[Pred]>;
-}
-
-// a-b*c => fmad(-b,c, a)
-// b*c-a => fmad(b, c, -a)
-multiclass FPCONTRACT64_SUB_PAT<NVPTXInst Inst, Predicate Pred> {
-  def : Pat<(fsub Float64Regs:$a, (fmul Float64Regs:$b, Float64Regs:$c)),
-          (Inst (FNEGf64 Float64Regs:$b), Float64Regs:$c, Float64Regs:$a)>,
-          Requires<[Pred]>;
-
-  def : Pat<(fsub (fmul Float64Regs:$b, Float64Regs:$c), Float64Regs:$a),
-          (Inst Float64Regs:$b, Float64Regs:$c, (FNEGf64 Float64Regs:$a))>,
-          Requires<[Pred]>;
-}
-
-defm FMAF32ext_ftz  : FPCONTRACT32_SUB_PAT<FMA32_ftzrrr, doFMAF32AGG_ftz>;
-defm FMAF32ext  : FPCONTRACT32_SUB_PAT<FMA32rrr, doFMAF32AGG>;
-defm FMAF64ext  : FPCONTRACT64_SUB_PAT<FMA64rrr, doFMAF64AGG>;
+defm FMA32_ftz  : FPCONTRACT32<"fma.rn.ftz.f32", doF32FTZ>;
+defm FMA32  : FPCONTRACT32<"fma.rn.f32", doNoF32FTZ>;
+defm FMA64  : FPCONTRACT64<"fma.rn.f64", doNoF32FTZ>;
 
 def SINF:  NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
                       "sin.approx.f32 \t$dst, $src;",
diff --git a/llvm/test/CodeGen/NVPTX/imad.ll b/llvm/test/CodeGen/NVPTX/imad.ll
new file mode 100644 (file)
index 0000000..67421c7
--- /dev/null
@@ -0,0 +1,9 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+; CHECK: imad
+define i32 @imad(i32 %a, i32 %b, i32 %c) {
+; CHECK: mad.lo.s32
+  %val0 = mul i32 %a, %b
+  %val1 = add i32 %val0, %c
+  ret i32 %val1
+}
diff --git a/llvm/test/CodeGen/NVPTX/mulwide.ll b/llvm/test/CodeGen/NVPTX/mulwide.ll
new file mode 100644 (file)
index 0000000..927946c
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+; CHECK: mulwide16
+define i32 @mulwide16(i16 %a, i16 %b) {
+; CHECK: mul.wide.s16
+  %val0 = sext i16 %a to i32
+  %val1 = sext i16 %b to i32
+  %val2 = mul i32 %val0, %val1
+  ret i32 %val2
+}
+
+; CHECK: mulwideu16
+define i32 @mulwideu16(i16 %a, i16 %b) {
+; CHECK: mul.wide.u16
+  %val0 = zext i16 %a to i32
+  %val1 = zext i16 %b to i32
+  %val2 = mul i32 %val0, %val1
+  ret i32 %val2
+}
+
+; CHECK: mulwide32
+define i64 @mulwide32(i32 %a, i32 %b) {
+; CHECK: mul.wide.s32
+  %val0 = sext i32 %a to i64
+  %val1 = sext i32 %b to i64
+  %val2 = mul i64 %val0, %val1
+  ret i64 %val2
+}
+
+; CHECK: mulwideu32
+define i64 @mulwideu32(i32 %a, i32 %b) {
+; CHECK: mul.wide.u32
+  %val0 = zext i32 %a to i64
+  %val1 = zext i32 %b to i64
+  %val2 = mul i64 %val0, %val1
+  ret i64 %val2
+}