[mlir][Linalg] Make a Linalg CodegenStrategy available.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 14 Oct 2020 09:02:47 +0000 (09:02 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 14 Oct 2020 11:11:26 +0000 (11:11 +0000)
This revision adds a programmable codegen strategy from linalg based on staged rewrite patterns. Testing is exercised on a simple linalg.matmul op.

Differential Revision: https://reviews.llvm.org/D89374

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h [new file with mode: 0644]
mlir/lib/Conversion/VectorToSCF/CMakeLists.txt
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/codegen-strategy.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
new file mode 100644 (file)
index 0000000..25a98e3
--- /dev/null
@@ -0,0 +1,165 @@
+//===- CodegenStrategy.h - Linalg programmable codegen strategy -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
+#define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
+
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+namespace mlir {
+
+class FuncOp;
+
+namespace linalg {
+
+/// Abstract Transformation class applied in a sequence that also handles state
+/// through markers.
+struct Transformation {
+  virtual ~Transformation() = default;
+  virtual OwningRewritePatternList
+  buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0;
+  linalg::LinalgMarker marker;
+};
+
+/// Promotion transformation enqueues a particular stage-1 pattern for
+/// `Tile<LinalgOpType>`with the appropriate `options`.
+template <typename LinalgOpType>
+struct Tile : public Transformation {
+  explicit Tile(linalg::LinalgTilingOptions options) : options(options) {}
+
+  OwningRewritePatternList
+  buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+    OwningRewritePatternList tilingPatterns;
+    tilingPatterns.insert<linalg::LinalgTilingPattern<LinalgOpType>>(
+        context, options, m);
+    return tilingPatterns;
+  }
+
+private:
+  linalg::LinalgTilingOptions options;
+};
+
+/// Promotion transformation enqueues a particular stage-1 pattern for
+/// `Promote<LinalgOpType>`with the appropriate `options`.
+template <typename LinalgOpType>
+struct Promote : public Transformation {
+  explicit Promote(linalg::LinalgPromotionOptions options) : options(options) {}
+
+  OwningRewritePatternList
+  buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+    OwningRewritePatternList promotionPatterns;
+    promotionPatterns.insert<linalg::LinalgPromotionPattern<LinalgOpType>>(
+        context, options, m);
+    return promotionPatterns;
+  }
+
+private:
+  linalg::LinalgPromotionOptions options;
+};
+
+/// Vectorization transformation enqueues a particular stage-1 pattern for
+/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
+/// transfer rewrite forwarding patterns.
+template <typename LinalgOpType>
+struct Vectorize : public Transformation {
+  OwningRewritePatternList
+  buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) override {
+    OwningRewritePatternList vectorizationPatterns;
+    // FillOp may interfere with forwarding patterns atm, so we bump up the
+    // priority of LinalgCopyVTRForwardingPattern /
+    // LinalgCopyVTWForwardingPattern.
+    vectorizationPatterns
+        .insert<linalg::LinalgVectorizationPattern<LinalgOpType>>(context, m);
+    vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
+                                 linalg::LinalgCopyVTWForwardingPattern>(
+        context, /*benefit=*/2);
+    return vectorizationPatterns;
+  }
+};
+
+/// Codegen strategy controls how a Linalg op is progressively lowered.
+/// The application uses a 3-level staged patterns strategy which allows
+/// ordering transformations by using the Linalg `applyStagedPatterns` function,
+/// where:
+///   1. The first stage consists of the successive `tile`, `promote` and
+///   `vectorize` patterns, applied sequentially.
+///   2. The second stage consists of common local canonicalization patterns
+///   that are applied eagerly after each stage-1 pattern.
+///   3. the third stage consists of more global transformation, also applied
+///   eagerly, after all stage-2 patterns. Such more global transformations
+struct CodegenStrategy {
+  /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
+  /// `options`.
+  template <typename LinalgOpType>
+  CodegenStrategy &tile(linalg::LinalgTilingOptions options) {
+    transformationSequence.emplace_back(new Tile<LinalgOpType>(options));
+    return *this;
+  }
+  /// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
+  /// with tiling `options`.
+  template <typename LinalgOpType>
+  CodegenStrategy &tileIf(bool b, linalg::LinalgTilingOptions options) {
+    return b ? tile<LinalgOpType>(options) : *this;
+  }
+  /// Append a pattern to add a level of promotion for `LinalgOpType` with
+  /// promotion `options`.
+  template <typename LinalgOpType>
+  CodegenStrategy &promote(linalg::LinalgPromotionOptions options) {
+    transformationSequence.emplace_back(new Promote<LinalgOpType>(options));
+    return *this;
+  }
+  /// Conditionally append a pattern to add a level of promotion for
+  /// `LinalgOpType` with promotion `options`.
+  template <typename LinalgOpType>
+  CodegenStrategy &promoteIf(bool b, linalg::LinalgPromotionOptions options) {
+    return b ? promote<LinalgOpType>(options) : *this;
+    return *this;
+  }
+  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
+  template <typename LinalgOpType>
+  CodegenStrategy &vectorize() {
+    transformationSequence.emplace_back(new Vectorize<LinalgOpType>());
+    return *this;
+  }
+  /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
+  /// operation.
+  template <typename LinalgOpType>
+  CodegenStrategy &vectorizeIf(bool b) {
+    return b ? vectorize<LinalgOpType>() : *this;
+    return *this;
+  }
+  /// Configure the post staged-patterns late vector transformations.
+  CodegenStrategy &
+  setVectorTransformsOptions(vector::VectorTransformsOptions options) {
+    vectorTransformsOptions = options;
+    return *this;
+  }
+  /// Configure the post staged-patterns late vector.transfer to scf conversion.
+  CodegenStrategy &
+  setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
+    vectorToSCFOptions = options;
+    return *this;
+  }
+
+  /// Apply the transformation patterns in sequence with cleanup transformations
+  /// interleaved.
+  void transform(FuncOp func) const;
+
+private:
+  LogicalResult postPatternTransforms(Operation *func) const;
+
+  vector::VectorTransformsOptions vectorTransformsOptions;
+  VectorTransferToSCFOptions vectorToSCFOptions;
+  SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
+};
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_
index b272455..d6bc8da 100644 (file)
@@ -10,7 +10,6 @@ add_mlir_conversion_library(MLIRVectorToSCF
   LINK_LIBS PUBLIC
   MLIREDSC
   MLIRAffineEDSC
-  MLIRLinalgUtils
   MLIRLLVMIR
   MLIRTransforms
   )
