[InstCombine] Disable unsafe select transform behind a flag
authorNikita Popov <nikita.ppv@gmail.com>
Sun, 27 Dec 2020 17:33:15 +0000 (18:33 +0100)
committerNikita Popov <nikita.ppv@gmail.com>
Mon, 28 Dec 2020 21:43:52 +0000 (22:43 +0100)
This disables the poison-unsafe select -> and/or transform behind
a flag (we continue to perform the fold by default). This is intended
to simplify evaluation and testing while we teach various passes
to directly recognize the select pattern.

This only disables the main select -> and/or transform. A number of
related ones are instead changed to canonicalize to the a ? b : false
and a ? true : b forms which represent and/or respectively. This
requires a bit of care to avoid infinite loops, as we do not want
!a ? b : false to be converted into a ? false : b.

The basic idea here is the same as D93065, but keeps the change
behind a flag for now.

Differential Revision: https://reviews.llvm.org/D93840

llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
llvm/test/Transforms/InstCombine/select-and-or.ll [new file with mode: 0644]

index aee6e37..a5aed72 100644 (file)
@@ -213,6 +213,17 @@ public:
                                                                            Pred,
                                                                    Constant *C);
 
+  static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
+    // a ? b : false and a ? true : b are the canonical form of logical and/or.
+    // This includes !a ? b : false and !a ? true : b. Absorbing the not into
+    // the select by swapping operands would break recognition of this pattern
+    // in other analyses, so don't do that.
+    return match(&SI, PatternMatch::m_LogicalAnd(PatternMatch::m_Value(),
+                                                 PatternMatch::m_Value())) ||
+           match(&SI, PatternMatch::m_LogicalOr(PatternMatch::m_Value(),
+                                                PatternMatch::m_Value()));
+  }
+
   /// Return true if the specified value is free to invert (apply ~ to).
   /// This happens in cases where the ~ can be eliminated.  If WillInvertAllUses
   /// is true, work under the assumption that the caller intends to remove all
@@ -267,6 +278,8 @@ public:
       case Instruction::Select:
         if (U.getOperandNo() != 0) // Only if the value is used as select cond.
           return false;
+        if (shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(I)))
+          return false;
         break;
       case Instruction::Br:
         assert(U.getOperandNo() == 0 && "Must be branching on that value.");
index 0756676..5dcea0f 100644 (file)
@@ -47,6 +47,11 @@ using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
+/// FIXME: Enabled by default until the pattern is supported well.
+static cl::opt<bool> EnableUnsafeSelectTransform(
+    "instcombine-unsafe-select-transform", cl::init(true),
+    cl::desc("Enable poison-unsafe select to and/or transform"));
+
 static Value *createMinMax(InstCombiner::BuilderTy &Builder,
                            SelectPatternFlavor SPF, Value *A, Value *B) {
   CmpInst::Predicate Pred = getMinMaxPred(SPF);
@@ -2567,38 +2572,43 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
 
   if (SelType->isIntOrIntVectorTy(1) &&
       TrueVal->getType() == CondVal->getType()) {
-    if (match(TrueVal, m_One())) {
+    if (EnableUnsafeSelectTransform && match(TrueVal, m_One())) {
       // Change: A = select B, true, C --> A = or B, C
       return BinaryOperator::CreateOr(CondVal, FalseVal);
     }
-    if (match(TrueVal, m_Zero())) {
-      // Change: A = select B, false, C --> A = and !B, C
-      Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
-      return BinaryOperator::CreateAnd(NotCond, FalseVal);
-    }
-    if (match(FalseVal, m_Zero())) {
+    if (EnableUnsafeSelectTransform && match(FalseVal, m_Zero())) {
       // Change: A = select B, C, false --> A = and B, C
       return BinaryOperator::CreateAnd(CondVal, TrueVal);
     }
+
+    // select a, false, b -> select !a, b, false
+    if (match(TrueVal, m_Zero())) {
+      Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
+      return SelectInst::Create(NotCond, FalseVal,
+                                ConstantInt::getFalse(SelType));
+    }
+    // select a, b, true -> select !a, true, b
     if (match(FalseVal, m_One())) {
-      // Change: A = select B, C, true --> A = or !B, C
       Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
-      return BinaryOperator::CreateOr(NotCond, TrueVal);
+      return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType),
+                                TrueVal);
     }
 
-    // select a, a, b  -> a | b
-    // select a, b, a  -> a & b
+    // select a, a, b -> select a, true, b
     if (CondVal == TrueVal)
-      return BinaryOperator::CreateOr(CondVal, FalseVal);
+      return replaceOperand(SI, 1, ConstantInt::getTrue(SelType));
+    // select a, b, a -> select a, b, false
     if (CondVal == FalseVal)
-      return BinaryOperator::CreateAnd(CondVal, TrueVal);
+      return replaceOperand(SI, 2, ConstantInt::getFalse(SelType));
 
-    // select a, ~a, b -> (~a) & b
-    // select a, b, ~a -> (~a) | b
+    // select a, !a, b -> select !a, b, false
     if (match(TrueVal, m_Not(m_Specific(CondVal))))
-      return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+      return SelectInst::Create(TrueVal, FalseVal,
+                                ConstantInt::getFalse(SelType));
+    // select a, b, !a -> select !a, true, b
     if (match(FalseVal, m_Not(m_Specific(CondVal))))
-      return BinaryOperator::CreateOr(TrueVal, FalseVal);
+      return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType),
+                                TrueVal);
   }
 
   // Selecting between two integer or vector splat integer constants?
