[DAGCombiner][AArch64] Enhance to fold CSNEG into CSINC instruction
authorzhongyunde <zhongyunde@huawei.com>
Wed, 16 Feb 2022 01:35:15 +0000 (09:35 +0800)
committerguopeilin <guopeilin1@huawei.com>
Wed, 16 Feb 2022 01:39:38 +0000 (09:39 +0800)
Perform the scalar expression combine in the form of:
  CSNEG(1, c, cc) + b  =>  cc  ? b+1 : b-c => CSINC(b-c, b, !cc)
  CSNEG(c, -1, cc) + b =>  cc  ? b+c : b+1 => CSINC(b+c, b, cc)

Fix https://github.com/llvm/llvm-project/issues/53071

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D119105

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/aarch64-isel-csinc-type.ll
llvm/test/CodeGen/AArch64/aarch64-isel-csinc.ll

index 9890b3e..a371d3b 100644 (file)
@@ -14751,41 +14751,62 @@ static SDValue performAddUADDVCombine(SDNode *N, SelectionDAG &DAG) {
 }
 
 /// Perform the scalar expression combine in the form of:
-///   CSEL (c, 1, cc) + b => CSINC(b+c, b, cc)
+///   CSEL(c, 1, cc) + b => CSINC(b+c, b, cc)
+///   CSNEG(c, -1, cc) + b => CSINC(b+c, b, cc)
 static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD)
     return SDValue();
 
-  SDValue CSel = N->getOperand(0);
+  SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
 
   // Handle commutivity.
-  if (CSel.getOpcode() != AArch64ISD::CSEL) {
-    std::swap(CSel, RHS);
-    if (CSel.getOpcode() != AArch64ISD::CSEL) {
+  if (LHS.getOpcode() != AArch64ISD::CSEL &&
+      LHS.getOpcode() != AArch64ISD::CSNEG) {
+    std::swap(LHS, RHS);
+    if (LHS.getOpcode() != AArch64ISD::CSEL &&
+        LHS.getOpcode() != AArch64ISD::CSNEG) {
       return SDValue();
     }
   }
 
-  if (!CSel.hasOneUse())
+  if (!LHS.hasOneUse())
     return SDValue();
 
   AArch64CC::CondCode AArch64CC =
-      static_cast<AArch64CC::CondCode>(CSel.getConstantOperandVal(2));
+      static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));
 
-  // The CSEL should include a const one operand.
-  ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(CSel.getOperand(0));
-  ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(CSel.getOperand(1));
-  if (!CTVal || !CFVal || (!CTVal->isOne() && !CFVal->isOne()))
+  // The CSEL should include a const one operand, and the CSNEG should include
+  // One or NegOne operand.
+  ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(LHS.getOperand(0));
+  ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(LHS.getOperand(1));
+  if (!CTVal || !CFVal)
     return SDValue();
 
-  // switch CSEL (1, c, cc)  to CSEL (c, 1, !cc)
-  if (CTVal->isOne() && !CFVal->isOne()) {
+  if (!(LHS.getOpcode() == AArch64ISD::CSEL &&
+        (CTVal->isOne() || CFVal->isOne())) &&
+      !(LHS.getOpcode() == AArch64ISD::CSNEG &&
+        (CTVal->isOne() || CFVal->isAllOnes())))
+    return SDValue();
+
+  // Switch CSEL(1, c, cc) to CSEL(c, 1, !cc)
+  if (LHS.getOpcode() == AArch64ISD::CSEL && CTVal->isOne() &&
+      !CFVal->isOne()) {
     std::swap(CTVal, CFVal);
     AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
   }
 
+  SDLoc DL(N);
+  // Switch CSNEG(1, c, cc) to CSNEG(-c, -1, !cc)
+  if (LHS.getOpcode() == AArch64ISD::CSNEG && CTVal->isOne() &&
+      !CFVal->isAllOnes()) {
+    APInt C = -1 * CFVal->getAPIntValue();
+    CTVal = cast<ConstantSDNode>(DAG.getConstant(C, DL, VT));
+    CFVal = cast<ConstantSDNode>(DAG.getAllOnesConstant(DL, VT));
+    AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
+  }
+
   // It might be neutral for larger constants, as the immediate need to be
   // materialized in a register.
   APInt ADDC = CTVal->getAPIntValue();
@@ -14793,12 +14814,13 @@ static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) {
   if (!TLI.isLegalAddImmediate(ADDC.getSExtValue()))
     return SDValue();
 
-  assert(CFVal->isOne() && "Unexpected constant value");
+  assert(((LHS.getOpcode() == AArch64ISD::CSEL && CFVal->isOne()) ||
+          (LHS.getOpcode() == AArch64ISD::CSNEG && CFVal->isAllOnes())) &&
+         "Unexpected constant value");
 
-  SDLoc DL(N);
   SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0));
   SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32);
