From 622aac6a0ad74126fab2fa5c2043bc0424c7eace Mon Sep 17 00:00:00 2001 From: Phoenix Meadowlark Date: Mon, 27 Apr 2020 19:59:16 +0000 Subject: [PATCH] Add a folder for division by one. - Adds a folder for integer division by one with the `divi_signed` and `divi_unsigned` ops. - Creates tests for scalar and tensor versions of these ops. - Modifies the test in `parallel-loop-collapsing.mlir` so that it doesn't assume division by one will be in the output. Differential Revision: https://reviews.llvm.org/D78518 --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 20 ++++++++++++ mlir/test/Transforms/canonicalize.mlir | 38 ++++++++++++++++++++++ mlir/test/Transforms/parallel-loop-collapsing.mlir | 6 ++-- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index b46abd6..0de8f42 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2009,6 +2009,16 @@ OpFoldResult SignedDivIOp::fold(ArrayRef operands) { } return a.sdiv_ov(b, overflowOrDiv0); }); + + // Fold out division by one. Assumes all tensors of all ones are splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + return overflowOrDiv0 ? Attribute() : result; } @@ -2537,6 +2547,16 @@ OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { } return a.udiv(b); }); + + // Fold out division by one. Assumes all tensors of all ones are splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + return div0 ? Attribute() : result; } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 6528d10..6900198 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -884,3 +884,41 @@ func @remove_dead_else(%M : memref<100 x i32>) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: affine.load + +// ----- + +// CHECK-LABEL: func @divi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @divi_signed_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = divi_signed %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @divi_unsigned_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @divi_unsigned_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = divi_unsigned %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_divi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_divi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = divi_signed %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// CHECK-LABEL: func @tensor_divi_unsigned_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = divi_unsigned %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir index 6fcb78c..55c851d 100644 --- a/mlir/test/Transforms/parallel-loop-collapsing.mlir +++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir @@ -37,11 +37,9 @@ func @parallel_many_dims() { // CHECK: [[C2:%.*]] = constant 2 : index // CHECK: loop.parallel ([[NEW_I0:%.*]], [[NEW_I1:%.*]], [[NEW_I2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C2]], [[C1]], [[C1]]) step ([[C1]], [[C1]], [[C1]]) { // CHECK: [[I0:%.*]] = remi_signed [[NEW_I0]], [[C2]] : index -// CHECK: [[I3_COUNT:%.*]] = divi_signed [[NEW_I0]], [[C1]] : index -// CHECK: [[I4_COUNT:%.*]] = divi_signed [[NEW_I1]], [[C1]] : index -// CHECK: [[VAL_16:%.*]] = muli [[I4_COUNT]], [[C13]] : index +// CHECK: [[VAL_16:%.*]] = muli [[NEW_I1]], [[C13]] : index // CHECK: [[I4:%.*]] = addi [[VAL_16]], [[C12]] : index -// CHECK: [[VAL_18:%.*]] = muli [[I3_COUNT]], [[C10]] : index +// CHECK: [[VAL_18:%.*]] = muli [[NEW_I0]], [[C10]] : index // CHECK: [[I3:%.*]] = addi [[VAL_18]], [[C9]] : index // CHECK: [[VAL_20:%.*]] = muli [[NEW_I2]], [[C7]] : index // CHECK: [[I2:%.*]] = addi [[VAL_20]], [[C6]] : index -- 2.7.4