[mlir] Avoid folding `index.remu` and `index.rems` for 0 rhs
authorrikhuijzer <rikhuijzer@pm.me>
Wed, 31 May 2023 17:45:05 +0000 (10:45 -0700)
committerJeff Niu <jeff@modular.com>
Wed, 31 May 2023 17:45:26 +0000 (10:45 -0700)
As discussed in https://github.com/llvm/llvm-project/issues/59714#issuecomment-1369518768, the folder for the remainder operations should be resillient when the rhs is 0.
The file `IndexOps.cpp` was already checking for multiple divisions by zero, so I tried to stick to the code style from those checks.

Fixes #59714.

As a side note, is it correct that remainder operations are never optimized away? I would expect that the following code

```
func.func @remu_test() -> index {
  %c3 = index.constant 2
  %c0 = index.constant 1
  %0 = index.remu %c3, %c0
  return %0 : index
}
```
would be optimized to
```
func.func @remu_test() -> index {
  return index.constant 0 : index
}
```
when called with `mlir-opt --convert-scf-to-openmp temp.mlir`, but maybe I'm misunderstanding something.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D151476

mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir

index b6ccb77..3218933 100644 (file)
@@ -263,7 +263,12 @@ OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
       adaptor.getOperands(),
-      [](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
+      [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
+        // Don't fold division by zero.
+        if (rhs.isZero())
+          return std::nullopt;
+        return lhs.srem(rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +278,12 @@ OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
   return foldBinaryOpChecked(
       adaptor.getOperands(),
-      [](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
+      [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
+        // Don't fold division by zero.
+        if (rhs.isZero())
+          return std::nullopt;
+        return lhs.urem(rhs);
+      });
 }
 
 //===----------------------------------------------------------------------===//
index c9b3079..a9b060b 100644 (file)
@@ -198,6 +198,24 @@ func.func @floordivs_nofold() -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: @rems_zerodiv_nofold
+func.func @rems_zerodiv_nofold() -> index {
+  %lhs = index.constant 2
+  %rhs = index.constant 0
+  // CHECK: index.rems
+  %0 = index.rems %lhs, %rhs
+  return %0 : index
+}
+
+// CHECK-LABEL: @remu_zerodiv_nofold
+func.func @remu_zerodiv_nofold() -> index {
+  %lhs = index.constant 2
+  %rhs = index.constant 0
+  // CHECK: index.remu
+  %0 = index.remu %lhs, %rhs
+  return %0 : index
+}
+
 // CHECK-LABEL: @rems
 func.func @rems() -> index {
   %lhs = index.constant -5