[MLIR] Generalize select to arithmetic canonicalization
authorWilliam S. Moses <gh@wsmoses.com>
Fri, 7 Jan 2022 22:26:38 +0000 (17:26 -0500)
committerWilliam S. Moses <gh@wsmoses.com>
Mon, 10 Jan 2022 16:50:17 +0000 (11:50 -0500)
Given a select whose result is an i1, we can eliminate the conditional in the select completely by adding a few arithmetic operations.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D116839

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index a1047a5..ed6698b 100644 (file)
@@ -813,29 +813,31 @@ static LogicalResult verify(ReturnOp op) {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-// Transforms a select to a not, where relevant.
+// Transforms a select of a boolean to arithmetic operations
 //
-//  select %arg, %false, %true
+//  select %arg, %x, %y : i1
 //
 //  becomes
 //
-//  xor %arg, %true
-struct SelectToNot : public OpRewritePattern<SelectOp> {
+//  and(%arg, %x) or and(!%arg, %y)
+struct SelectI1Simplify : public OpRewritePattern<SelectOp> {
   using OpRewritePattern<SelectOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(SelectOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!matchPattern(op.getTrueValue(), m_Zero()))
-      return failure();
-
-    if (!matchPattern(op.getFalseValue(), m_One()))
-      return failure();
-
     if (!op.getType().isInteger(1))
       return failure();
 
-    rewriter.replaceOpWithNewOp<arith::XOrIOp>(op, op.getCondition(),
-                                               op.getFalseValue());
+    Value falseConstant =
+        rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
+    Value notCondition = rewriter.create<arith::XOrIOp>(
+        op.getLoc(), op.getCondition(), falseConstant);
+
+    Value trueVal = rewriter.create<arith::AndIOp>(
+        op.getLoc(), op.getCondition(), op.getTrueValue());
+    Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
+                                                    op.getFalseValue());
+    rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
     return success();
   }
 };
@@ -876,7 +878,7 @@ struct SelectToExtUI : public OpRewritePattern<SelectOp> {
 
 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<SelectToNot, SelectToExtUI>(context);
+  results.insert<SelectI1Simplify, SelectToExtUI>(context);
 }
 
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
index b44f78f..67d95ce 100644 (file)
@@ -88,10 +88,23 @@ func @branchCondProp(%arg0: i1) {
 
 // CHECK-LABEL: @selToNot
 //       CHECK:       %[[trueval:.+]] = arith.constant true
-//       CHECK:       %{{.+}} = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK:       %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK:   return %[[res]]
 func @selToNot(%arg0: i1) -> i1 {
   %true = arith.constant true
   %false = arith.constant false
   %res = select %arg0, %false, %true : i1
   return %res : i1
 }
+
+// CHECK-LABEL: @selToArith
+//       CHECK-NEXT:       %[[trueval:.+]] = arith.constant true
+//       CHECK-NEXT:       %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK-NEXT:       %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
+//       CHECK-NEXT:       %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
+//       CHECK-NEXT:       %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
+//       CHECK:   return %[[res]]
+func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
+  %res = select %arg0, %arg1, %arg2 : i1
+  return %res : i1
+}