@@ -2942,7 +2952,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   }
 
   Value *NotCond;
-  if (match(CondVal, m_Not(m_Value(NotCond)))) {
+  if (match(CondVal, m_Not(m_Value(NotCond))) &&
+      !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) {
     replaceOperand(SI, 0, NotCond);
     SI.swapValues();
     SI.swapProfMetadata();
diff --git a/llvm/test/Transforms/InstCombine/select-and-or.ll b/llvm/test/Transforms/InstCombine/select-and-or.ll
new file mode 100644 (file)
index 0000000..5fab7cd
--- /dev/null
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -instcombine -instcombine-unsafe-select-transform=0 < %s | FileCheck %s
+
+; Should not be converted to "and", which has different poison semantics.
+define i1 @logical_and(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_and(
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 %b, i1 false
+  ret i1 %res
+}
+
+; Should not be converted to "or", which has different poison semantics.
+define i1 @logical_or(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_or(
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 true, i1 %b
+  ret i1 %res
+}
+; Canonicalize to logical and form, even if that requires adding a "not".
+define i1 @logical_and_not(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_and_not(
+; CHECK-NEXT:    [[NOT_A:%.*]] = xor i1 [[A:%.*]], true
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[NOT_A]], i1 [[B:%.*]], i1 false
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 false, i1 %b
+  ret i1 %res
+}
+
+; Canonicalize to logical or form, even if that requires adding a "not".
+define i1 @logical_or_not(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_or_not(
+; CHECK-NEXT:    [[NOT_A:%.*]] = xor i1 [[A:%.*]], true
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[NOT_A]], i1 true, i1 [[B:%.*]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 %b, i1 true
+  ret i1 %res
+}
+
+; These are variants where condition or !condition is used to represent true
+; or false in one of the select arms. It should be canonicalized to the
+; constants.
+
+define i1 @logical_and_cond_reuse(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_and_cond_reuse(
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 %b, i1 %a
+  ret i1 %res
+}
+
+define i1 @logical_or_cond_reuse(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_or_cond_reuse(
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %res = select i1 %a, i1 %a, i1 %b
+  ret i1 %res
+}
+
+define i1 @logical_and_not_cond_reuse(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_and_not_cond_reuse(
+; CHECK-NEXT:    [[A_NOT:%.*]] = xor i1 [[A:%.*]], true
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %a.not = xor i1 %a, true
+  %res = select i1 %a, i1 %b, i1 %a.not
+  ret i1 %res
+}
+
+define i1 @logical_or_not_cond_reuse(i1 %a, i1 %b) {
+; CHECK-LABEL: @logical_or_not_cond_reuse(
+; CHECK-NEXT:    [[A_NOT:%.*]] = xor i1 [[A:%.*]], true
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[A_NOT]], i1 [[B:%.*]], i1 false
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %a.not = xor i1 %a, true
+  %res = select i1 %a, i1 %a.not, i1 %b
+  ret i1 %res
+}