//===----------------------------------------------------------------------===//
OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+ // divui (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result =
return a.udiv(b);
});
- // Fold out division by one. Assumes all tensors of all ones are splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return div0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+ // divsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
return a.sdiv_ov(b, overflowOrDiv0);
});
- // Fold out division by one. Assumes all tensors of all ones are splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+ // ceildivui (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
bool overflowOrDiv0 = false;
auto result =
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
APInt one(a.getBitWidth(), 1, true);
return quotient.uadd_ov(one, overflowOrDiv0);
});
- // Fold out ceil division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+ // ceildivsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
return zero.ssub_ov(div, overflowOrDiv0);
});
- // Fold out ceil division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+ // floordivsi (x, 1) -> x.
+ if (matchPattern(getRhs(), m_One()))
+ return getLhs();
+
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result =
return zero.ssub_ov(ceil, overflowOrDiv0);
});
- // Fold out floor division by one. Assumes all tensors of all ones are
- // splats.
- if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
- if (rhs.getValue() == 1)
- return getLhs();
- } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
- if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
- return getLhs();
- }
-
return overflowOrDiv0 ? Attribute() : result;
}