[mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed to each op
authorRiver Riddle <riddleriver@gmail.com>
Mon, 14 Jun 2021 18:09:43 +0000 (11:09 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 14 Jun 2021 18:20:15 +0000 (11:20 -0700)
Interface patterns are unique in that they get added to every operation that also implements that interface, given that they aren't tied to individual operations. When the same interface pattern gets added to multiple operations (such as the current behavior with Linalg), an reference to each of these patterns is added to every op (meaning that an operation will now have N references to effectively the same pattern). This revision fixes this problematic behavior in Linalg, and can bring upwards of a 25% reduction in compile time in Linalg based workloads.

Differential Revision: https://reviews.llvm.org/D104160

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index bc9acee..07a378b 100644 (file)
@@ -35,6 +35,7 @@ def Linalg_Dialect : Dialect {
   let dependentDialects = [
     "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
   ];
+  let hasCanonicalizer = 1;
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
     /// Attribute name used to to memoize indexing maps for named ops.
index 2b70572..54dda4c 100644 (file)
@@ -178,7 +178,6 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
   }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
   let skipDefaultBuilders = 1;
 }
 
@@ -230,7 +229,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 /// A base class for pooling operation such as conv. The arguments must contain
@@ -427,7 +425,6 @@ def ConvOp : PoolingBase_Op<"conv", []> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 // Only support buffer semantics.
@@ -490,7 +487,6 @@ class SingleInputPoolingBase_Op<string mnemonic>
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
@@ -673,7 +669,6 @@ def GenericOp : GenericOpBase<"generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 /// GenericOp with Indexing (i.e. multi-for style in which the region is passed
index 6eef6b0..8114a22 100644 (file)
@@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
 
-namespace {
-struct EraseDeadLinalgOp;
-struct FoldTensorCastOp;
-} // namespace
-
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
 
@@ -3374,25 +3369,29 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
 };
 } // namespace
 
-#define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
-  void XXX::getCanonicalizationPatterns(RewritePatternSet &results,            \
-                                        MLIRContext *context) {                \
-    results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,        \
-                RemoveIdentityLinalgOps>(context);                             \
-  }                                                                            \
-                                                                               \
+#define LINALGOP_FOLDERS(XXX)                                                  \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
                           SmallVectorImpl<OpFoldResult> &) {                   \
     return foldMemRefCast(*this);                                              \
   }
 
-CANONICALIZERS_AND_FOLDERS(ConvOp)
-CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
-CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
-CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
-CANONICALIZERS_AND_FOLDERS(CopyOp)
-CANONICALIZERS_AND_FOLDERS(FillOp)
-CANONICALIZERS_AND_FOLDERS(GenericOp)
+LINALGOP_FOLDERS(ConvOp)
+LINALGOP_FOLDERS(PoolingMaxOp)
+LINALGOP_FOLDERS(PoolingMinOp)
+LINALGOP_FOLDERS(PoolingSumOp)
+LINALGOP_FOLDERS(CopyOp)
+LINALGOP_FOLDERS(FillOp)
+LINALGOP_FOLDERS(GenericOp)
 
 // All named ops canonicalizers and folders are auto-generated in the
 // .cpp.inc.
+
+//===----------------------------------------------------------------------===//
+// LinalgDialect
+//===----------------------------------------------------------------------===//
+
+void LinalgDialect::getCanonicalizationPatterns(
+    RewritePatternSet &results) const {
+  results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
+              RemoveIdentityLinalgOps>(getContext());
+}
index f65a0fa..5e76361 100644 (file)
@@ -1405,6 +1405,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
+      patterns);
 }
 
 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
index b46ac20..ab80be5 100644 (file)
@@ -414,6 +414,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
+  ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
   CanonicalizationPatternList<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
index 9e86bf3..faa2835 100644 (file)
@@ -1959,7 +1959,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
       }];
       let hasFolder = 1;
-      let hasCanonicalizer = 1;
 
       let extraClassDeclaration = structuredOpsBaseDecls # [{{
         // Auto-generated.
@@ -2094,13 +2093,7 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
 
 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
                                              StringRef cppOpName) {
-  const char *canonicalizersAndFoldersFmt = R"FMT(
-    void {0}::getCanonicalizationPatterns(
-        RewritePatternSet &results,
-        MLIRContext *context) {{
-      results.add<EraseDeadLinalgOp>(context);
-      results.add<FoldTensorCastOp>(context);
-    }
+  const char *foldersFmt = R"FMT(
     LogicalResult {0}::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {{
       return foldMemRefCast(*this);
@@ -2112,7 +2105,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
       getGenericEffectsImpl(effects,
         getOperation()->getResults(), inputBuffers, outputBuffers);
     })FMT";
-  os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
+  os << llvm::formatv(foldersFmt, cppOpName);
 }
 
 // Prints methods for querying whether the current named op has attributes that
index 312724c..c907cae 100644 (file)
@@ -503,7 +503,6 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
       return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
     }];
     let hasFolder = 1;
-    let hasCanonicalizer = 1;
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
@@ -535,16 +534,10 @@ ArrayAttr {0}::iterator_types() {
 }
 )FMT";
 
-// Implementations of getCanonicalizationPatterns, fold and getEffects.
+// Implementations of fold and getEffects.
 // Parameters:
 // {0}: Class name
-const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
-void {0}::getCanonicalizationPatterns(
-    RewritePatternSet &results,
-    MLIRContext *context) {{
-  results.add<EraseDeadLinalgOp>(context);
-  results.add<FoldTensorCastOp>(context);
-}
+const char structuredOpFoldersFormat[] = R"FMT(
 LogicalResult {0}::fold(ArrayRef<Attribute>,
                         SmallVectorImpl<OpFoldResult> &) {{
   return foldMemRefCast(*this);
@@ -880,7 +873,7 @@ void {0}::regionBuilder(
   }
 
   // Canonicalizers and folders.
-  os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
+  os << llvm::formatv(structuredOpFoldersFormat, className);
 
   return success();
 }