[RISCV] Combine comparison and logic ops
authorIlya Andreev <ilya.andreev@syntacore.com>
Tue, 13 Sep 2022 13:01:56 +0000 (09:01 -0400)
committerSergey Kachkov <sergey.kachkov@syntacore.com>
Fri, 23 Dec 2022 14:10:21 +0000 (17:10 +0300)
Two comparison operations and a logical operation are combined into selection using MIN or MAX and comparison operation.
For optimization to be applied conditions have to be satisfied:
  1. In comparison operations has to be the one common operand.
  2. Supports only signed and unsigned integers.
  3. Comparison has to be the same with respect to common operand.
  4. There are no more users of comparison except logic operation.
  5. Every combination of comparison and AND, OR are supported.

It will convert
  %l0 = %a < %c
  %l1 = %b < %c
  %res = %l0 or %l1
into
  %sel = min(%a, %b)
  %res = %sel < %c

It supports several comparison operations (<, <=, >, >=), signed, unsigned values and different order of operands if they do not violate conditions.

Differential Revision: https://reviews.llvm.org/D134277

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/zbb-cmp-combine.ll

index 764802f..3041e46 100644 (file)
@@ -8465,6 +8465,226 @@ static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Helper class contains information about comparison operation.
+// The first two operands of this operation are compared values and the
+// last one is the operation.
+// Compared values are stored in Ops.
+// Comparison operation is stored in CCode.
+class CmpOpInfo {
+  static unsigned constexpr Size = 2u;
+
+  // Type for storing operands of compare operation.
+  using OpsArray = std::array<SDValue, Size>;
+  OpsArray Ops;
+
+  using const_iterator = OpsArray::const_iterator;
+  const_iterator begin() const { return Ops.begin(); }
+  const_iterator end() const { return Ops.end(); }
+
+  ISD::CondCode CCode;
+
+  unsigned CommonPos{Size};
+  unsigned DifferPos{Size};
+
+  // Sets CommonPos and DifferPos based on incoming position
+  // of common operand CPos.
+  void setPositions(const_iterator CPos) {
+    assert(CPos != Ops.end() && "Common operand has to be in OpsArray.\n");
+    CommonPos = CPos == Ops.begin() ? 0 : 1;
+    DifferPos = 1 - CommonPos;
+    assert((DifferPos == 0 || DifferPos == 1) &&
+           "Positions can be only 0 or 1.");
+  }
+
+  // Private constructor of comparison info based on comparison operator.
+  // It is private because CmpOpInfo only reasonable relative to other
+  // comparison operator. Therefore, infos about comparison operation
+  // have to be collected simultaneously via CmpOpInfo::getInfoAbout().
+  CmpOpInfo(const SDValue &CmpOp)
+      : Ops{CmpOp.getOperand(0), CmpOp.getOperand(1)},
+        CCode{cast<CondCodeSDNode>(CmpOp.getOperand(2))->get()} {}
+
+  // Finds common operand of Op1 and Op2 and finishes filling CmpOpInfos.
+  // Returns true if common operand is found. Otherwise - false.
+  static bool establishCorrespondence(CmpOpInfo &Op1, CmpOpInfo &Op2) {
+    const auto CommonOpIt1 =
+        std::find_first_of(Op1.begin(), Op1.end(), Op2.begin(), Op2.end());
+    if (CommonOpIt1 == Op1.end())
+      return false;
+
+    const auto CommonOpIt2 = std::find(Op2.begin(), Op2.end(), *CommonOpIt1);
+    assert(CommonOpIt2 != Op2.end() &&
+           "Cannot find common operand in the second comparison operation.");
+
+    Op1.setPositions(CommonOpIt1);
+    Op2.setPositions(CommonOpIt2);
+
+    return true;
+  }
+
+public:
+  CmpOpInfo(const CmpOpInfo &) = default;
+  CmpOpInfo(CmpOpInfo &&) = default;
+
+  SDValue const &operator[](unsigned Pos) const {
+    assert(Pos < Size && "Out of range\n");
+    return Ops[Pos];
+  }
+
+  // Creates infos about comparison operations CmpOp0 and CmpOp1.
+  // If there is no common operand returns None. Otherwise, returns
+  // correspondence info about comparison operations.
+  static std::optional<std::pair<CmpOpInfo, CmpOpInfo>>
+  getInfoAbout(SDValue const &CmpOp0, SDValue const &CmpOp1) {
+    CmpOpInfo Op0{CmpOp0};
+    CmpOpInfo Op1{CmpOp1};
+    if (!establishCorrespondence(Op0, Op1))
+      return std::nullopt;
+    return std::make_pair(Op0, Op1);
+  }
+
+  // Returns position of common operand.
+  unsigned getCPos() const { return CommonPos; }
+
+  // Returns position of differ operand.
+  unsigned getDPos() const { return DifferPos; }
+
+  // Returns common operand.
+  SDValue const &getCOp() const { return operator[](CommonPos); }
+
+  // Returns differ operand.
+  SDValue const &getDOp() const { return operator[](DifferPos); }
+
+  // Returns consition code of comparison operation.
+  ISD::CondCode getCondCode() const { return CCode; }
+};
+
+// Verifies conditions to apply an optimization.
+// Returns Reference comparison code and three operands A, B, C.
+// Conditions for optimization:
+//   One operand of the compasions has to be common.
+//   This operand is written to C.
+//   Two others operands are differend. They are written to A and B.
+//   Comparisons has to be similar with respect to common operand C.
+//     e.g. A < C; C > B are similar
+//      but A < C; B > C are not.
+//   Reference comparison code is the comparison code if
+//   common operand is right placed.
+//     e.g. C > A will be swapped to A < C.
+static std::optional<std::tuple<ISD::CondCode, SDValue, SDValue, SDValue>>
+verifyCompareConds(SDNode *N, SelectionDAG &DAG) {
+  LLVM_DEBUG(
+      dbgs() << "Checking conditions for comparison operation combining.\n";);
+
+  SDValue V0 = N->getOperand(0);
+  SDValue V1 = N->getOperand(1);
+  assert(V0.getValueType() == V1.getValueType() &&
+         "Operations must have the same value type.");
+
+  // Condition 1. Operations have to be used only in logic operation.
+  if (!V0.hasOneUse() || !V1.hasOneUse())
+    return std::nullopt;
+
+  // Condition 2. Operands have to be comparison operations.
+  if (V0.getOpcode() != ISD::SETCC || V1.getOpcode() != ISD::SETCC)
+    return std::nullopt;
+
+  // Condition 3.1. Operations only with integers.
+  if (!V0.getOperand(0).getValueType().isInteger())
+    return std::nullopt;
+
+  const auto ComparisonInfo = CmpOpInfo::getInfoAbout(V0, V1);
+  // Condition 3.2. Common operand has to be in comparison.
+  if (!ComparisonInfo)
+    return std::nullopt;
+
+  const auto [Op0, Op1] = ComparisonInfo.value();
+
+  LLVM_DEBUG(dbgs() << "Shared operands are on positions: " << Op0.getCPos()
+                    << " and " << Op1.getCPos() << '\n';);
+  // If common operand at the first position then swap operation to convert to
+  // strict pattern. Common operand has to be right hand side.
+  ISD::CondCode RefCond = Op0.getCondCode();
+  ISD::CondCode AssistCode = Op1.getCondCode();
+  if (!Op0.getCPos())
+    RefCond = ISD::getSetCCSwappedOperands(RefCond);
+  if (!Op1.getCPos())
+    AssistCode = ISD::getSetCCSwappedOperands(AssistCode);
+  LLVM_DEBUG(dbgs() << "Reference condition is: " << RefCond << '\n';);
+  // If there are different comparison operations then do not perform an
+  // optimization. a < c; c < b -> will be changed to b > c.
+  if (RefCond != AssistCode)
+    return std::nullopt;
+
+  // Conditions can be only similar to Less or Greater. (>, >=, <, <=)
+  // Applying this mask to the operation will determine Less and Greater
+  // operations.
+  const unsigned CmpMask = 0b110;
+  const unsigned MaskedOpcode = CmpMask & RefCond;
+  // If masking gave 0b110, then this is an operation NE, O or TRUE.
+  if (MaskedOpcode == CmpMask)
+    return std::nullopt;
+  // If masking gave 00000, then this is an operation E, O or FALSE.
+  if (MaskedOpcode == 0)
+    return std::nullopt;
+  // Everything else is similar to Less or Greater.
+
+  SDValue A = Op0.getDOp();
+  SDValue B = Op1.getDOp();
+  SDValue C = Op0.getCOp();
+
+  LLVM_DEBUG(
+      dbgs() << "The conditions for combining comparisons are satisfied.\n";);
+  return std::make_tuple(RefCond, A, B, C);
+}
+
+static ISD::NodeType getSelectionCode(bool IsUnsigned, bool IsAnd,
+                                      bool IsGreaterOp) {
+  // Codes of selection operation. The first index selects signed or unsigned,
+  // the second index selects MIN/MAX.
+  static constexpr ISD::NodeType SelectionCodes[2][2] = {
+      {ISD::SMIN, ISD::SMAX}, {ISD::UMIN, ISD::UMAX}};
+  const bool ChooseSelCode = IsAnd ^ IsGreaterOp;
+  return SelectionCodes[IsUnsigned][ChooseSelCode];
+}
+
+// Combines two comparison operation and logic operation to one selection
+// operation(min, max) and logic operation. Returns new constructed Node if
+// conditions for optimization are satisfied.
+static SDValue combineCmpOp(SDNode *N, SelectionDAG &DAG,
+                            const RISCVSubtarget &Subtarget) {
+  if (!Subtarget.hasStdExtZbb())
+    return SDValue();
+
+  const unsigned BitOpcode = N->getOpcode();
+  assert((BitOpcode == ISD::AND || BitOpcode == ISD::OR) &&
+         "This optimization can be used only with AND/OR operations");
+
+  const auto Props = verifyCompareConds(N, DAG);
+  // If conditions are invalidated then do not perform an optimization.
+  if (!Props)
+    return SDValue();
+
+  const auto [RefOpcode, A, B, C] = Props.value();
+  const EVT CmpOpVT = A.getValueType();
+
+  const bool IsGreaterOp = RefOpcode & 0b10;
+  const bool IsUnsigned = ISD::isUnsignedIntSetCC(RefOpcode);
+  assert((IsUnsigned || ISD::isSignedIntSetCC(RefOpcode)) &&
+         "Operation neither with signed or unsigned integers.");
+
+  const bool IsAnd = BitOpcode == ISD::AND;
+  const ISD::NodeType PickCode =
+      getSelectionCode(IsUnsigned, IsAnd, IsGreaterOp);
+
+  SDLoc DL(N);
+  SDValue Pick = DAG.getNode(PickCode, DL, CmpOpVT, A, B);
+  SDValue Cmp =
+      DAG.getSetCC(DL, N->getOperand(0).getValueType(), Pick, C, RefOpcode);
+
+  return Cmp;
+}
+
 static SDValue performANDCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const RISCVSubtarget &Subtarget) {
@@ -8489,6 +8709,9 @@ static SDValue performANDCombine(SDNode *N,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
   }
 
+  if (SDValue V = combineCmpOp(N, DAG, Subtarget))
+    return V;
+
   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
     return V;
 
@@ -8505,6 +8728,9 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                 const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
 
+  if (SDValue V = combineCmpOp(N, DAG, Subtarget))
+    return V;
+
   if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
     return V;
 
index 4d02c53..b94c50d 100644 (file)
@@ -12,9 +12,8 @@
 define i1 @ulo(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ulo:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %a, %c
   %l1 = icmp ult i64 %b, %c
@@ -25,9 +24,8 @@ define i1 @ulo(i64 %c, i64 %a, i64 %b) {
 define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ulo_swap1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %c, %a
   %l1 = icmp ult i64 %b, %c
@@ -38,9 +36,8 @@ define i1 @ulo_swap1(i64 %c, i64 %a, i64 %b) {
 define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ulo_swap2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %a, %c
   %l1 = icmp ugt i64 %c, %b
@@ -51,9 +48,8 @@ define i1 @ulo_swap2(i64 %c, i64 %a, i64 %b) {
 define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ulo_swap12:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %c, %a
   %l1 = icmp ugt i64 %c, %b
@@ -65,9 +61,8 @@ define i1 @ulo_swap12(i64 %c, i64 %a, i64 %b) {
 define i1 @ula(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ula:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %a, %c
   %l1 = icmp ult i64 %b, %c
@@ -78,9 +73,8 @@ define i1 @ula(i64 %c, i64 %a, i64 %b) {
 define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ula_swap1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %c, %a
   %l1 = icmp ult i64 %b, %c
@@ -91,9 +85,8 @@ define i1 @ula_swap1(i64 %c, i64 %a, i64 %b) {
 define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ula_swap2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %a, %c
   %l1 = icmp ugt i64 %c, %b
@@ -104,9 +97,8 @@ define i1 @ula_swap2(i64 %c, i64 %a, i64 %b) {
 define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ula_swap12:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %c, %a
   %l1 = icmp ugt i64 %c, %b
@@ -119,9 +111,8 @@ define i1 @ula_swap12(i64 %c, i64 %a, i64 %b) {
 define i1 @ugo(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ugo:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a0, a1
-; CHECK-NEXT:    sltu a0, a0, a2
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a0, a1
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %a, %c
   %l1 = icmp ugt i64 %b, %c
@@ -132,9 +123,8 @@ define i1 @ugo(i64 %c, i64 %a, i64 %b) {
 define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ugo_swap1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a0, a1
-; CHECK-NEXT:    sltu a0, a0, a2
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a0, a1
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %c, %a
   %l1 = icmp ugt i64 %b, %c
@@ -145,9 +135,8 @@ define i1 @ugo_swap1(i64 %c, i64 %a, i64 %b) {
 define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ugo_swap2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a0, a1
-; CHECK-NEXT:    sltu a0, a0, a2
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a0, a1
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %a, %c
   %l1 = icmp ult i64 %c, %b
@@ -158,9 +147,8 @@ define i1 @ugo_swap2(i64 %c, i64 %a, i64 %b) {
 define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ugo_swap12:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a0, a1
-; CHECK-NEXT:    sltu a0, a0, a2
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    maxu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a0, a1
 ; CHECK-NEXT:    ret
   %l0 = icmp ult i64 %c, %a
   %l1 = icmp ult i64 %c, %b
@@ -173,9 +161,8 @@ define i1 @ugo_swap12(i64 %c, i64 %a, i64 %b) {
 define i1 @ugea(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: ugea:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    sltu a0, a2, a0
-; CHECK-NEXT:    or a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a1, a0
 ; CHECK-NEXT:    xori a0, a0, 1
 ; CHECK-NEXT:    ret
   %l0 = icmp uge i64 %a, %c
@@ -189,9 +176,8 @@ define i1 @ugea(i64 %c, i64 %a, i64 %b) {
 define i1 @uga(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: uga:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    sltu a1, a0, a1
-; CHECK-NEXT:    sltu a0, a0, a2
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    minu a1, a1, a2
+; CHECK-NEXT:    sltu a0, a0, a1
 ; CHECK-NEXT:    ret
   %l0 = icmp ugt i64 %a, %c
   %l1 = icmp ugt i64 %b, %c
@@ -204,9 +190,8 @@ define i1 @uga(i64 %c, i64 %a, i64 %b) {
 define i1 @sla(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: sla:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    slt a1, a1, a0
-; CHECK-NEXT:    slt a0, a2, a0
-; CHECK-NEXT:    and a0, a1, a0
+; CHECK-NEXT:    max a1, a1, a2
+; CHECK-NEXT:    slt a0, a1, a0
 ; CHECK-NEXT:    ret
   %l0 = icmp slt i64 %a, %c
   %l1 = icmp slt i64 %b, %c
@@ -214,6 +199,7 @@ define i1 @sla(i64 %c, i64 %a, i64 %b) {
   ret i1 %res
 }
 
+; Negative test
 ; Float check.
 define i1 @flo(float %c, float %a, float %b) {
 ; CHECK-RV64I-LABEL: flo:
@@ -259,6 +245,7 @@ define i1 @flo(float %c, float %a, float %b) {
   ret i1 %res
 }
 
+; Negative test
 ; Double check.
 define i1 @dlo(double %c, double %a, double %b) {
 ; CHECK-LABEL: dlo:
@@ -296,6 +283,7 @@ define i1 @dlo(double %c, double %a, double %b) {
   ret i1 %res
 }
 
+; Negative test
 ; More than one user
 define i1 @multi_user(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: multi_user:
@@ -313,6 +301,7 @@ define i1 @multi_user(i64 %c, i64 %a, i64 %b) {
   ret i1 %out
 }
 
+; Negative test
 ; No same comparations
 define i1 @no_same_ops(i64 %c, i64 %a, i64 %b) {
 ; CHECK-LABEL: no_same_ops: