Search through associative operators for BMI patterns (BLSI, BLSR, BLSMSK)
authorNoah Goldstein <goldstein.w.n@gmail.com>
Mon, 6 Feb 2023 18:04:34 +0000 (12:04 -0600)
committerNoah Goldstein <goldstein.w.n@gmail.com>
Mon, 6 Feb 2023 20:09:17 +0000 (14:09 -0600)
(a & (-b)) & b is often lowered as:
    %sub  = sub i32     0, %b
    %and0 = and i32  %sub, %a
    %and1 = and i32 %and0, %b

Which won't get detected by the BLSI pattern as b & -b are never in
the same SDNode.

This patch will do a small search through associative operators and try
and place BMI patterns in the same node so they will hit the pattern.

Reviewed By: pengfei

Differential Revision: https://reviews.llvm.org/D141179

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bmi-out-of-order.ll

index 76d3a40..ea6a67a 100644 (file)
@@ -49205,6 +49205,70 @@ static SDValue combineScalarAndWithMaskSetcc(SDNode *N, SelectionDAG &DAG,
   return DAG.getZExtOrTrunc(DAG.getBitcast(IntVT, Concat), dl, VT);
 }
 
+static SDValue getBMIMatchingOp(unsigned Opc, SelectionDAG &DAG,
+                                SDValue OpMustEq, SDValue Op, unsigned Depth) {
+  // We don't want to go crazy with the recursion here. This isn't a super
+  // important optimization.
+  static constexpr unsigned kMaxDepth = 2;
+
+  // Only do this re-ordering if op has one use.
+  if (!Op.hasOneUse())
+    return SDValue();
+
+  SDLoc DL(Op);
+  // If we hit another assosiative op, recurse further.
+  if (Op.getOpcode() == Opc) {
+    // Done recursing.
+    if (Depth++ >= kMaxDepth)
+      return SDValue();
+
+    for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx)
+      if (SDValue R =
+              getBMIMatchingOp(Opc, DAG, OpMustEq, Op.getOperand(OpIdx), Depth))
+        return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), R,
+                           Op.getOperand(1 - OpIdx));
+
+  } else if (Op.getOpcode() == ISD::SUB) {
+    if (Opc == ISD::AND) {
+      // BLSI: (and x, (sub 0, x))
+      if (isNullConstant(Op.getOperand(0)) && Op.getOperand(1) == OpMustEq)
+        return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
+    }
+    // Opc must be ISD::AND or ISD::XOR
+    // BLSR: (and x, (sub x, 1))
+    // BLSMSK: (xor x, (sub x, 1))
+    if (isOneConstant(Op.getOperand(1)) && Op.getOperand(0) == OpMustEq)
+      return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
+
+  } else if (Op.getOpcode() == ISD::ADD) {
+    // Opc must be ISD::AND or ISD::XOR
+    // BLSR: (and x, (add x, -1))
+    // BLSMSK: (xor x, (add x, -1))
+    if (isAllOnesConstant(Op.getOperand(1)) && Op.getOperand(0) == OpMustEq)
+      return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
+  }
+  return SDValue();
+}
+
+static SDValue combineBMILogicOp(SDNode *N, SelectionDAG &DAG,
+                                 const X86Subtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+  // Make sure this node is a candidate for BMI instructions.
+  if (!Subtarget.hasBMI() || !VT.isScalarInteger() ||
+      (VT != MVT::i32 && VT != MVT::i64))
+    return SDValue();
+
+  assert(N->getOpcode() == ISD::AND || N->getOpcode() == ISD::XOR);
+
+  // Try and match LHS and RHS.
+  for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx)
+    if (SDValue OpMatch =
+            getBMIMatchingOp(N->getOpcode(), DAG, N->getOperand(OpIdx),
+                             N->getOperand(1 - OpIdx), 0))
+      return OpMatch;
+  return SDValue();
+}
+
 static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
                           TargetLowering::DAGCombinerInfo &DCI,
                           const X86Subtarget &Subtarget) {
@@ -49426,6 +49490,9 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  if (SDValue R = combineBMILogicOp(N, DAG, Subtarget))
+    return R;
+
   return SDValue();
 }
 
@@ -52400,6 +52467,9 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  if (SDValue R = combineBMILogicOp(N, DAG, Subtarget))
+    return R;
+
   return combineFneg(N, DAG, DCI, Subtarget);
 }
 
index 93dd432..1464f45 100644 (file)
@@ -44,8 +44,7 @@ define i64 @blsmask_through1(i64 %a, i64 %b) nounwind {
 ;
 ; X64-LABEL: blsmask_through1:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    xorq %rsi, %rdi
-; X64-NEXT:    leaq -1(%rsi), %rax
+; X64-NEXT:    blsmskq %rsi, %rax
 ; X64-NEXT:    xorq %rdi, %rax
 ; X64-NEXT:    retq
 entry:
@@ -58,20 +57,16 @@ entry:
 define i32 @blsmask_through2(i32 %a, i32 %b, i32 %c) nounwind {
 ; X86-LABEL: blsmask_through2:
 ; X86:       # %bb.0: # %entry
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
-; X86-NEXT:    leal -1(%ecx), %eax
+; X86-NEXT:    blsmskl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    xorl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    xorl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    xorl %ecx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: blsmask_through2:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    # kill: def $esi killed $esi def $rsi
-; X64-NEXT:    leal -1(%rsi), %eax
+; X64-NEXT:    blsmskl %esi, %eax
 ; X64-NEXT:    xorl %edx, %edi
 ; X64-NEXT:    xorl %edi, %eax
-; X64-NEXT:    xorl %esi, %eax
 ; X64-NEXT:    retq
 entry:
   %sub = add nsw i32 %b, -1
@@ -238,9 +233,7 @@ define i64 @blsi_through1(i64 %a, i64 %b) nounwind {
 ;
 ; X64-LABEL: blsi_through1:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    movq %rsi, %rax
-; X64-NEXT:    andq %rsi, %rdi
-; X64-NEXT:    negq %rax
+; X64-NEXT:    blsiq %rsi, %rax
 ; X64-NEXT:    andq %rdi, %rax
 ; X64-NEXT:    retq
 entry:
@@ -253,21 +246,16 @@ entry:
 define i32 @blsi_through2(i32 %a, i32 %b, i32 %c) nounwind {
 ; X86-LABEL: blsi_through2:
 ; X86:       # %bb.0: # %entry
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
-; X86-NEXT:    movl %ecx, %eax
-; X86-NEXT:    negl %eax
+; X86-NEXT:    blsil {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    andl %ecx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: blsi_through2:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    movl %esi, %eax
-; X64-NEXT:    negl %eax
+; X64-NEXT:    blsil %esi, %eax
 ; X64-NEXT:    andl %edx, %edi
 ; X64-NEXT:    andl %edi, %eax
-; X64-NEXT:    andl %esi, %eax
 ; X64-NEXT:    retq
 entry:
   %sub = sub i32 0, %b
@@ -298,11 +286,9 @@ define i64 @blsi_through3(i64 %a, i64 %b, i64 %c) nounwind {
 ;
 ; X64-LABEL: blsi_through3:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    movq %rsi, %rax
-; X64-NEXT:    negq %rax
+; X64-NEXT:    blsiq %rsi, %rax
 ; X64-NEXT:    andq %rdx, %rdi
 ; X64-NEXT:    andq %rdi, %rax
-; X64-NEXT:    andq %rsi, %rax
 ; X64-NEXT:    retq
 entry:
   %sub = sub i64 0, %b
@@ -432,8 +418,7 @@ define i64 @blsr_through1(i64 %a, i64 %b) nounwind {
 ;
 ; X64-LABEL: blsr_through1:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    andq %rsi, %rdi
-; X64-NEXT:    leaq -1(%rsi), %rax
+; X64-NEXT:    blsrq %rsi, %rax
 ; X64-NEXT:    andq %rdi, %rax
 ; X64-NEXT:    retq
 entry:
@@ -446,20 +431,16 @@ entry:
 define i32 @blsr_through2(i32 %a, i32 %b, i32 %c) nounwind {
 ; X86-LABEL: blsr_through2:
 ; X86:       # %bb.0: # %entry
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
-; X86-NEXT:    leal -1(%ecx), %eax
+; X86-NEXT:    blsrl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    andl %ecx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: blsr_through2:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    # kill: def $esi killed $esi def $rsi
-; X64-NEXT:    leal -1(%rsi), %eax
+; X64-NEXT:    blsrl %esi, %eax
 ; X64-NEXT:    andl %edx, %edi
 ; X64-NEXT:    andl %edi, %eax
-; X64-NEXT:    andl %esi, %eax
 ; X64-NEXT:    retq
 entry:
   %sub = add nsw i32 %b, -1