From 9b32886e7e705bb28aab57682e612375075a0ad7 Mon Sep 17 00:00:00 2001 From: jacquesguan Date: Fri, 22 Apr 2022 08:13:14 +0000 Subject: [PATCH] [mlir][Arithmetic] Use common constant fold function in RemSI and RemUI to cover splat. This patch replaces current fold function with the common constant fold funtion in order to cover the situation of constant splat. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D124236 --- mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp | 58 ++++++++++++------------ mlir/test/Dialect/Arithmetic/canonicalize.mlir | 46 +++++++++++++++++++ 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index 5a10440..8f26e66 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -444,23 +444,22 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + // remui (x, 1) -> 0. + if (matchPattern(getRhs(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = + constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.urem(b); + }); - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); + return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// @@ -468,23 +467,22 @@ OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// OpFoldResult arith::RemSIOp::fold(ArrayRef operands) { - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + // remsi (x, 1) -> 0. + if (matchPattern(getRhs(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = + constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.srem(b); + }); - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); + return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index 5ea88b3..1f6b473 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1319,3 +1319,49 @@ func.func @test_negf() -> (f32) { %0 = arith.negf %c : f32 return %0: f32 } + +// ----- + +// CHECK-LABEL: @test_remui( +// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remui() -> (vector<4xi32>) { + %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32> + %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %0 = arith.remui %v1, %v2 : vector<4xi32> + return %0 : vector<4xi32> +} + +// // ----- + +// CHECK-LABEL: @test_remui_1( +// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remui_1(%arg : vector<4xi32>) -> (vector<4xi32>) { + %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32> + %0 = arith.remui %arg, %v : vector<4xi32> + return %0 : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: @test_remsi( +// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remsi() -> (vector<4xi32>) { + %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32> + %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %0 = arith.remsi %v1, %v2 : vector<4xi32> + return %0 : vector<4xi32> +} + +// // ----- + +// CHECK-LABEL: @test_remsi_1( +// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) { + %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32> + %0 = arith.remsi %arg, %v : vector<4xi32> + return %0 : vector<4xi32> +} -- 2.7.4