From a02af37560ff5aa22dcef5735ef25eaf58eaaf64 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Jan 2022 17:26:38 -0500 Subject: [PATCH] [MLIR] Generalize select to arithmetic canonicalization Given a select whose result is an i1, we can eliminate the conditional in the select completely by adding a few arithmetic operations. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D116839 --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 28 +++++++++++++++------------- mlir/test/Dialect/Standard/canonicalize.mlir | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index a1047a5..ed6698b 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -813,29 +813,31 @@ static LogicalResult verify(ReturnOp op) { // SelectOp //===----------------------------------------------------------------------===// -// Transforms a select to a not, where relevant. +// Transforms a select of a boolean to arithmetic operations // -// select %arg, %false, %true +// select %arg, %x, %y : i1 // // becomes // -// xor %arg, %true -struct SelectToNot : public OpRewritePattern { +// and(%arg, %x) or and(!%arg, %y) +struct SelectI1Simplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &rewriter) const override { - if (!matchPattern(op.getTrueValue(), m_Zero())) - return failure(); - - if (!matchPattern(op.getFalseValue(), m_One())) - return failure(); - if (!op.getType().isInteger(1)) return failure(); - rewriter.replaceOpWithNewOp(op, op.getCondition(), - op.getFalseValue()); + Value falseConstant = + rewriter.create(op.getLoc(), true, 1); + Value notCondition = rewriter.create( + op.getLoc(), op.getCondition(), falseConstant); + + Value trueVal = rewriter.create( + op.getLoc(), op.getCondition(), op.getTrueValue()); + Value falseVal = rewriter.create(op.getLoc(), notCondition, + op.getFalseValue()); + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); return success(); } }; @@ -876,7 +878,7 @@ struct SelectToExtUI : public OpRewritePattern { void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } OpFoldResult SelectOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index b44f78f..67d95ce 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -88,10 +88,23 @@ func @branchCondProp(%arg0: i1) { // CHECK-LABEL: @selToNot // CHECK: %[[trueval:.+]] = arith.constant true -// CHECK: %{{.+}} = arith.xori %arg0, %[[trueval]] : i1 +// CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1 +// CHECK: return %[[res]] func @selToNot(%arg0: i1) -> i1 { %true = arith.constant true %false = arith.constant false %res = select %arg0, %false, %true : i1 return %res : i1 } + +// CHECK-LABEL: @selToArith +// CHECK-NEXT: %[[trueval:.+]] = arith.constant true +// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1 +// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1 +// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1 +// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1 +// CHECK: return %[[res]] +func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { + %res = select %arg0, %arg1, %arg2 : i1 + return %res : i1 +} -- 2.7.4