[mlir][sparse] minor zero test refactoring in rewriting
authorAart Bik <ajcbik@google.com>
Tue, 6 Sep 2022 23:09:17 +0000 (16:09 -0700)
committerAart Bik <ajcbik@google.com>
Wed, 7 Sep 2022 17:07:11 +0000 (10:07 -0700)
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133382

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

index 9adface..a8c0883 100644 (file)
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
 // Helper methods for the actual rewriting rules.
 //===---------------------------------------------------------------------===//
 
+// Helper method to match any typed zero.
+static bool isZeroValue(Value val) {
+  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
+}
+
 // Helper to detect a sparse tensor type operand.
 static bool isSparseTensor(OpOperand *op) {
   if (auto enc = getSparseTensorEncoding(op->get().getType())) {
@@ -47,8 +52,7 @@ static bool isAlloc(OpOperand *op, bool isZero) {
   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
     Value copy = alloc.getCopy();
     if (isZero)
-      return copy && (matchPattern(copy, m_Zero()) ||
-                      matchPattern(copy, m_AnyZeroFloat()));
+      return copy && isZeroValue(copy);
     return !copy;
   }
   return false;
@@ -100,13 +104,10 @@ static bool isZeroYield(GenericOp op) {
   if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
     if (arg.getOwner()->getParentOp() == op) {
       OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
-      return matchPattern(t->get(), m_Zero()) ||
-             matchPattern(t->get(), m_AnyZeroFloat());
+      return isZeroValue(t->get());
     }
-  } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
-    return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
   }
-  return false;
+  return isZeroValue(yieldOp.getOperand(0));
 }
 
 //===---------------------------------------------------------------------===//