From 54fafd17a728f3dd33b3cf999b6dfd3cd1d49f12 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 6 Aug 2020 05:13:33 -0400 Subject: [PATCH] [mlir][Linalg] Introduce canonicalization to remove dead LinalgOps When any of the memrefs in a structured linalg op has a zero dimension, it becomes dead. This is consistent with the fact that linalg ops deduce their loop bounds from their operands. Note however that this is not the case for the `tensor<0xelt_type>` which is a special convention that must be lowered away into either `memref` or just `elt_type` before this canonicalization can kick in. Differential Revision: https://reviews.llvm.org/D85413 --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 7 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 142 +++++++++------------ mlir/lib/IR/StandardTypes.cpp | 19 ++- mlir/test/Dialect/Linalg/canonicalize.mlir | 31 +++++ 4 files changed, 104 insertions(+), 95 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index dad6f45..26406cc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -153,6 +153,7 @@ def CopyOp : LinalgStructured_Op<"copy", [ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { @@ -178,6 +179,7 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// A base class for pooling operation such as conv. The arguments must contain @@ -358,6 +360,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } class SingleInputPoolingBase_Op @@ -417,6 +420,7 @@ class SingleInputPoolingBase_Op let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { @@ -658,6 +662,7 @@ def GenericOp : GenericOpBase<"generic"> { let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// GenericOp with Indexing (i.e. multi-for style in which the region is passed @@ -795,6 +800,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -817,6 +823,7 @@ class LinalgNamedStructured_Op props> let printer = [{ return ::printNamedStructuredOp(p, *this); }]; let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } // This file is auto-generated from a tc specification. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 03bd71f..a8d98af 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1153,38 +1153,6 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) { // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces // where a Linalg "named" op **isa** LinalgOp. -LogicalResult ConvOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMaxOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMinOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingSumOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult CopyOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult FillOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult GenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult IndexedGenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1299,58 +1267,64 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { return verifyGenericOp(op); } +struct EraseDeadLinalgOp : public RewritePattern { + EraseDeadLinalgOp(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + for (Value v : linalgOp.getInputsAndOutputBuffers()) { + // Linalg "inputs" may be either tensor or memref type. + // tensor<0xelt_type> is a convention that may not always mean + // "0 iterations". Only erase in cases we see memref<...x0x...>. + auto mt = v.getType().dyn_cast(); + if (!mt) + continue; + if (llvm::is_contained(mt.getShape(), 0)) { + rewriter.eraseOp(linalgOp); + return success(); + } + } + return failure(); + } +}; + +#define CANONICALIZERS_AND_FOLDERS(XXX) \ + void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ + MLIRContext *context) { \ + results.insert(); \ + } \ + \ + LogicalResult XXX::fold(ArrayRef, \ + SmallVectorImpl &) { \ + return foldMemRefCast(*this); \ + } + +CANONICALIZERS_AND_FOLDERS(ConvOp); +CANONICALIZERS_AND_FOLDERS(PoolingMaxOp); +CANONICALIZERS_AND_FOLDERS(PoolingMinOp); +CANONICALIZERS_AND_FOLDERS(PoolingSumOp); +CANONICALIZERS_AND_FOLDERS(CopyOp); +CANONICALIZERS_AND_FOLDERS(FillOp); +CANONICALIZERS_AND_FOLDERS(GenericOp); +CANONICALIZERS_AND_FOLDERS(IndexedGenericOp); + #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" // TODO: Determine whether we can generate the folders and verifiers. -LogicalResult BatchMatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult DotOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatvecOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNDHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} +CANONICALIZERS_AND_FOLDERS(BatchMatmulOp); +CANONICALIZERS_AND_FOLDERS(DotOp); +CANONICALIZERS_AND_FOLDERS(MatmulOp); +CANONICALIZERS_AND_FOLDERS(MatvecOp); +CANONICALIZERS_AND_FOLDERS(ConvWOp); +CANONICALIZERS_AND_FOLDERS(ConvNWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCWOp); +CANONICALIZERS_AND_FOLDERS(ConvHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCHWOp); +CANONICALIZERS_AND_FOLDERS(ConvDHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 70b00cf..f878672 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -732,19 +732,16 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { + // Size 0 corner case is useful for canonicalizations. + if (llvm::is_contained(sizes, 0)) + return getAffineConstantExpr(0, context); + + auto maps = AffineMap::inferFromExprList(exprs); + assert(!maps.empty() && "Expected one non-empty map"); + unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); + AffineExpr expr; bool dynamicPoisonBit = false; - unsigned numDims = 0; - unsigned nSymbols = 0; - // Compute the number of symbols and dimensions of the passed exprs. - for (AffineExpr expr : exprs) { - expr.walk([&numDims, &nSymbols](AffineExpr d) { - if (AffineDimExpr dim = d.dyn_cast()) - numDims = std::max(numDims, dim.getPosition() + 1); - else if (AffineSymbolExpr symbol = d.dyn_cast()) - nSymbols = std::max(nSymbols, symbol.getPosition() + 1); - }); - } int64_t runningSize = 1; for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { int64_t size = std::get<1>(en); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cb7df0..005bd1c 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -172,3 +172,34 @@ func @no_fold_memref_reshape(%arg0 : memref) -> memref // CHECK-LABEL: @no_fold_memref_reshape // CHECK: linalg.reshape // CHECK: linalg.reshape + +// ----- + +#accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)> +] + +#trait = { + args_in = 1, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // memref<0x32> is expected to be dce'ed + linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32> + + // tensor<0xf32> cannot be dce'ed + %1 = linalg.generic #trait %arg1 { + ^bb(%0: f32) : + linalg.yield %0 : f32 + } : tensor<0xf32> -> tensor<0xf32> + + return %1: tensor<0xf32> +} +// CHECK-LABEL: @dce_zero_memref +// CHECK-NOT: linalg.copy +// CHECK-NEXT: linalg.generic + -- 2.7.4