[RISCV] Slightly simplify coode in combineVWADD_W_VL_VWSUB_W_VL and combineMUL_VLToVW...
authorCraig Topper <craig.topper@sifive.com>
Wed, 31 Aug 2022 21:10:02 +0000 (14:10 -0700)
committerCraig Topper <craig.topper@sifive.com>
Wed, 31 Aug 2022 22:02:03 +0000 (15:02 -0700)
Use computeMaxSignificantBits instead of ComputeNumSignBits. Create
APInt as part of call to MaskedValueIsZero instead of creating
a named temporary.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

index e306622..32e794f 100644 (file)
@@ -8744,11 +8744,11 @@ static SDValue combineVWADD_W_VL_VWSUB_W_VL(SDNode *N, SelectionDAG &DAG) {
       return SDValue();
 
     if (IsSigned) {
-      if (DAG.ComputeNumSignBits(Op0) <= (ScalarBits - NarrowSize))
+      if (DAG.ComputeMaxSignificantBits(Op0) > NarrowSize)
         return SDValue();
     } else {
-      APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize);
-      if (!DAG.MaskedValueIsZero(Op0, Mask))
+      if (!DAG.MaskedValueIsZero(Op0,
+                                 APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
         return SDValue();
     }
 
@@ -8826,16 +8826,15 @@ static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG,
       return SDValue();
 
     // If the LHS is a sign extend, try to use vwmul.
-    if (IsSignExt && DAG.ComputeNumSignBits(Op1) > (ScalarBits - NarrowSize)) {
+    if (IsSignExt && DAG.ComputeMaxSignificantBits(Op1) <= NarrowSize) {
       // Can use vwmul.
-    } else {
-      // Otherwise try to use vwmulu or vwmulsu.
-      APInt Mask = APInt::getBitsSetFrom(ScalarBits, NarrowSize);
-      if (DAG.MaskedValueIsZero(Op1, Mask))
-        IsVWMULSU = IsSignExt;
-      else
-        return SDValue();
-    }
+    } else if (DAG.MaskedValueIsZero(
+                   Op1, APInt::getBitsSetFrom(ScalarBits, NarrowSize))) {
+      // Scalar is zero extended, if the vector is sign extended we can use
+      // vwmulsu. If the vector is zero extended we can use vwmulu.
+      IsVWMULSU = IsSignExt;
+    } else
+      return SDValue();
 
     Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                       DAG.getUNDEF(NarrowVT), Op1, VL);