[mlir][Linalg] Introduce canonicalization to remove dead LinalgOps
authorNicolas Vasilache <ntv@google.com>
Thu, 6 Aug 2020 09:13:33 +0000 (05:13 -0400)
committerNicolas Vasilache <ntv@google.com>
Thu, 6 Aug 2020 10:08:46 +0000 (06:08 -0400)
When any of the memrefs in a structured linalg op has a zero dimension, it becomes dead.
This is consistent with the fact that linalg ops deduce their loop bounds from their operands.

Note however that this is not the case for the `tensor<0xelt_type>` which is a special convention
that must be lowered away into either `memref<elt_type>` or just `elt_type` before this
canonicalization can kick in.

Differential Revision: https://reviews.llvm.org/D85413

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index dad6f45..26406cc 100644 (file)
@@ -153,6 +153,7 @@ def CopyOp : LinalgStructured_Op<"copy", [
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
@@ -178,6 +179,7 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 /// A base class for pooling operation such as conv. The arguments must contain
@@ -358,6 +360,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 class SingleInputPoolingBase_Op<string mnemonic>
@@ -417,6 +420,7 @@ class SingleInputPoolingBase_Op<string mnemonic>
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
@@ -658,6 +662,7 @@ def GenericOp : GenericOpBase<"generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 /// GenericOp with Indexing (i.e. multi-for style in which the region is passed
@@ -795,6 +800,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -817,6 +823,7 @@ class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
   let printer = [{ return ::printNamedStructuredOp(p, *this); }];
   let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 // This file is auto-generated from a tc specification.
index 03bd71f..a8d98af 100644 (file)
@@ -1153,38 +1153,6 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 // TODO: Consider making all this boilerplate easy to autogenerate
 // with Tablegen. This seems a desirable property in the context of OpInterfaces
 // where a Linalg "named" op **isa** LinalgOp.
-LogicalResult ConvOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult CopyOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult FillOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult GenericOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
-                                     SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1299,58 +1267,64 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
   return verifyGenericOp<NamedStructuredOpType>(op);
 }
 
+struct EraseDeadLinalgOp : public RewritePattern {
+  EraseDeadLinalgOp(PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    if (!linalgOp)
+      return failure();
+    for (Value v : linalgOp.getInputsAndOutputBuffers()) {
+      // Linalg "inputs" may be either tensor or memref type.
+      // tensor<0xelt_type> is a convention that may not always mean
+      // "0 iterations". Only erase in cases we see memref<...x0x...>.
+      auto mt = v.getType().dyn_cast<MemRefType>();
+      if (!mt)
+        continue;
+      if (llvm::is_contained(mt.getShape(), 0)) {
+        rewriter.eraseOp(linalgOp);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
+#define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
+  void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
+                                        MLIRContext *context) {                \
+    results.insert<EraseDeadLinalgOp>();                                       \
+  }                                                                            \
+                                                                               \
+  LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
+                          SmallVectorImpl<OpFoldResult> &) {                   \
+    return foldMemRefCast(*this);                                              \
+  }
+
+CANONICALIZERS_AND_FOLDERS(ConvOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMaxOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMinOp);
+CANONICALIZERS_AND_FOLDERS(PoolingSumOp);
+CANONICALIZERS_AND_FOLDERS(CopyOp);
+CANONICALIZERS_AND_FOLDERS(FillOp);
+CANONICALIZERS_AND_FOLDERS(GenericOp);
+CANONICALIZERS_AND_FOLDERS(IndexedGenericOp);
+
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
 
 // TODO: Determine whether we can generate the folders and verifiers.
-LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
-                                  SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult DotOp::fold(ArrayRef<Attribute>,
-                          SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
-                            SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
-                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
-                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
-                                SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
-                                SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
+CANONICALIZERS_AND_FOLDERS(BatchMatmulOp);
+CANONICALIZERS_AND_FOLDERS(DotOp);
+CANONICALIZERS_AND_FOLDERS(MatmulOp);
+CANONICALIZERS_AND_FOLDERS(MatvecOp);
+CANONICALIZERS_AND_FOLDERS(ConvWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCWOp);
+CANONICALIZERS_AND_FOLDERS(ConvHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvDHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp);
index 70b00cf..f878672 100644 (file)
@@ -732,19 +732,16 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                                 ArrayRef<AffineExpr> exprs,
                                                 MLIRContext *context) {
+  // Size 0 corner case is useful for canonicalizations.
+  if (llvm::is_contained(sizes, 0))
+    return getAffineConstantExpr(0, context);
+
+  auto maps = AffineMap::inferFromExprList(exprs);
+  assert(!maps.empty() && "Expected one non-empty map");
+  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
+
   AffineExpr expr;
   bool dynamicPoisonBit = false;
-  unsigned numDims = 0;
-  unsigned nSymbols = 0;
-  // Compute the number of symbols and dimensions of the passed exprs.
-  for (AffineExpr expr : exprs) {
-    expr.walk([&numDims, &nSymbols](AffineExpr d) {
-      if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
-        numDims = std::max(numDims, dim.getPosition() + 1);
-      else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
-        nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
-    });
-  }
   int64_t runningSize = 1;
   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
     int64_t size = std::get<1>(en);
index 9cb7df0..005bd1c 100644 (file)
@@ -172,3 +172,34 @@ func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
 // CHECK-LABEL: @no_fold_memref_reshape
 //       CHECK:   linalg.reshape
 //       CHECK:   linalg.reshape
+
+// -----
+
+#accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
+}
+
+func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+  // memref<0x32> is expected to be dce'ed
+  linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>
+
+  // tensor<0xf32> cannot be dce'ed
+  %1 = linalg.generic #trait %arg1 {
+  ^bb(%0: f32) :
+    linalg.yield %0 : f32
+  } : tensor<0xf32> -> tensor<0xf32>
+
+  return %1: tensor<0xf32>
+}
+// CHECK-LABEL: @dce_zero_memref
+//   CHECK-NOT:   linalg.copy
+//  CHECK-NEXT:   linalg.generic
+