[mlir][arith] Clean up ExpandOps pass
authorMogball <jeffniu22@gmail.com>
Mon, 20 Dec 2021 21:58:39 +0000 (21:58 +0000)
committerMogball <jeffniu22@gmail.com>
Mon, 20 Dec 2021 21:59:11 +0000 (21:59 +0000)
mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp

index d9ab927..d06c304 100644 (file)
 
 using namespace mlir;
 
+/// Create an integer or index constant.
+static Value createConst(Location loc, Type type, int value,
+                         PatternRewriter &rewriter) {
+  return rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getIntegerAttr(type, value));
+}
+
 namespace {
 
 /// Expands CeilDivUIOp (n, m) into
@@ -26,17 +33,14 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
     Location loc = op.getLoc();
     Value a = op.getLhs();
     Value b = op.getRhs();
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(a.getType(), 0));
+    Value zero = createConst(loc, a.getType(), 0, rewriter);
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(a.getType(), 1));
+    Value one = createConst(loc, a.getType(), 1, rewriter);
     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
-    Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compare, zero, plusOne);
     return success();
   }
 };
@@ -49,16 +53,12 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
-    Type type = signedCeilDivIOp.getType();
-    Value a = signedCeilDivIOp.getLhs();
-    Value b = signedCeilDivIOp.getRhs();
-    Value plusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 1));
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 0));
-    Value minusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, -1));
+    Type type = op.getType();
+    Value a = op.getLhs();
+    Value b = op.getRhs();
+    Value plusOne = createConst(loc, type, 1, rewriter);
+    Value zero = createConst(loc, type, 0, rewriter);
+    Value minusOne = createConst(loc, type, -1, rewriter);
     // Compute x = (b>0) ? -1 : 1.
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
@@ -90,9 +90,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
     Value compareRes =
         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
     // Perform substitution and return success.
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, posRes, negRes);
     return success();
   }
 };
@@ -105,16 +104,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
-    Type type = signedFloorDivIOp.getType();
-    Value a = signedFloorDivIOp.getLhs();
-    Value b = signedFloorDivIOp.getRhs();
-    Value plusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 1));
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 0));
-    Value minusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, -1));
+    Type type = op.getType();
+    Value a = op.getLhs();
+    Value b = op.getRhs();
+    Value plusOne = createConst(loc, type, 1, rewriter);
+    Value zero = createConst(loc, type, 0, rewriter);
+    Value minusOne = createConst(loc, type, -1, rewriter);
     // Compute x = (b<0) ? 1 : -1.
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
@@ -144,9 +139,8 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
     Value compareRes =
         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
     // Perform substitution and return success.
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, negRes, posRes);
     return success();
   }
 };