[MLIR] Canonicalize/fold select %x, 1, 0 to extui
authorWilliam S. Moses <gh@wsmoses.com>
Mon, 3 Jan 2022 04:49:29 +0000 (23:49 -0500)
committerWilliam S. Moses <gh@wsmoses.com>
Mon, 3 Jan 2022 06:26:28 +0000 (01:26 -0500)
Two canonicalizations for select %x, 1, 0
  If the return type is i1, return simply the condition %x, otherwise extui %x to the return type.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116517

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index de45339..a1047a5 100644 (file)
@@ -840,9 +840,43 @@ struct SelectToNot : public OpRewritePattern<SelectOp> {
   }
 };
 
+//  select %arg, %c1, %c0 => extui %arg
+struct SelectToExtUI : public OpRewritePattern<SelectOp> {
+  using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SelectOp op,
+                                PatternRewriter &rewriter) const override {
+    // Cannot extui i1 to i1, or i1 to f32
+    if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
+      return failure();
+
+    // select %x, c1, %c0 => extui %arg
+    if (matchPattern(op.getTrueValue(), m_One()))
+      if (matchPattern(op.getFalseValue(), m_Zero())) {
+        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
+                                                    op.getCondition());
+        return success();
+      }
+
+    // select %x, c0, %c1 => extui (xor %arg, true)
+    if (matchPattern(op.getTrueValue(), m_Zero()))
+      if (matchPattern(op.getFalseValue(), m_One())) {
+        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+            op, op.getType(),
+            rewriter.create<arith::XOrIOp>(
+                op.getLoc(), op.getCondition(),
+                rewriter.create<arith::ConstantIntOp>(
+                    op.getLoc(), 1, op.getCondition().getType())));
+        return success();
+      }
+
+    return failure();
+  }
+};
+
 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<SelectToNot>(context);
+  results.insert<SelectToNot, SelectToExtUI>(context);
 }
 
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
@@ -861,6 +895,12 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(condition, m_Zero()))
     return falseVal;
 
+  // select %x, true, false => %x
+  if (getType().isInteger(1))
+    if (matchPattern(getTrueValue(), m_One()))
+      if (matchPattern(getFalseValue(), m_Zero()))
+        return condition;
+
   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
     auto pred = cmp.getPredicate();
     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
index 875d9f7..b44f78f 100644 (file)
@@ -29,6 +29,41 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
 
 // -----
 
+// CHECK-LABEL: @select_extui
+//       CHECK:   %[[res:.+]] = arith.extui %arg0 : i1 to i64
+//       CHECK:   return %[[res]]
+func @select_extui(%arg0: i1) -> i64 {
+  %c0_i64 = arith.constant 0 : i64
+  %c1_i64 = arith.constant 1 : i64
+  %res = select %arg0, %c1_i64, %c0_i64 : i64
+  return %res : i64
+}
+
+// CHECK-LABEL: @select_extui2
+// CHECK-DAG:  %true = arith.constant true
+// CHECK-DAG:  %[[xor:.+]] = arith.xori %arg0, %true : i1
+// CHECK-DAG:  %[[res:.+]] = arith.extui %[[xor]] : i1 to i64
+//       CHECK:   return %[[res]]
+func @select_extui2(%arg0: i1) -> i64 {
+  %c0_i64 = arith.constant 0 : i64
+  %c1_i64 = arith.constant 1 : i64
+  %res = select %arg0, %c0_i64, %c1_i64 : i64
+  return %res : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_extui_i1
+//  CHECK-NEXT:   return %arg0
+func @select_extui_i1(%arg0: i1) -> i1 {
+  %c0_i1 = arith.constant false
+  %c1_i1 = arith.constant true
+  %res = select %arg0, %c1_i1, %c0_i1 : i1
+  return %res : i1
+}
+
+// -----
+
 // CHECK-LABEL: @branchCondProp
 //       CHECK:       %[[trueval:.+]] = arith.constant true
 //       CHECK:       %[[falseval:.+]] = arith.constant false