From 4dbd94b5435e1e1e23984d023637dbb77fa89cbd Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 17 Dec 2018 14:11:31 -0800 Subject: [PATCH] Refactor LowerVectorTransfersPass using pattern rewriters This introduces a generic lowering pass for ML functions. The pass is parameterized by template arguments defining individual pattern rewriters. Concrete lowering passes define individual pattern rewriters and inherit from the generic class that takes care of allocating rewriters, traversing ML functions and performing the actual rewrite. While this is similar to the greedy pattern rewriter available in Transform/Utils, it requires adjustments due to the ML/CFG duality. In particular, ML function rewriters must be able to create statements, not only operations, and need access to an MLFuncBuilder. When we move to using the unified function type, the ML-specific rewriting will become unnecessary. Use LowerVectorTransfers as a testbed for the generic pass. PiperOrigin-RevId: 225887424 --- .../mlir/Transforms/MLPatternLoweringPass.h | 166 ++++++++++++++++++ mlir/lib/Transforms/LowerVectorTransfers.cpp | 127 +++++++------- 2 files changed, 232 insertions(+), 61 deletions(-) create mode 100644 mlir/include/mlir/Transforms/MLPatternLoweringPass.h diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h new file mode 100644 index 000000000000..4d3b120b9b02 --- /dev/null +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -0,0 +1,166 @@ +//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines a generic class to implement lowering passes on ML functions as a +// list of pattern rewriters. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H +#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass.h" +#include + +namespace mlir { + +/// Specialization of the pattern rewriter to ML functions. +class MLFuncLoweringRewriter : public PatternRewriter { +public: + explicit MLFuncLoweringRewriter(MLFuncBuilder *builder) + : PatternRewriter(builder->getContext()), builder(builder) {} + + MLFuncBuilder *getBuilder() { return builder; } + + Operation *createOperation(const OperationState &state) override { + auto *result = builder->createOperation(state); + return result; + } + +private: + MLFuncBuilder *builder; +}; + +/// Base class for the MLFunction-wise lowering state. A pointer to the same +/// instance of the subclass will be passed to all `rewrite` calls on operations +/// that belong to the same MLFunction. +class MLFuncGlobalLoweringState { +public: + virtual ~MLFuncGlobalLoweringState() {} + +protected: + // Must be subclassed. + MLFuncGlobalLoweringState() {} +}; + +/// Base class for MLFunction lowering patterns. +class MLLoweringPattern : public Pattern { +public: + /// Subclasses must override this function to implement rewriting. It will be + /// called on all operations found by `match` (declared in Pattern, subclasses + /// must override). It will be passed the function-wise state, common to all + /// matches, and the state returned by the `match` call, if any. The subclass + /// must use `rewriter` to modify the function. + virtual void rewriteOpStmt(Operation *op, + MLFuncGlobalLoweringState *funcWiseState, + std::unique_ptr opState, + MLFuncLoweringRewriter *rewriter) const = 0; + +protected: + // Must be subclassed. + MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context) + : Pattern(opName, benefit, context) {} +}; + +namespace detail { +/// Owning list of ML lowering patterns. +using OwningMLLoweringPatternList = + std::vector>; +} // namespace detail + +/// Generic lowering pass for ML functions. The lowering details are defined as +/// a sequence of pattern matchers. The following constraints on matchers +/// apply: +/// - only one (match root) operation can be removed; +/// - the code produced by rewriters is final, it is not pattern-matched; +/// - the matchers are applied in their order of appearance in the list; +/// - if the match is found, the operation is rewritten immediately and the +/// next _original_ operation is considered. +/// In other words, for each operation, the pass applies the first matching +/// rewriter in the list and advances to the (lexically) next operation. +/// Non-operation statements (ForStmt and IfStmt) are ignored. +/// This is similar to greedy worklist-based pattern rewriter, except that this +/// operates on ML functions using an ML builder and does not maintain the work +/// list. Note that, as of the time of writing, worklist-based rewriter did not +/// support removing multiple operations either. +template +class MLPatternLoweringPass : public FunctionPass { +public: + explicit MLPatternLoweringPass(void *ID) : FunctionPass(ID) {} + + virtual std::unique_ptr + makeFuncWiseState(MLFunction *f) const { + return nullptr; + } + + PassResult runOnMLFunction(MLFunction *f) override; +}; + +///////////////////////////////////////////////////////////////////// +// MLPatternLoweringPass template implementations +///////////////////////////////////////////////////////////////////// + +namespace detail { +template struct ListAdder { + static void addPatternsToList(OwningMLLoweringPatternList *list, + MLIRContext *context) { + static_assert(std::is_base_of::value, + "can only add subclasses of MLLoweringPattern"); + list->emplace_back(new Pattern(context)); + ListAdder::addPatternsToList(list, context); + } +}; + +template struct ListAdder { + static void addPatternsToList(OwningMLLoweringPatternList *list, + MLIRContext *context) { + list->emplace_back(new Pattern(context)); + } +}; +} // namespace detail + +template +PassResult MLPatternLoweringPass::runOnMLFunction(MLFunction *f) { + detail::OwningMLLoweringPatternList patterns; + detail::ListAdder::addPatternsToList(&patterns, f->getContext()); + auto funcWiseState = makeFuncWiseState(f); + + MLFuncBuilder builder(f); + MLFuncLoweringRewriter rewriter(&builder); + + llvm::SmallVector ops; + f->walk([&ops](OperationStmt *stmt) { ops.push_back(stmt); }); + + for (OperationStmt *stmt : ops) { + for (const auto &pattern : patterns) { + rewriter.getBuilder()->setInsertionPoint(stmt); + auto matchResult = pattern->match(stmt); + if (matchResult) { + pattern->rewriteOpStmt(stmt, funcWiseState.get(), + std::move(*matchResult), &rewriter); + break; + } + } + } + + return PassResult::Success; +} + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 462ab95d004c..df30a7794614 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -30,7 +30,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLValue.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SSAValue.h" #include "mlir/IR/Types.h" #include "mlir/Pass.h" @@ -38,6 +40,7 @@ #include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetVector.h" @@ -59,29 +62,6 @@ using namespace mlir; #define DEBUG_TYPE "lower-vector-transfers" -namespace { - -struct LowerVectorTransfersPass : public FunctionPass { - LowerVectorTransfersPass() - : FunctionPass(&LowerVectorTransfersPass::passID) {} - - PassResult runOnMLFunction(MLFunction *f) override; - - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - MLFunctionMatcherContext mlContext; - - static char passID; -}; - -struct LowerVectorTransfersState { - // Top of the function constant zero index. - SSAValue *zero; -}; - -} // end anonymous namespace - -char LowerVectorTransfersPass::passID = 0; - /// Creates the SSAValue for the sum of `a` and `b` without building a /// full-fledged AffineMap for all indices. /// @@ -98,6 +78,13 @@ static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) { ->getResult(0); } +namespace { +struct LowerVectorTransfersState : public MLFuncGlobalLoweringState { + // Top of the function constant zero index. + SSAValue *zero; +}; +} // namespace + /// Performs simple lowering into a combination of: /// 1. local memory allocation, /// 2. vector_load/vector_store from/to local buffer @@ -108,8 +95,9 @@ static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) { // argument being one of the two types. Extract the common behavior into helper // functions and detemplatizing it. template -static void lowerAsLoops(VectorTransferOpTy *transfer, - const LowerVectorTransfersState &state) { +static void rewriteAsLoops(VectorTransferOpTy *transfer, + MLFuncLoweringRewriter *rewriter, + LowerVectorTransfersState *state) { static_assert( std::is_same::value || std::is_same::value, @@ -122,7 +110,14 @@ static void lowerAsLoops(VectorTransferOpTy *transfer, // vectorMemRefType is a view of tmpMemRefType as one vector. auto vectorMemRefType = MemRefType::get({1}, vectorType, {}, 0); - MLFuncBuilder b(cast(transfer->getOperation())); + // Get the ML function builder. + // We need access to the MLFunction builder stored internally in the + // MLFunctionLoweringRewriter general rewriting API does not provide + // ML-specific functions (ForStmt and StmtBlock manipulation). While we could + // forward them or define a whole rewriting chain based on MLFunctionBuilder + // instead of Builer, the code for it would be duplicate boilerplate. As we + // go towards unifying ML and CFG functions, this separation will disappear. + MLFuncBuilder &b = *rewriter->getBuilder(); // 1. First allocate the local buffer in fast memory. // TODO(ntv): CL memory space. @@ -136,7 +131,7 @@ static void lowerAsLoops(VectorTransferOpTy *transfer, // case of GPUs. if (std::is_same::value) { b.create(vecView->getLoc(), transfer->getVector(), - vecView->getResult(), ArrayRef{state.zero}); + vecView->getResult(), ArrayRef{state->zero}); } // 3. Emit the loop-nest. @@ -191,12 +186,13 @@ static void lowerAsLoops(VectorTransferOpTy *transfer, // 5. Read the vector from local storage in case of a vector_transfer_read. // TODO(ntv): This vector_load operation should be further lowered in the // case of GPUs. + llvm::SmallVector newResults = {}; if (std::is_same::value) { b.setInsertionPoint(cast(transfer->getOperation())); auto *vector = b.create(transfer->getLoc(), vecView->getResult(), - ArrayRef{state.zero}) + ArrayRef{state->zero}) ->getResult(); - transfer->getVector()->replaceAllUsesWith(vector); + newResults.push_back(vector); } // 6. Free the local buffer. @@ -204,46 +200,55 @@ static void lowerAsLoops(VectorTransferOpTy *transfer, b.create(transfer->getLoc(), tmpScalarAlloc); // 7. It is now safe to erase the statement. - transfer->erase(); + rewriter->replaceOp(transfer->getOperation(), newResults); } -PassResult LowerVectorTransfersPass::runOnMLFunction(MLFunction *f) { - LowerVectorTransfersState state; - { - MLFuncBuilder b(f); - b.setInsertionPointToStart(f); - state.zero = b.create(b.getUnknownLoc(), 0); +namespace { +template +class VectorTransferExpander : public MLLoweringPattern { +public: + explicit VectorTransferExpander(MLIRContext *context) + : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} + + PatternMatchResult match(Operation *op) const override { + if (m_Op().match(op)) + return matchSuccess(); + return matchFailure(); } - using matcher::Op; - LLVM_DEBUG(dbgs() << "\nLowerVectorTransfersPass on MLFunction\n"); - LLVM_DEBUG(f->print(dbgs())); - - // Avoid any read/write ordering considerations: do it in 2 steps. - // 1. vector_transfer_reads; - auto filterReads = [](const Statement &stmt) { - const auto &opStmt = cast(stmt); - return opStmt.isa(); - }; - for (auto m : Op(filterReads).match(f)) { - auto read = cast(m.first)->cast(); - // TODO(ntv): Drop &* once lowerAsLoops is detemplatized. - lowerAsLoops(&*read, state); + void rewriteOpStmt(Operation *op, MLFuncGlobalLoweringState *funcWiseState, + std::unique_ptr opState, + MLFuncLoweringRewriter *rewriter) const override { + rewriteAsLoops(&*op->dyn_cast(), rewriter, + static_cast(funcWiseState)); } +}; +} // namespace + +namespace { - // 2. vector_transfer_writes; - auto filterWrites = [](const Statement &stmt) { - const auto &opStmt = cast(stmt); - return opStmt.isa(); - }; - for (auto m : Op(filterWrites).match(f)) { - auto write = cast(m.first)->cast(); - // TODO(ntv): Drop &* once lowerAsLoops is detemplatized. - lowerAsLoops(&*write, state); +struct LowerVectorTransfersPass + : public MLPatternLoweringPass< + VectorTransferExpander, + VectorTransferExpander> { + LowerVectorTransfersPass() + : MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {} + + std::unique_ptr + makeFuncWiseState(MLFunction *f) const override { + auto state = llvm::make_unique(); + auto builder = MLFuncBuilder(f); + builder.setInsertionPointToStart(f); + state->zero = builder.create(builder.getUnknownLoc(), 0); + return state; } - return PassResult::Success; -} + static char passID; +}; + +} // end anonymous namespace + +char LowerVectorTransfersPass::passID = 0; FunctionPass *mlir::createLowerVectorTransfersPass() { return new LowerVectorTransfersPass(); -- 2.34.1