// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//
+// Helper method to match any typed zero.
+static bool isZeroValue(Value val) {
+ return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
+}
+
// Helper to detect a sparse tensor type operand.
static bool isSparseTensor(OpOperand *op) {
if (auto enc = getSparseTensorEncoding(op->get().getType())) {
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Value copy = alloc.getCopy();
if (isZero)
- return copy && (matchPattern(copy, m_Zero()) ||
- matchPattern(copy, m_AnyZeroFloat()));
+ return copy && isZeroValue(copy);
return !copy;
}
return false;
if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
if (arg.getOwner()->getParentOp() == op) {
OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
- return matchPattern(t->get(), m_Zero()) ||
- matchPattern(t->get(), m_AnyZeroFloat());
+ return isZeroValue(t->get());
}
- } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
- return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
}
- return false;
+ return isZeroValue(yieldOp.getOperand(0));
}
//===---------------------------------------------------------------------===//