index c0d283d..8a76766 100644 (file)
@@ -16,7 +16,6 @@
 
 #include "../PassDetail.h"
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/EDSC/Builders.h"
 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
@@ -41,6 +40,28 @@ using namespace mlir::edsc::intrinsics;
 using vector::TransferReadOp;
 using vector::TransferWriteOp;
 
+// Return a list of Values that correspond to multiple AffineApplyOp, one for
+// each result of `map`. Each `expr` in `map` is canonicalized and folded
+// greedily according to its operands.
+// TODO: factor out in a common location that both linalg and vector can use.
+static SmallVector<Value, 4>
+applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values) {
+  SmallVector<Value, 4> res;
+  res.reserve(map.getNumResults());
+  unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
+  // For each `expr` in `map`, applies the `expr` to the values extracted from
+  // ranges. If the resulting application can be folded into a Value, the
+  // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
+  for (auto expr : map.getResults()) {
+    AffineMap map = AffineMap::get(numDims, numSym, expr);
+    SmallVector<Value, 4> operands(values.begin(), values.end());
+    fullyComposeAffineMapAndOperands(&map, &operands);
+    canonicalizeMapAndOperands(&map, &operands);
+    res.push_back(b.createOrFold<AffineApplyOp>(loc, map, operands));
+  }
+  return res;
+}
+
 namespace {
 /// Helper class captures the common information needed to lower N>1-D vector
 /// transfer operations (read and write).
@@ -193,7 +214,8 @@ static Value onTheFlyFoldSLT(Value v, Value ub) {
 
 ///   1. Compute the indexings `majorIvs + majorOffsets` and save them in
 ///      `majorIvsPlusOffsets`.
-///   2. Return a value of i1 that determines whether the first `majorIvs.rank()`
+///   2. Return a value of i1 that determines whether the first
+///   `majorIvs.rank()`
 ///      dimensions `majorIvs + majorOffsets` are all within `memrefBounds`.
 static Value
 emitInBoundsCondition(PatternRewriter &rewriter,
@@ -205,8 +227,8 @@ emitInBoundsCondition(PatternRewriter &rewriter,
   majorIvsPlusOffsets.reserve(majorIvs.size());
   unsigned idx = 0;
   SmallVector<Value, 4> bounds =
-      linalg::applyMapToValues(rewriter, xferOp.getLoc(),
-                               xferOp.permutation_map(), memrefBounds.getUbs());
+      applyMapToValues(rewriter, xferOp.getLoc(), xferOp.permutation_map(),
+                       memrefBounds.getUbs());
   for (auto it : llvm::zip(majorIvs, majorOffsets, bounds)) {
     Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
     using namespace mlir::edsc::op;
@@ -450,8 +472,8 @@ static void emitWithBoundsChecks(
     function_ref<void(ArrayRef<Value>)> outOfBoundsFun = nullptr) {
   // Permute the incoming indices according to the permutation map.
   SmallVector<Value, 4> indices =
-      linalg::applyMapToValues(rewriter, transfer.getLoc(),
-                               transfer.permutation_map(), transfer.indices());
+      applyMapToValues(rewriter, transfer.getLoc(), transfer.permutation_map(),
+                       transfer.indices());
 
   // Generate a bounds check if necessary.
   SmallVector<Value, 4> majorIvsPlusOffsets;
index 2b13717..3a44da7 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  CodegenStrategy.cpp
   DropUnitDims.cpp
   Fusion.cpp
   FusionOnTensors.cpp
@@ -31,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRPass
   MLIRStandard
   MLIRStandardToLLVM
+  MLIRTransforms
   MLIRTransformUtils
   MLIRVector
   )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
new file mode 100644 (file)
index 0000000..d27985e
--- /dev/null
@@ -0,0 +1,95 @@
+//===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic and helpers to expose Linalg transforms as
+// composable rewrite patterns through a programmable CodegenStrategy object.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
+
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-codegen-strategy"
+
+void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
+  MLIRContext *context = func.getContext();
+  // Emplace patterns one at a time while also maintaining a simple chained
+  // state transition.
+  unsigned stepCount = 0;
+  SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+  auto zeroState = Identifier::get(std::to_string(stepCount), context);
+  auto currentState = zeroState;
+  for (const std::unique_ptr<Transformation> &t : transformationSequence) {
+    auto nextState = Identifier::get(std::to_string(++stepCount), context);
+    auto marker = (currentState == zeroState)
+                      ? linalg::LinalgMarker({}, nextState)
+                      : linalg::LinalgMarker(currentState, nextState);
+    stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
+    currentState = nextState;
+  }
+
+  OwningRewritePatternList stage2Patterns =
+      linalg::getLinalgTilingCanonicalizationPatterns(context);
+  stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+
+  auto stage3Transforms = [](Operation *op) {
+    // Some of these may be too aggressive as a stage 3 that is applied on each
+    // stage 1 application and may have to be split out to post staged patterns
+    // application (in which case they could just be passes, TBD).
+    PassManager pm(op->getContext());
+    pm.addPass(createLoopInvariantCodeMotionPass());
+    if (failed(pm.run(op->getParentOfType<ModuleOp>())))
+      llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
+    promoteSingleIterationLoops(cast<FuncOp>(op));
+    hoistViewAllocOps(cast<FuncOp>(op));
+    hoistRedundantVectorTransfers(cast<FuncOp>(op));
+    return success();
+  };
+  linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
+                              stage3Transforms);
+
+  //===--------------------------------------------------------------------===//
+  // Post staged patterns transforms
+  //===--------------------------------------------------------------------===//
+
+  ModuleOp module = func.getParentOfType<ModuleOp>();
+
+  // Programmatic splitting of slow/fast path vector transfers.
+  OwningRewritePatternList patterns;
+  patterns.insert<vector::VectorTransferFullPartialRewriter>(
+      context, vectorTransformsOptions);
+  applyPatternsAndFoldGreedily(module, patterns);
+
+  // Programmatic controlled lowering of vector.contract only.
+  OwningRewritePatternList vectorContractLoweringPatterns;
+  vectorContractLoweringPatterns
+      .insert<ContractionOpToOuterProductOpLowering,
+              ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+          vectorTransformsOptions, context);
+  applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
+
+  // Programmatic controlled lowering of vector.transfer only.
+  OwningRewritePatternList vectorToLoopsPatterns;
+  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+                                        vectorToSCFOptions);
+  applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
+
+  // Ensure we drop the marker in the end.
+  module.walk([](LinalgOp op) {
+    op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+  });
+}
diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
new file mode 100644 (file)
index 0000000..49ef2e6
--- /dev/null
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+
+// CHECK-LABEL: func @matmul(
+// OUTER-LABEL: func @matmul(
+func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
+  linalg.matmul
+   ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+   outs(%C: memref<1584x1584xf32>)
+
+  //      CHECK: vector.matrix_multiply
+  // CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
+  // CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
+
+  // OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
+  return
+}
+
index 6aaedf1..effa7e2 100644 (file)
@@ -16,6 +16,7 @@ add_mlir_library(MLIRTestTransforms
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp
   TestInlining.cpp
+  TestLinalgCodegenStrategy.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp
diff --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
new file mode 100644 (file)
index 0000000..57500bc
--- /dev/null
@@ -0,0 +1,150 @@
+//===- TestLinalgCodegenStrategy.cpp - Test Linalg codegen strategy -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing the Linalg codegen strategy.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct TestLinalgCodegenStrategy
+    : public PassWrapper<TestLinalgCodegenStrategy, FunctionPass> {
+  TestLinalgCodegenStrategy() = default;
+  TestLinalgCodegenStrategy(const TestLinalgCodegenStrategy &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // clang-format off
+    registry.insert<AffineDialect,
+                    gpu::GPUDialect,
+                    linalg::LinalgDialect,
+                    scf::SCFDialect,
+                    StandardOpsDialect,
+                    vector::VectorDialect>();
+    // clang-format on
+  }
+
+  void runOnFunction() override;
+
+  ListOption<int64_t> tileSizes{*this, "tile-sizes",
+                                llvm::cl::MiscFlags::CommaSeparated,
+                                llvm::cl::desc("Specifies the tile sizes.")};
+  Option<bool> promote{
+      *this, "promote",
+      llvm::cl::desc("Promote the tile into a small aligned memory buffer."),
+      llvm::cl::init(false)};
+  Option<bool> promoteFullTile{
+      *this, "promote-full-tile-pad",
+      llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
+      llvm::cl::init(false)};
+  ListOption<int64_t> registerTileSizes{
+      *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc(
+          "Specifies the size of the register tile that will be used "
+          " to vectorize")};
+  Option<bool> registerPromote{
+      *this, "register-promote",
+      llvm::cl::desc(
+          "Promote the register tile into a small aligned memory buffer."),
+      llvm::cl::init(false)};
+  Option<bool> registerPromoteFullTile{
+      *this, "register-promote-full-tile-pad",
+      llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
+      llvm::cl::init(false)};
+  Option<bool> vectorize{
+      *this, "vectorize",
+      llvm::cl::desc("Rewrite the linalg op as a vector operation."),
+      llvm::cl::init(false)};
+  Option<std::string> splitVectorTransfersTo{
+      *this, "split-transfers",
+      llvm::cl::desc(
+          "Split vector transfers between slow (masked) and fast "
+          "(unmasked) variants. Possible options are:\n"
+          "\tnone: keep unsplit vector.transfer and pay the full price\n"
+          "\tlinalg-copy: use linalg.fill + linalg.copy for the slow path\n"
+          "\tvector-transfers: use extra small unmasked vector.transfer for"
+          " the slow path\n"),
+      llvm::cl::init("none")};
+  Option<std::string> vectorizeContractionTo{
+      *this, "vectorize-contraction-to",
+      llvm::cl::desc("the type of vector op to use for linalg contractions"),
+      llvm::cl::init("outerproduct")};
+  Option<bool> unrollVectorTransfers{
+      *this, "unroll-vector-transfers",
+      llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
+      llvm::cl::init(false)};
+};
+} // end anonymous namespace
+
+/// Apply transformations specified as patterns.
+void TestLinalgCodegenStrategy::runOnFunction() {
+  LinalgTilingOptions tilingOptions;
+  if (!tileSizes.empty())
+    tilingOptions = tilingOptions.setTileSizes(tileSizes);
+
+  LinalgTilingOptions registerTilingOptions;
+  if (!registerTileSizes.empty())
+    registerTilingOptions =
+        registerTilingOptions.setTileSizes(registerTileSizes);
+
+  vector::VectorContractLowering vectorContractLowering =
+      llvm::StringSwitch<vector::VectorContractLowering>(
+          vectorizeContractionTo.getValue())
+          .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
+          .Case("dot", vector::VectorContractLowering::Dot)
+          .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
+          .Default(vector::VectorContractLowering::OuterProduct);
+  vector::VectorTransferSplit vectorTransferSplit =
+      llvm::StringSwitch<vector::VectorTransferSplit>(
+          splitVectorTransfersTo.getValue())
+          .Case("none", vector::VectorTransferSplit::None)
+          .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
+          .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
+          .Default(vector::VectorTransferSplit::None);
+
+  CodegenStrategy strategy;
+  strategy.tileIf<MatmulOp>(!tileSizes.empty(), tilingOptions)
+      .promoteIf<MatmulOp>(promote,
+                           LinalgPromotionOptions()
+                               .setAlignment(16)
+                               .setUseFullTileBuffersByDefault(promoteFullTile))
+      .tileIf<MatmulOp>(!registerTileSizes.empty(), registerTilingOptions)
+      .promoteIf<MatmulOp>(registerPromote, LinalgPromotionOptions()
+                                                .setAlignment(16)
+                                                .setUseFullTileBuffersByDefault(
+                                                    registerPromoteFullTile))
+      .vectorizeIf<MatmulOp>(vectorize)
+      .setVectorTransformsOptions(
+          vector::VectorTransformsOptions()
+              .setVectorTransformsOptions(vectorContractLowering)
+              .setVectorTransferSplit(vectorTransferSplit))
+      .setVectorTransferToSCFOptions(
+          VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
+
+  strategy.transform(getFunction());
+}
+
+namespace mlir {
+void registerTestLinalgCodegenStrategy() {
+  PassRegistration<TestLinalgCodegenStrategy> testLinalgCodegenStrategyPass(
+      "test-linalg-codegen-strategy", "Test Linalg Codegen Strategy.");
+}
+} // namespace mlir
index 5b03565..ef55e4f 100644 (file)
@@ -58,6 +58,7 @@ void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
+void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgHoisting();
 void registerTestLinalgTransforms();
@@ -116,6 +117,7 @@ void registerTestPasses() {
   registerTestExpandTanhPass();
   registerTestGpuMemoryPromotionPass();
   registerTestInterfaces();
+  registerTestLinalgCodegenStrategy();
   registerTestLinalgFusionTransforms();
   registerTestLinalgHoisting();
   registerTestLinalgTransforms();