-  SDValue Cmp = CSel.getOperand(3);
+  SDValue Cmp = LHS.getOperand(3);
 
   return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp);
 }
index 5ae1133..7706ca9 100644 (file)
@@ -7,7 +7,7 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
 target triple = "aarch64-unknown-linux-gnu"
 
 ; char csinc1 (char a, char b) { return !a ? b+1 : b+3; }
-define dso_local i8 @csinc1(i8 %a, i8 %b) local_unnamed_addr #0 {
+define i8 @csinc1(i8 %a, i8 %b) local_unnamed_addr #0 {
 ; CHECK-LABEL: csinc1:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    tst w0, #0xff
@@ -22,7 +22,7 @@ entry:
 }
 
 ; short csinc2 (short a, short b) { return !a ? b+1 : b+3; }
-define dso_local i16 @csinc2(i16 %a, i16 %b) local_unnamed_addr #0 {
+define i16 @csinc2(i16 %a, i16 %b) local_unnamed_addr #0 {
 ; CHECK-LABEL: csinc2:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    tst w0, #0xffff
@@ -37,7 +37,7 @@ entry:
 }
 
 ; int csinc3 (int a, int b) { return !a ? b+1 : b+3; }
-define dso_local i32 @csinc3(i32 %a, i32 %b) local_unnamed_addr #0 {
+define i32 @csinc3(i32 %a, i32 %b) local_unnamed_addr #0 {
 ; CHECK-LABEL: csinc3:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cmp w0, #0
@@ -52,7 +52,7 @@ entry:
 }
 
 ; long long csinc4 (long long a, long long b) { return !a ? b+1 : b+3; }
-define dso_local i64 @csinc4(i64 %a, i64 %b) local_unnamed_addr #0 {
+define i64 @csinc4(i64 %a, i64 %b) local_unnamed_addr #0 {
 ; CHECK-LABEL: csinc4:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cmp x0, #0
@@ -65,3 +65,33 @@ entry:
   %cond = add nsw i64 %cond.v, %b
   ret i64 %cond
 }
+
+; long long csinc8 (long long a, long long b) { return a ? b-1 : b+1; }
+define i64 @csinc8(i64 %a, i64 %b) {
+; CHECK-LABEL: csinc8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub x8, x1, #1
+; CHECK-NEXT:    cmp x0, #0
+; CHECK-NEXT:    csinc x0, x8, x1, ne
+; CHECK-NEXT:    ret
+entry:
+  %tobool.not = icmp eq i64 %a, 0
+  %cond.v = select i1 %tobool.not, i64 1, i64 -1
+  %cond = add nsw i64 %cond.v, %b
+  ret i64 %cond
+}
+
+; long long csinc9 (long long a, long long b) { return a ? b+1 : b-1; }
+define i64 @csinc9(i64 %a, i64 %b) {
+; CHECK-LABEL: csinc9:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub x8, x1, #1
+; CHECK-NEXT:    cmp x0, #0
+; CHECK-NEXT:    csinc x0, x8, x1, eq
+; CHECK-NEXT:    ret
+entry:
+  %tobool.not = icmp eq i64 %a, 0
+  %cond.v = select i1 %tobool.not, i64 -1, i64 1
+  %cond = add nsw i64 %cond.v, %b
+  ret i64 %cond
+}
index 9ecc6f5..cbcd5ef 100644 (file)
@@ -7,7 +7,7 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
 target triple = "aarch64-unknown-linux-gnu"
 
 ; int csinc1 (int a, int b) { return !a ? b+3 : b+1; }
