[mlir][Linalg] Fix and properly test CodegenStrategy API
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 2 Feb 2021 12:16:51 +0000 (12:16 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 2 Feb 2021 13:01:12 +0000 (13:01 +0000)
Fix a bug that was introduced where calling the codegen strategy with actual concrete C++ Op types did not trigger the expected behavior.
Also introduce a test for the behavior that was missing.

Differential Revision: https://reviews.llvm.org/D95863

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/test/Dialect/Linalg/codegen-strategy.mlir
mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp

index d73a0f6..2e0796d 100644 (file)
@@ -38,9 +38,8 @@ template <template <typename> class PatternType, typename ConcreteOpType,
 void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
                     MLIRContext *context, StringRef opName,
                     linalg::LinalgTransformationFilter m) {
-  assert(opName.empty() ||
-         opName == ConcreteOpType::getOperationName() &&
-             "explicit name must match ConcreteOpType::getOperationName");
+  assert(opName == ConcreteOpType::getOperationName() &&
+         "explicit name must match ConcreteOpType::getOperationName");
   patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
 }
 
@@ -61,7 +60,8 @@ template <typename LinalgOpType>
 struct Tile : public Transformation {
   explicit Tile(linalg::LinalgTilingOptions options,
                 linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(""), options(options) {}
+      : Transformation(f), opName(LinalgOpType::getOperationName()),
+        options(options) {}
 
   Tile(StringRef name, linalg::LinalgTilingOptions options,
        linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -88,7 +88,8 @@ struct Promote : public Transformation {
   explicit Promote(
       linalg::LinalgPromotionOptions options,
       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(""), options(options) {}
+      : Transformation(f), opName(LinalgOpType::getOperationName()),
+        options(options) {}
 
   Promote(StringRef name, linalg::LinalgPromotionOptions options,
           linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -116,7 +117,8 @@ struct Vectorize : public Transformation {
   explicit Vectorize(
       linalg::LinalgVectorizationOptions options,
       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(""), options(options) {}
+      : Transformation(f), opName(LinalgOpType::getOperationName()),
+        options(options) {}
 
   Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
             linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
index 60cf4fe..5affbdc 100644 (file)
@@ -1,3 +1,6 @@
+// Test that both anchor-op name and MatmulOp-based codegen strategy produce the same result.
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 
index 34ee46a..9c0aadd 100644 (file)
@@ -47,6 +47,12 @@ struct TestLinalgCodegenStrategy
 
   void runOnFunction() override;
 
+  template <typename OpType>
+  void runStrategy(LinalgTilingOptions tilingOptions,
+                   LinalgTilingOptions registerTilingOptions,
+                   vector::VectorContractLowering vectorContractLowering,
+                   vector::VectorTransferSplit vectorTransferSplit);
+
   ListOption<int64_t> tileSizes{*this, "tile-sizes",
                                 llvm::cl::MiscFlags::CommaSeparated,
                                 llvm::cl::desc("Specifies the tile sizes.")};
@@ -107,12 +113,67 @@ struct TestLinalgCodegenStrategy
 };
 } // end anonymous namespace
 
+template <>
+void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
+    LinalgTilingOptions tilingOptions,
+    LinalgTilingOptions registerTilingOptions,
+    vector::VectorContractLowering vectorContractLowering,
+    vector::VectorTransferSplit vectorTransferSplit) {
+  assert(!anchorOpName.empty());
+  CodegenStrategy strategy;
+  strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
+      .promoteIf<LinalgOp>(promote, anchorOpName,
+                           LinalgPromotionOptions()
+                               .setAlignment(16)
+                               .setUseFullTileBuffersByDefault(promoteFullTile))
+      .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
+                        registerTilingOptions)
+      .promoteIf<LinalgOp>(
+          registerPromote, anchorOpName,
+          LinalgPromotionOptions()
+              .setAlignment(16)
+              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+      .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
+      .setVectorTransformsOptions(
+          vector::VectorTransformsOptions()
+              .setVectorTransformsOptions(vectorContractLowering)
+              .setVectorTransferSplit(vectorTransferSplit))
+      .setVectorTransferToSCFOptions(
+          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
+  strategy.transform(getFunction());
+}
+
+template <typename OpType>
+void TestLinalgCodegenStrategy::runStrategy(
+    LinalgTilingOptions tilingOptions,
+    LinalgTilingOptions registerTilingOptions,
+    vector::VectorContractLowering vectorContractLowering,
+    vector::VectorTransferSplit vectorTransferSplit) {
+  CodegenStrategy strategy;
+  strategy.tileIf<OpType>(!tileSizes.empty(), tilingOptions)
+      .template promoteIf<OpType>(
+          promote, LinalgPromotionOptions()
+                       .setAlignment(16)
+                       .setUseFullTileBuffersByDefault(promoteFullTile))
+      .template tileIf<OpType>(!registerTileSizes.empty(),
+                               registerTilingOptions)
+      .template promoteIf<OpType>(
+          registerPromote,
+          LinalgPromotionOptions()
+              .setAlignment(16)
+              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+      .template vectorizeIf<OpType>(vectorize)
+      .setVectorTransformsOptions(
+          vector::VectorTransformsOptions()
+              .setVectorTransformsOptions(vectorContractLowering)
+              .setVectorTransferSplit(vectorTransferSplit))
+      .setVectorTransferToSCFOptions(
+          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
+  strategy.transform(getFunction());
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgCodegenStrategy::runOnFunction() {
-  linalg::LinalgTransformationFilter::FilterFunction filterOpName =
-      [&](Operation *op) -> LogicalResult {
-    return success(op->getName().getStringRef() == anchorOpName);
-  };
   LinalgTilingOptions tilingOptions;
   if (!tileSizes.empty())
     tilingOptions = tilingOptions.setTileSizes(tileSizes);
@@ -137,28 +198,14 @@ void TestLinalgCodegenStrategy::runOnFunction() {
           .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
           .Default(vector::VectorTransferSplit::None);
 
-  CodegenStrategy strategy;
-  strategy.tileIf<LinalgOp>(!tileSizes.empty(), anchorOpName, tilingOptions)
-      .promoteIf<LinalgOp>(promote, anchorOpName,
-                           LinalgPromotionOptions()
-                               .setAlignment(16)
-                               .setUseFullTileBuffersByDefault(promoteFullTile),
-                           filterOpName)
-      .tileIf<LinalgOp>(!registerTileSizes.empty(), anchorOpName,
-                        registerTilingOptions)
-      .promoteIf<LinalgOp>(
-          registerPromote, anchorOpName,
-          LinalgPromotionOptions()
-              .setAlignment(16)
-              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
-      .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
-      .setVectorTransformsOptions(
-          vector::VectorTransformsOptions()
-              .setVectorTransformsOptions(vectorContractLowering)
-              .setVectorTransferSplit(vectorTransferSplit))
-      .setVectorTransferToSCFOptions(
-          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
-  strategy.transform(getFunction());
+  // If no anchorOpNameis specified, just test that strategy applies properly to
+  // linalg::MatmulOp.
+  if (anchorOpName.empty())
+    runStrategy<linalg::MatmulOp>(tilingOptions, registerTilingOptions,
+                                  vectorContractLowering, vectorTransferSplit);
+  else
+    runStrategy<LinalgOp>(tilingOptions, registerTilingOptions,
+                          vectorContractLowering, vectorTransferSplit);
 }
 
 namespace mlir {