From: River Riddle Date: Tue, 2 Apr 2019 03:43:13 +0000 (-0700) Subject: Remove MLPatternLoweringPass and rewrite LowerVectorTransfers to use RewritePatte... X-Git-Tag: llvmorg-11-init~1466^2~2053 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=084669e005692b5a4dbf7f325c0205ec89eb6431;p=platform%2Fupstream%2Fllvm.git Remove MLPatternLoweringPass and rewrite LowerVectorTransfers to use RewritePattern instead. -- PiperOrigin-RevId: 241455472 --- diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h deleted file mode 100644 index c43b551..0000000 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ /dev/null @@ -1,142 +0,0 @@ -//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// Defines a generic class to implement lowering passes on ML functions as a -// list of pattern rewriters. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H -#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { - -/// Specialization of the pattern rewriter to ML functions. -class MLFuncLoweringRewriter : public PatternRewriter { -public: - explicit MLFuncLoweringRewriter(FuncBuilder *builder) - : PatternRewriter(builder->getContext()), builder(builder) {} - - FuncBuilder *getBuilder() { return builder; } - - Operation *createOperation(const OperationState &state) override { - auto *result = builder->createOperation(state); - return result; - } - -private: - FuncBuilder *builder; -}; - -/// Base class for the Function-wise lowering state. A pointer to the same -/// instance of the subclass will be passed to all `rewrite` calls on operations -/// that belong to the same Function. -class MLFuncGlobalLoweringState { -public: - virtual ~MLFuncGlobalLoweringState() {} - -protected: - // Must be subclassed. - MLFuncGlobalLoweringState() {} -}; - -/// Base class for Function lowering patterns. -class MLLoweringPattern : public Pattern { -public: - /// Subclasses must override this function to implement rewriting. It will be - /// called on all operations found by `match` (declared in Pattern, subclasses - /// must override). It will be passed the function-wise state, common to all - /// matches, and the state returned by the `match` call, if any. The subclass - /// must use `rewriter` to modify the function. - virtual void rewriteOpInst(Operation *op, - MLFuncGlobalLoweringState *funcWiseState, - std::unique_ptr opState, - MLFuncLoweringRewriter *rewriter) const = 0; - -protected: - // Must be subclassed. - MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context) - : Pattern(opName, benefit, context) {} -}; - -namespace detail { -/// Owning list of ML lowering patterns. -using OwningMLLoweringPatternList = - std::vector>; - -template struct ListAdder { - static void addPatternsToList(OwningMLLoweringPatternList *list, - MLIRContext *context) { - static_assert(std::is_base_of::value, - "can only add subclasses of MLLoweringPattern"); - list->emplace_back(new Pattern(context)); - ListAdder::addPatternsToList(list, context); - } -}; - -template struct ListAdder { - static void addPatternsToList(OwningMLLoweringPatternList *list, - MLIRContext *context) { - list->emplace_back(new Pattern(context)); - } -}; -} // namespace detail - -/// Generic lowering for ML patterns. The lowering details are defined as -/// a sequence of pattern matchers. The following constraints on matchers -/// apply: -/// - only one (match root) operation can be removed; -/// - the code produced by rewriters is final, it is not pattern-matched; -/// - the matchers are applied in their order of appearance in the list; -/// - if the match is found, the operation is rewritten immediately and the -/// next _original_ operation is considered. -/// In other words, for each operation, apply the first matching rewriter in the -/// list and advance to the (lexically) next operation. This is similar to -/// greedy worklist-based pattern rewriter, except that this operates on ML -/// functions using an ML builder and does not maintain the work list. Note -/// that, as of the time of writing, worklist-based rewriter did not support -/// removing multiple operations either. -template -void applyMLPatternsGreedily( - Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) { - detail::OwningMLLoweringPatternList patterns; - detail::ListAdder::addPatternsToList(&patterns, f->getContext()); - - FuncBuilder builder(f); - MLFuncLoweringRewriter rewriter(&builder); - - llvm::SmallVector ops; - f->walk([&ops](Operation *op) { ops.push_back(op); }); - - for (Operation *op : ops) { - for (const auto &pattern : patterns) { - builder.setInsertionPoint(op); - if (auto matchResult = pattern->match(op)) { - pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult), - &rewriter); - break; - } - } - } -} -} // end namespace mlir - -#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 244ec58..54e8b8f 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -39,7 +39,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Support/Functional.h" -#include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" #include "mlir/VectorOps/VectorOps.h" @@ -99,30 +98,27 @@ namespace { /// 4. local memory deallocation. /// Minor variations occur depending on whether a VectorTransferReadOp or /// a VectorTransferWriteOp is rewritten. -template class VectorTransferRewriter { -public: - VectorTransferRewriter(VectorTransferOpTy transfer, - MLFuncLoweringRewriter *rewriter, - MLFuncGlobalLoweringState *state); +template +struct VectorTransferRewriter : public RewritePattern { + explicit VectorTransferRewriter(MLIRContext *context) + : RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {} /// Used for staging the transfer in a local scalar buffer. - MemRefType tmpMemRefType() { + MemRefType tmpMemRefType(VectorTransferOpTy transfer) const { auto vectorType = transfer.getVectorType(); return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 0); } + /// View of tmpMemRefType as one vector, used in vector load/store to tmp /// buffer. - MemRefType vectorMemRefType() { + MemRefType vectorMemRefType(VectorTransferOpTy transfer) const { return MemRefType::get({1}, transfer.getVectorType(), {}, 0); } - /// Performs the rewrite. - void rewrite(); -private: - VectorTransferOpTy transfer; - MLFuncLoweringRewriter *rewriter; - MLFuncGlobalLoweringState *state; + /// Performs the rewrite. + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace @@ -213,12 +209,6 @@ clip(VectorTransferOpTy transfer, edsc::MemRefView &view, return clippedScalarAccessExprs; } -template -VectorTransferRewriter::VectorTransferRewriter( - VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter, - MLFuncGlobalLoweringState *state) - : transfer(transfer), rewriter(rewriter), state(state){}; - /// Lowers VectorTransferReadOp into a combination of: /// 1. local memory allocation; /// 2. perfect loop nest over: @@ -260,13 +250,20 @@ VectorTransferRewriter::VectorTransferRewriter( /// /// TODO(ntv): implement alternatives to clipping. /// TODO(ntv): support non-data-parallel operations. -template <> void VectorTransferRewriter::rewrite() { + +/// Performs the rewrite. +template <> +PatternMatchResult +VectorTransferRewriter::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { using namespace mlir::edsc; using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; + VectorTransferReadOp transfer = op->cast(); + // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc()); + ScopedContext scope(FuncBuilder(op), transfer.getLoc()); IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); VectorView vectorView(transfer.getVector()); @@ -281,9 +278,9 @@ template <> void VectorTransferRewriter::rewrite() { auto steps = vectorView.getSteps(); // 2. Emit alloc-copy-load-dealloc. - ValueHandle tmp = alloc(tmpMemRefType()); + ValueHandle tmp = alloc(tmpMemRefType(transfer)); IndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp, vectorMemRefType()); + ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); LoopNestBuilder(pivs, lbs, ubs, steps)({ // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). local(ivs) = remote(clip(transfer, view, ivs)), @@ -292,8 +289,8 @@ template <> void VectorTransferRewriter::rewrite() { (dealloc(tmp)); // vexing parse // 3. Propagate. - transfer.replaceAllUsesWith(vectorValue.getValue()); - transfer.erase(); + rewriter.replaceOp(op, vectorValue.getValue()); + return matchSuccess(); } /// Lowers VectorTransferWriteOp into a combination of: @@ -314,13 +311,18 @@ template <> void VectorTransferRewriter::rewrite() { /// /// TODO(ntv): implement alternatives to clipping. /// TODO(ntv): support non-data-parallel operations. -template <> void VectorTransferRewriter::rewrite() { +template <> +PatternMatchResult +VectorTransferRewriter::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { using namespace mlir::edsc; using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; + VectorTransferWriteOp transfer = op->cast(); + // 1. Setup all the captures. - ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc()); + ScopedContext scope(FuncBuilder(op), transfer.getLoc()); IndexedValue remote(transfer.getMemRef()); MemRefView view(transfer.getMemRef()); ValueHandle vectorValue(transfer.getVector()); @@ -336,9 +338,9 @@ template <> void VectorTransferRewriter::rewrite() { auto steps = vectorView.getSteps(); // 2. Emit alloc-store-copy-dealloc. - ValueHandle tmp = alloc(tmpMemRefType()); + ValueHandle tmp = alloc(tmpMemRefType(transfer)); IndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp, vectorMemRefType()); + ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); store(vectorValue, vec, {constant_index(0)}); LoopNestBuilder(pivs, lbs, ubs, steps)({ // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). @@ -346,36 +348,23 @@ template <> void VectorTransferRewriter::rewrite() { }); (dealloc(tmp)); // vexing parse... - transfer.erase(); + rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } namespace { -template -class VectorTransferExpander : public MLLoweringPattern { -public: - explicit VectorTransferExpander(MLIRContext *context) - : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - - PatternMatchResult match(Operation *op) const override { - if (m_Op().match(op)) - return matchSuccess(); - return matchFailure(); - } - void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState, - std::unique_ptr opState, - MLFuncLoweringRewriter *rewriter) const override { - VectorTransferRewriter( - op->dyn_cast(), rewriter, funcWiseState) - .rewrite(); - } -}; - struct LowerVectorTransfersPass : public FunctionPass { void runOnFunction() { - auto &f = getFunction(); - applyMLPatternsGreedily, - VectorTransferExpander>(&f); + OwningRewritePatternList patterns; + auto *context = &getContext(); + patterns.push_back( + llvm::make_unique>( + context)); + patterns.push_back( + llvm::make_unique>( + context)); + applyPatternsGreedily(getFunction(), std::move(patterns)); } }; @@ -388,5 +377,3 @@ FunctionPassBase *mlir::createLowerVectorTransfersPass() { static PassRegistration pass("lower-vector-transfers", "Materializes vector transfer ops to a " "proper abstraction for the hardware"); - -#undef DEBUG_TYPE diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index 05ebe8f..4805a86 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -131,8 +131,8 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-LABEL:func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { func @materialize_write(%M: index, %N: index, %O: index, %P: index) { - // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: %cst = constant splat, 1.000000e+00> : vector<5x4x3xf32> + // CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %arg0 step 3 { // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %arg1 step 4 { // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %arg2 {