Revert "[MLIR][Arith] Remove unused assertions"
authorBen Shi <powerman1st@163.com>
Thu, 29 Dec 2022 00:54:01 +0000 (08:54 +0800)
committerBen Shi <powerman1st@163.com>
Thu, 29 Dec 2022 00:54:01 +0000 (08:54 +0800)
This reverts commit 50e6c306b1cb03fe398aebc41d1bef5b6c9d9bb0.

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

index 852fd0d..f6446ea 100644 (file)
@@ -887,6 +887,8 @@ OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "maxf takes two operands");
+
   // maxf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -905,6 +907,8 @@ OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
   // maxsi(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -931,6 +935,8 @@ OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
   // maxui(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -955,6 +961,8 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "minf takes two operands");
+
   // minf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -973,6 +981,8 @@ OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
   // minsi(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -999,6 +1009,8 @@ OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
   // minui(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
@@ -1233,6 +1245,8 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
   // trunci(zexti(a)) -> a
   // trunci(sexti(a)) -> a
   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
@@ -1279,6 +1293,8 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
 /// can be represented without precision loss or rounding.
 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
   auto constOperand = operands.front();
   if (!constOperand || !constOperand.isa<FloatAttr>())
     return {};
@@ -1519,6 +1535,8 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "bitcast op expects 1 operand");
+
   auto resType = getType();
   auto operand = operands[0];
   if (!operand)
@@ -1635,6 +1653,8 @@ static std::optional<int64_t> getIntegerWidth(Type t) {
 }
 
 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "cmpi takes two operands");
+
   // cmpi(pred, x, x)
   if (getLhs() == getRhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
@@ -1754,6 +1774,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
 }
 
 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "cmpf takes two operands");
+
   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();