From 68bd355505463431c9c29a09d94ae866763c3522 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 6 Nov 2019 13:51:19 -0800 Subject: [PATCH] Adding an m_NonZero constant integer matcher. This is useful for making matching cases where a non-zero value is required more readable, such as the results of a constant comparison that are expected to be equal. PiperOrigin-RevId: 278932874 --- mlir/include/mlir/IR/Matchers.h | 20 +++++++++++++++++--- mlir/lib/Dialect/StandardOps/Ops.cpp | 33 +++++++++++++-------------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index a464612..aba63d6 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -111,16 +111,24 @@ struct constant_int_op_binder { } }; -// The matcher that matches a given target constant scalar / vector splat / -// tensor splat integer value. +/// The matcher that matches a given target constant scalar / vector splat / +/// tensor splat integer value. template struct constant_int_value_matcher { bool match(Operation *op) { APInt value; - return constant_int_op_binder(&value).match(op) && TargetValue == value; } }; +/// The matcher that matches anything except the given target constant scalar / +/// vector splat / tensor splat integer value. +template struct constant_int_not_value_matcher { + bool match(Operation *op) { + APInt value; + return constant_int_op_binder(&value).match(op) && TargetNotValue != value; + } +}; + /// The matcher that matches a certain kind of op. template struct op_matcher { bool match(Operation *op) { return isa(op); } @@ -172,6 +180,12 @@ inline detail::constant_int_value_matcher<0> m_Zero() { return detail::constant_int_value_matcher<0>(); } +/// Matches a constant scalar / vector splat / tensor splat integer that is any +/// non-zero value. +inline detail::constant_int_not_value_matcher<0> m_NonZero() { + return detail::constant_int_not_value_matcher<0>(); +} + } // end namespace mlir #endif // MLIR_MATCHERS_H diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 5a452c5..161a6c4 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1070,27 +1070,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern { PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - // Check that the condition is a constant. - if (!matchPattern(condbr.getCondition(), m_Op())) - return matchFailure(); - - Block *foldedDest; - SmallVector branchArgs; - - // If the condition is known to evaluate to false we fold to a branch to the - // false destination. Otherwise, we fold to a branch to the true - // destination. - if (matchPattern(condbr.getCondition(), m_Zero())) { - foldedDest = condbr.getFalseDest(); - branchArgs.assign(condbr.false_operand_begin(), - condbr.false_operand_end()); - } else { - foldedDest = condbr.getTrueDest(); - branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); + if (matchPattern(condbr.getCondition(), m_NonZero())) { + // True branch taken. + rewriter.replaceOpWithNewOp( + condbr, condbr.getTrueDest(), + llvm::to_vector<4>(condbr.getTrueOperands())); + return matchSuccess(); + } else if (matchPattern(condbr.getCondition(), m_Zero())) { + // False branch taken. + rewriter.replaceOpWithNewOp( + condbr, condbr.getFalseDest(), + llvm::to_vector<4>(condbr.getFalseOperands())); + return matchSuccess(); } - - rewriter.replaceOpWithNewOp(condbr, foldedDest, branchArgs); - return matchSuccess(); + return matchFailure(); } }; } // end anonymous namespace. -- 2.7.4