[DAG] Pull out getTruncatedUSUBSAT helper from foldSubToUSubSat. NFCI.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 17 Feb 2021 12:17:08 +0000 (12:17 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 17 Feb 2021 12:17:08 +0000 (12:17 +0000)
This will simplify an incoming generic implementation of D25987.

I'll rebase D96703 shortly to support this.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

index 0d41a28..6a04ba7 100644 (file)
@@ -3130,6 +3130,34 @@ SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
   return SDValue();
 }
 
+// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
+// clamp/truncation if necessary.
+static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
+                                   SDValue RHS, SelectionDAG &DAG,
+                                   const SDLoc &DL) {
+  assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
+         "Illegal truncation");
+
+  if (DstVT == SrcVT)
+    return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
+
+  // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
+  // clamping RHS.
+  APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
+                                          DstVT.getScalarSizeInBits());
+  if (!DAG.MaskedValueIsZero(LHS, UpperBits))
+    return SDValue();
+
+  SDValue SatLimit =
+      DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
+                                           DstVT.getScalarSizeInBits()),
+                      DL, SrcVT);
+  RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
+  RHS = DAG.getZExtOrTrunc(RHS, DL, DstVT);
+  LHS = DAG.getZExtOrTrunc(LHS, DL, DstVT);
+  return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
+}
+
 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
 // usubsat(a,b), optionally as a truncated type.
 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
@@ -3140,30 +3168,6 @@ SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
   EVT SubVT = N->getValueType(0);
   SDValue Op0 = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
-  assert(DstVT.getScalarSizeInBits() <= SubVT.getScalarSizeInBits() &&
-         "Illegal truncation");
-
-  auto TruncatedUSUBSAT = [&](SDValue LHS, SDValue RHS) {
-    SDLoc DL(N);
-    if (DstVT == SubVT)
-      return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
-
-    // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
-    // clamping RHS.
-    APInt UpperBits = APInt::getBitsSetFrom(SubVT.getScalarSizeInBits(),
-                                            DstVT.getScalarSizeInBits());
-    if (!DAG.MaskedValueIsZero(LHS, UpperBits))
-      return SDValue();
-
-    SDValue SatLimit =
-        DAG.getConstant(APInt::getLowBitsSet(SubVT.getScalarSizeInBits(),
-                                             DstVT.getScalarSizeInBits()),
-                        DL, SubVT);
-    RHS = DAG.getNode(ISD::UMIN, DL, SubVT, RHS, SatLimit);
-    RHS = DAG.getZExtOrTrunc(RHS, DL, DstVT);
-    LHS = DAG.getZExtOrTrunc(LHS, DL, DstVT);
-    return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
-  };
 
   // Try to find umax(a,b) - b or a - umin(a,b) patterns
   // they may be converted to usubsat(a,b).
@@ -3171,18 +3175,18 @@ SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
     SDValue MaxLHS = Op0.getOperand(0);
     SDValue MaxRHS = Op0.getOperand(1);
     if (MaxLHS == Op1)
-      return TruncatedUSUBSAT(MaxRHS, Op1);
+      return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
     if (MaxRHS == Op1)
-      return TruncatedUSUBSAT(MaxLHS, Op1);
+      return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
   }
 
   if (Op1.getOpcode() == ISD::UMIN) {
     SDValue MinLHS = Op1.getOperand(0);
     SDValue MinRHS = Op1.getOperand(1);
     if (MinLHS == Op0)
-      return TruncatedUSUBSAT(Op0, MinRHS);
+      return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
     if (MinRHS == Op0)
-      return TruncatedUSUBSAT(Op0, MinLHS);
+      return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
   }
 
   return SDValue();