[mlir][Linalg][NFC] Improve debugging during vectorization
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 10:07:50 +0000 (02:07 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 1 Dec 2022 10:49:52 +0000 (02:49 -0800)
Make more systematic use of `notifyMatchFailure`.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index a643713..82b191c 100644 (file)
@@ -1242,22 +1242,19 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
 
   // TODO: support mask.
   if (xferOp.getMask())
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
 
   // Transfer into `view`.
   Value viewOrAlloc = xferOp.getSource();
   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
-    return failure();
-
-  LDBG(viewOrAlloc);
+    return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
 
   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
   if (!subViewOp)
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "no subview found");
   Value subView = subViewOp.getResult();
-  LDBG("with subView " << subView);
 
   // Find the copy into `subView` without interleaved uses.
   memref::CopyOp copyOp;
@@ -1266,7 +1263,6 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       assert(newCopyOp.getTarget().getType().isa<MemRefType>());
       if (newCopyOp.getTarget() != subView)
         continue;
-      LDBG("copy candidate " << *newCopyOp);
       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
         continue;
       copyOp = newCopyOp;
@@ -1274,8 +1270,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     }
   }
   if (!copyOp)
-    return failure();
-  LDBG("with copy " << *copyOp);
+    return rewriter.notifyMatchFailure(xferOp, "no copy found");
 
   // Find the fill into `viewOrAlloc` without interleaved uses before the
   // copy.
@@ -1285,7 +1280,6 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       assert(newFillOp.output().getType().isa<MemRefType>());
       if (newFillOp.output() != viewOrAlloc)
         continue;
-      LDBG("fill candidate " << *newFillOp);
       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
         continue;
       maybeFillOp = newFillOp;
@@ -1294,9 +1288,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   }
   // Ensure padding matches.
   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
-    return failure();
-  if (maybeFillOp)
-    LDBG("with maybeFillOp " << *maybeFillOp);
+    return rewriter.notifyMatchFailure(xferOp,
+                                       "padding value does not match fill");
 
   // `in` is the subview that memref.copy reads. Replace it.
   Value in = copyOp.getSource();
@@ -1325,18 +1318,18 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
   // TODO: support mask.
   if (xferOp.getMask())
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
 
   // Transfer into `viewOrAlloc`.
   Value viewOrAlloc = xferOp.getSource();
   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
 
   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
   if (!subViewOp)
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "no subview found");
   Value subView = subViewOp.getResult();
 
   // Find the copy from `subView` without interleaved uses.
@@ -1352,7 +1345,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
     }
   }
   if (!copyOp)
-    return failure();
+    return rewriter.notifyMatchFailure(xferOp, "no copy found");
 
   // `out` is the subview copied into that we replace.
   assert(copyOp.getTarget().getType().isa<MemRefType>());
@@ -1488,7 +1481,8 @@ struct Conv1DGenerator
   /// > 1.
   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
     if (!valid)
-      return failure();
+      return IRRewriter(builder).notifyMatchFailure(op,
+                                                    "unvectorizable 1-D conv");
 
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
@@ -1670,7 +1664,8 @@ struct Conv1DGenerator
   /// > 1.
   FailureOr<Operation *> depthwiseConv() {
     if (!valid)
-      return failure();
+      return IRRewriter(builder).notifyMatchFailure(
+          op, "unvectorizable depthwise conv");
 
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
@@ -1753,10 +1748,8 @@ struct Conv1DGenerator
     }
 
     // Its possible we failed to create the Fma
-    for (auto v : resVals) {
-      if (!v)
-        return failure();
-    }
+    if (!llvm::all_of(resVals, [](Value v) { return v; }))
+      return IRRewriter(builder).notifyMatchFailure(op, "failed to create FMA");
 
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
@@ -1824,14 +1817,15 @@ struct Conv1DGenerator
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, w, f, kw, c);
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
-      return failure();
+      return IRRewriter(builder).notifyMatchFailure(
+          op, "failed to match conv::Nwc 3-par 2-red");
 
     // No transposition needed.
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c, f},
                 /*resIndex*/ {n, w, f}}))
       return conv(Conv1DOpOrder::Nwc);
-    return failure();
+    return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Nwc layout");
   }
 
   /// Entry point that transposes into the common form:
@@ -1840,14 +1834,15 @@ struct Conv1DGenerator
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, f, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
-      return failure();
+      return IRRewriter(builder).notifyMatchFailure(
+          op, "failed to match conv::Ncw 3-par 2-red");
 
     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
                 /*rhsIndex*/ {f, c, kw},
                 /*resIndex*/ {n, f, w}}))
       return conv(Conv1DOpOrder::Ncw);
 
-    return failure();
+    return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Ncw layout");
   }
 
   /// Entry point that transposes into the common form:
@@ -1856,14 +1851,17 @@ struct Conv1DGenerator
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
-      return failure();
+      return IRRewriter(builder).notifyMatchFailure(
+          op, "failed to match depthwise::Nwc conv 3-par 1-red");
 
     // No transposition needed.
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
       return depthwiseConv();
-    return failure();
+
+    return IRRewriter(builder).notifyMatchFailure(
+        op, "not a depthwise::Nwc layout");
   }
 
 private: