Remove MLPatternLoweringPass and rewrite LowerVectorTransfers to use RewritePatte...
authorRiver Riddle <riverriddle@google.com>
Tue, 2 Apr 2019 03:43:13 +0000 (20:43 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 2 Apr 2019 20:39:17 +0000 (13:39 -0700)
--

PiperOrigin-RevId: 241455472

mlir/include/mlir/Transforms/MLPatternLoweringPass.h [deleted file]
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir

diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
deleted file mode 100644 (file)
index c43b551..0000000
+++ /dev/null
@@ -1,142 +0,0 @@
-//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// Defines a generic class to implement lowering passes on ML functions as a
-// list of pattern rewriters.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
-#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include <type_traits>
-
-namespace mlir {
-
-/// Specialization of the pattern rewriter to ML functions.
-class MLFuncLoweringRewriter : public PatternRewriter {
-public:
-  explicit MLFuncLoweringRewriter(FuncBuilder *builder)
-      : PatternRewriter(builder->getContext()), builder(builder) {}
-
-  FuncBuilder *getBuilder() { return builder; }
-
-  Operation *createOperation(const OperationState &state) override {
-    auto *result = builder->createOperation(state);
-    return result;
-  }
-
-private:
-  FuncBuilder *builder;
-};
-
-/// Base class for the Function-wise lowering state.  A pointer to the same
-/// instance of the subclass will be passed to all `rewrite` calls on operations
-/// that belong to the same Function.
-class MLFuncGlobalLoweringState {
-public:
-  virtual ~MLFuncGlobalLoweringState() {}
-
-protected:
-  // Must be subclassed.
-  MLFuncGlobalLoweringState() {}
-};
-
-/// Base class for Function lowering patterns.
-class MLLoweringPattern : public Pattern {
-public:
-  /// Subclasses must override this function to implement rewriting.  It will be
-  /// called on all operations found by `match` (declared in Pattern, subclasses
-  /// must override).  It will be passed the function-wise state, common to all
-  /// matches, and the state returned by the `match` call, if any.  The subclass
-  /// must use `rewriter` to modify the function.
-  virtual void rewriteOpInst(Operation *op,
-                             MLFuncGlobalLoweringState *funcWiseState,
-                             std::unique_ptr<PatternState> opState,
-                             MLFuncLoweringRewriter *rewriter) const = 0;
-
-protected:
-  // Must be subclassed.
-  MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context)
-      : Pattern(opName, benefit, context) {}
-};
-
-namespace detail {
-/// Owning list of ML lowering patterns.
-using OwningMLLoweringPatternList =
-    std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
-
-template <typename Pattern, typename... Patterns> struct ListAdder {
-  static void addPatternsToList(OwningMLLoweringPatternList *list,
-                                MLIRContext *context) {
-    static_assert(std::is_base_of<MLLoweringPattern, Pattern>::value,
-                  "can only add subclasses of MLLoweringPattern");
-    list->emplace_back(new Pattern(context));
-    ListAdder<Patterns...>::addPatternsToList(list, context);
-  }
-};
-
-template <typename Pattern> struct ListAdder<Pattern> {
-  static void addPatternsToList(OwningMLLoweringPatternList *list,
-                                MLIRContext *context) {
-    list->emplace_back(new Pattern(context));
-  }
-};
-} // namespace detail
-
-/// Generic lowering for ML patterns.  The lowering details are defined as
-/// a sequence of pattern matchers.  The following constraints on matchers
-/// apply:
-/// - only one (match root) operation can be removed;
-/// - the code produced by rewriters is final, it is not pattern-matched;
-/// - the matchers are applied in their order of appearance in the list;
-/// - if the match is found, the operation is rewritten immediately and the
-///   next _original_ operation is considered.
-/// In other words, for each operation, apply the first matching rewriter in the
-/// list and advance to the (lexically) next operation. This is similar to
-/// greedy worklist-based pattern rewriter, except that this operates on ML
-/// functions using an ML builder and does not maintain the work list.  Note
-/// that, as of the time of writing, worklist-based rewriter did not support
-/// removing multiple operations either.
-template <typename... Patterns>
-void applyMLPatternsGreedily(
-    Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
-  detail::OwningMLLoweringPatternList patterns;
-  detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
-
-  FuncBuilder builder(f);
-  MLFuncLoweringRewriter rewriter(&builder);
-
-  llvm::SmallVector<Operation *, 16> ops;
-  f->walk([&ops](Operation *op) { ops.push_back(op); });
-
-  for (Operation *op : ops) {
-    for (const auto &pattern : patterns) {
-      builder.setInsertionPoint(op);
-      if (auto matchResult = pattern->match(op)) {
-        pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult),
-                               &rewriter);
-        break;
-      }
-    }
-  }
-}
-} // end namespace mlir
-
-#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
index 244ec58..54e8b8f 100644 (file)
@@ -39,7 +39,6 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/StandardOps/Ops.h"
 #include "mlir/Support/Functional.h"