-define dso_local i32 @csinc1(i32 %a, i32 %b) {
+define i32 @csinc1(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc1:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cmp w0, #0
@@ -22,7 +22,7 @@ entry:
 }
 
 ; int csinc2 (int a, int b) { return a ? b+3 : b+1; }
-define dso_local i32 @csinc2(i32 %a, i32 %b) {
+define i32 @csinc2(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc2:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cmp w0, #0
@@ -37,7 +37,7 @@ entry:
 }
 
 ; int csinc3 (int a, int b) { return !a ? b+1 : b-3; }
-define dso_local i32 @csinc3(i32 %a, i32 %b) {
+define i32 @csinc3(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc3:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sub w8, w1, #3
@@ -52,7 +52,7 @@ entry:
 }
 
 ; int csinc4 (int a, int b) { return a ? b+1 : b-3; }
-define dso_local i32 @csinc4(i32 %a, i32 %b) {
+define i32 @csinc4(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc4:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sub w8, w1, #3
@@ -67,7 +67,7 @@ entry:
 }
 
 ; int csinc5 (int a, int b) { return a ? b+1 : b-4095; }
-define dso_local i32 @csinc5(i32 %a, i32 %b) {
+define i32 @csinc5(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc5:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sub w8, w1, #4095
@@ -82,7 +82,7 @@ entry:
 }
 
 ; int csinc6 (int a, int b) { return a ? b+1 : b-4096; }
-define dso_local i32 @csinc6(i32 %a, i32 %b) {
+define i32 @csinc6(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc6:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    sub w8, w1, #1, lsl #12 // =4096
@@ -98,7 +98,7 @@ entry:
 
 ; prevent larger constants (the add laid after csinc)
 ; int csinc7 (int a, int b) { return a ? b+1 : b-4097; }
-define dso_local i32 @csinc7(i32 %a, i32 %b) {
+define i32 @csinc7(i32 %a, i32 %b) {
 ; CHECK-LABEL: csinc7:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cmp w0, #0
@@ -112,3 +112,33 @@ entry:
   %cond = add nsw i32 %cond.v, %b
   ret i32 %cond
 }
+
+; int csinc8 (int a, int b) { return a ? b-1 : b+1; }
+define i32 @csinc8(i32 %a, i32 %b) {
+; CHECK-LABEL: csinc8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub w8, w1, #1
+; CHECK-NEXT:    cmp w0, #0
+; CHECK-NEXT:    csinc w0, w8, w1, ne
+; CHECK-NEXT:    ret
+entry:
+  %tobool.not = icmp eq i32 %a, 0
+  %cond.v = select i1 %tobool.not, i32 1, i32 -1
+  %cond = add nsw i32 %cond.v, %b
+  ret i32 %cond
+}
+
+; int csinc9 (int a, int b) { return a ? b+1 : b-1; }
+define i32 @csinc9(i32 %a, i32 %b) {
+; CHECK-LABEL: csinc9:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub w8, w1, #1
+; CHECK-NEXT:    cmp w0, #0
+; CHECK-NEXT:    csinc w0, w8, w1, eq
+; CHECK-NEXT:    ret
+entry:
+  %tobool.not = icmp eq i32 %a, 0
+  %cond.v = select i1 %tobool.not, i32 -1, i32 1
+  %cond = add nsw i32 %cond.v, %b
+  ret i32 %cond
+}