From 6c295a932d26681f07037d7289df405e36350dd4 Mon Sep 17 00:00:00 2001 From: jacquesguan Date: Thu, 29 Dec 2022 10:51:05 +0800 Subject: [PATCH] [mlir][Arith] Fold integer shift op with zero. This revision folds arith.shrui, arith.shrsi and arith.shli with zero rhs to lhs. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D140749 --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 +++++++++ mlir/test/Dialect/Arith/canonicalize.mlir | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 844905c..f812b3c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2221,6 +2221,9 @@ LogicalResult arith::SelectOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { + // shli(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2236,6 +2239,9 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { + // shrui(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2251,6 +2257,9 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::ShRSIOp::fold(ArrayRef operands) { + // shrsi(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index d181ae9..02cbaa2 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2107,3 +2107,30 @@ func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 { %hi = arith.trunci %sh: i64 to i32 return %hi : i32 } + +// CHECK-LABEL: @foldShli0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShli0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shli %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrui0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrui0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrui %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrsi0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrsi0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrsi %x, %c0 : i64 + return %r : i64 +} -- 2.7.4