-#include "mlir/Transforms/MLPatternLoweringPass.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/VectorOps/VectorOps.h"
 
@@ -99,30 +98,27 @@ namespace {
 ///   4. local memory deallocation.
 /// Minor variations occur depending on whether a VectorTransferReadOp or
 /// a VectorTransferWriteOp is rewritten.
-template <typename VectorTransferOpTy> class VectorTransferRewriter {
-public:
-  VectorTransferRewriter(VectorTransferOpTy transfer,
-                         MLFuncLoweringRewriter *rewriter,
-                         MLFuncGlobalLoweringState *state);
+template <typename VectorTransferOpTy>
+struct VectorTransferRewriter : public RewritePattern {
+  explicit VectorTransferRewriter(MLIRContext *context)
+      : RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
 
   /// Used for staging the transfer in a local scalar buffer.
-  MemRefType tmpMemRefType() {
+  MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
     auto vectorType = transfer.getVectorType();
     return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
                            {}, 0);
   }
+
   /// View of tmpMemRefType as one vector, used in vector load/store to tmp
   /// buffer.
-  MemRefType vectorMemRefType() {
+  MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
     return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
   }
-  /// Performs the rewrite.
-  void rewrite();
 
-private:
-  VectorTransferOpTy transfer;
-  MLFuncLoweringRewriter *rewriter;
-  MLFuncGlobalLoweringState *state;
+  /// Performs the rewrite.
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override;
 };
 } // end anonymous namespace
 
@@ -213,12 +209,6 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view,
   return clippedScalarAccessExprs;
 }
 
-template <typename VectorTransferOpTy>
-VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
-    VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter,
-    MLFuncGlobalLoweringState *state)
-    : transfer(transfer), rewriter(rewriter), state(state){};
-
 /// Lowers VectorTransferReadOp into a combination of:
 ///   1. local memory allocation;
 ///   2. perfect loop nest over:
@@ -260,13 +250,20 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
 ///
 /// TODO(ntv): implement alternatives to clipping.
 /// TODO(ntv): support non-data-parallel operations.
-template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
+
+/// Performs the rewrite.
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
   using namespace mlir::edsc;
   using namespace mlir::edsc::op;
   using namespace mlir::edsc::intrinsics;
 
+  VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
+
   // 1. Setup all the captures.
-  ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
+  ScopedContext scope(FuncBuilder(op), transfer.getLoc());
   IndexedValue remote(transfer.getMemRef());
   MemRefView view(transfer.getMemRef());
   VectorView vectorView(transfer.getVector());
@@ -281,9 +278,9 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
   auto steps = vectorView.getSteps();
 
   // 2. Emit alloc-copy-load-dealloc.
-  ValueHandle tmp = alloc(tmpMemRefType());
+  ValueHandle tmp = alloc(tmpMemRefType(transfer));
   IndexedValue local(tmp);
-  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
+  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
   LoopNestBuilder(pivs, lbs, ubs, steps)({
       // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
       local(ivs) = remote(clip(transfer, view, ivs)),
@@ -292,8 +289,8 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
   (dealloc(tmp)); // vexing parse
 
   // 3. Propagate.
-  transfer.replaceAllUsesWith(vectorValue.getValue());
-  transfer.erase();
+  rewriter.replaceOp(op, vectorValue.getValue());
+  return matchSuccess();
 }
 
 /// Lowers VectorTransferWriteOp into a combination of:
@@ -314,13 +311,18 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
 ///
 /// TODO(ntv): implement alternatives to clipping.
 /// TODO(ntv): support non-data-parallel operations.
-template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
   using namespace mlir::edsc;
   using namespace mlir::edsc::op;
   using namespace mlir::edsc::intrinsics;
 
+  VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
+
   // 1. Setup all the captures.
-  ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
+  ScopedContext scope(FuncBuilder(op), transfer.getLoc());
   IndexedValue remote(transfer.getMemRef());
   MemRefView view(transfer.getMemRef());
   ValueHandle vectorValue(transfer.getVector());
@@ -336,9 +338,9 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
   auto steps = vectorView.getSteps();
 
   // 2. Emit alloc-store-copy-dealloc.
-  ValueHandle tmp = alloc(tmpMemRefType());
+  ValueHandle tmp = alloc(tmpMemRefType(transfer));
   IndexedValue local(tmp);
-  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
+  ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
   store(vectorValue, vec, {constant_index(0)});
   LoopNestBuilder(pivs, lbs, ubs, steps)({
       // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
@@ -346,36 +348,23 @@ template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
   });
   (dealloc(tmp)); // vexing parse...
 
-  transfer.erase();
+  rewriter.replaceOp(op, llvm::None);
+  return matchSuccess();
 }
 
 namespace {
-template <typename VectorTransferOpTy>
-class VectorTransferExpander : public MLLoweringPattern {
-public:
-  explicit VectorTransferExpander(MLIRContext *context)
-      : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
-
-  PatternMatchResult match(Operation *op) const override {
-    if (m_Op<VectorTransferOpTy>().match(op))
-      return matchSuccess();
-    return matchFailure();
-  }
-  void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState,
-                     std::unique_ptr<PatternState> opState,
-                     MLFuncLoweringRewriter *rewriter) const override {
-    VectorTransferRewriter<VectorTransferOpTy>(
-        op->dyn_cast<VectorTransferOpTy>(), rewriter, funcWiseState)
-        .rewrite();
-  }
-};
-
 struct LowerVectorTransfersPass
     : public FunctionPass<LowerVectorTransfersPass> {
   void runOnFunction() {
-    auto &f = getFunction();
-    applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
-                            VectorTransferExpander<VectorTransferWriteOp>>(&f);
+    OwningRewritePatternList patterns;
+    auto *context = &getContext();
+    patterns.push_back(
+        llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
+            context));
+    patterns.push_back(
+        llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
+            context));
+    applyPatternsGreedily(getFunction(), std::move(patterns));
   }
 };
 
@@ -388,5 +377,3 @@ FunctionPassBase *mlir::createLowerVectorTransfersPass() {
 static PassRegistration<LowerVectorTransfersPass>
     pass("lower-vector-transfers", "Materializes vector transfer ops to a "
                                    "proper abstraction for the hardware");
-
-#undef DEBUG_TYPE
index 05ebe8f..4805a86 100644 (file)
@@ -131,8 +131,8 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
 
 // CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
 func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
-  // CHECK-NEXT:  %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
   // CHECK-NEXT:  %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
+  // CHECK-NEXT:  %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
   // CHECK-NEXT:  affine.for %[[I0:.*]] = 0 to %arg0 step 3 {
   // CHECK-NEXT:    affine.for %[[I1:.*]] = 0 to %arg1 step 4 {
   // CHECK-NEXT:      affine.for %[[I2:.*]] = 0 to %arg2 {