using namespace mlir;
+/// Create an integer or index constant.
+static Value createConst(Location loc, Type type, int value,
+ PatternRewriter &rewriter) {
+ return rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(type, value));
+}
+
namespace {
/// Expands CeilDivUIOp (n, m) into
Location loc = op.getLoc();
Value a = op.getLhs();
Value b = op.getRhs();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(a.getType(), 0));
+ Value zero = createConst(loc, a.getType(), 0, rewriter);
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
- Value one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(a.getType(), 1));
+ Value one = createConst(loc, a.getType(), 1, rewriter);
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
- Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
- rewriter.replaceOp(op, {res});
+ rewriter.replaceOpWithNewOp<SelectOp>(op, compare, zero, plusOne);
return success();
}
};
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
- Type type = signedCeilDivIOp.getType();
- Value a = signedCeilDivIOp.getLhs();
- Value b = signedCeilDivIOp.getRhs();
- Value plusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, 1));
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, 0));
- Value minusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, -1));
+ Type type = op.getType();
+ Value a = op.getLhs();
+ Value b = op.getRhs();
+ Value plusOne = createConst(loc, type, 1, rewriter);
+ Value zero = createConst(loc, type, 0, rewriter);
+ Value minusOne = createConst(loc, type, -1, rewriter);
// Compute x = (b>0) ? -1 : 1.
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
Value compareRes =
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
- Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
// Perform substitution and return success.
- rewriter.replaceOp(op, {res});
+ rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, posRes, negRes);
return success();
}
};
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
- Type type = signedFloorDivIOp.getType();
- Value a = signedFloorDivIOp.getLhs();
- Value b = signedFloorDivIOp.getRhs();
- Value plusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, 1));
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, 0));
- Value minusOne = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(type, -1));
+ Type type = op.getType();
+ Value a = op.getLhs();
+ Value b = op.getRhs();
+ Value plusOne = createConst(loc, type, 1, rewriter);
+ Value zero = createConst(loc, type, 0, rewriter);
+ Value minusOne = createConst(loc, type, -1, rewriter);
// Compute x = (b<0) ? 1 : -1.
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
Value compareRes =
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
- Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
// Perform substitution and return success.
- rewriter.replaceOp(op, {res});
+ rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, negRes, posRes);
return success();
}
};