+++ /dev/null
-//===- MLPatternLoweringPass.h - Generic ML lowering pass -------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// Defines a generic class to implement lowering passes on ML functions as a
-// list of pattern rewriters.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
-#define MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include <type_traits>
-
-namespace mlir {
-
-/// Specialization of the pattern rewriter to ML functions.
-class MLFuncLoweringRewriter : public PatternRewriter {
-public:
- explicit MLFuncLoweringRewriter(FuncBuilder *builder)
- : PatternRewriter(builder->getContext()), builder(builder) {}
-
- FuncBuilder *getBuilder() { return builder; }
-
- Operation *createOperation(const OperationState &state) override {
- auto *result = builder->createOperation(state);
- return result;
- }
-
-private:
- FuncBuilder *builder;
-};
-
-/// Base class for the Function-wise lowering state. A pointer to the same
-/// instance of the subclass will be passed to all `rewrite` calls on operations
-/// that belong to the same Function.
-class MLFuncGlobalLoweringState {
-public:
- virtual ~MLFuncGlobalLoweringState() {}
-
-protected:
- // Must be subclassed.
- MLFuncGlobalLoweringState() {}
-};
-
-/// Base class for Function lowering patterns.
-class MLLoweringPattern : public Pattern {
-public:
- /// Subclasses must override this function to implement rewriting. It will be
- /// called on all operations found by `match` (declared in Pattern, subclasses
- /// must override). It will be passed the function-wise state, common to all
- /// matches, and the state returned by the `match` call, if any. The subclass
- /// must use `rewriter` to modify the function.
- virtual void rewriteOpInst(Operation *op,
- MLFuncGlobalLoweringState *funcWiseState,
- std::unique_ptr<PatternState> opState,
- MLFuncLoweringRewriter *rewriter) const = 0;
-
-protected:
- // Must be subclassed.
- MLLoweringPattern(StringRef opName, int64_t benefit, MLIRContext *context)
- : Pattern(opName, benefit, context) {}
-};
-
-namespace detail {
-/// Owning list of ML lowering patterns.
-using OwningMLLoweringPatternList =
- std::vector<std::unique_ptr<mlir::MLLoweringPattern>>;
-
-template <typename Pattern, typename... Patterns> struct ListAdder {
- static void addPatternsToList(OwningMLLoweringPatternList *list,
- MLIRContext *context) {
- static_assert(std::is_base_of<MLLoweringPattern, Pattern>::value,
- "can only add subclasses of MLLoweringPattern");
- list->emplace_back(new Pattern(context));
- ListAdder<Patterns...>::addPatternsToList(list, context);
- }
-};
-
-template <typename Pattern> struct ListAdder<Pattern> {
- static void addPatternsToList(OwningMLLoweringPatternList *list,
- MLIRContext *context) {
- list->emplace_back(new Pattern(context));
- }
-};
-} // namespace detail
-
-/// Generic lowering for ML patterns. The lowering details are defined as
-/// a sequence of pattern matchers. The following constraints on matchers
-/// apply:
-/// - only one (match root) operation can be removed;
-/// - the code produced by rewriters is final, it is not pattern-matched;
-/// - the matchers are applied in their order of appearance in the list;
-/// - if the match is found, the operation is rewritten immediately and the
-/// next _original_ operation is considered.
-/// In other words, for each operation, apply the first matching rewriter in the
-/// list and advance to the (lexically) next operation. This is similar to
-/// greedy worklist-based pattern rewriter, except that this operates on ML
-/// functions using an ML builder and does not maintain the work list. Note
-/// that, as of the time of writing, worklist-based rewriter did not support
-/// removing multiple operations either.
-template <typename... Patterns>
-void applyMLPatternsGreedily(
- Function *f, MLFuncGlobalLoweringState *funcWiseState = nullptr) {
- detail::OwningMLLoweringPatternList patterns;
- detail::ListAdder<Patterns...>::addPatternsToList(&patterns, f->getContext());
-
- FuncBuilder builder(f);
- MLFuncLoweringRewriter rewriter(&builder);
-
- llvm::SmallVector<Operation *, 16> ops;
- f->walk([&ops](Operation *op) { ops.push_back(op); });
-
- for (Operation *op : ops) {
- for (const auto &pattern : patterns) {
- builder.setInsertionPoint(op);
- if (auto matchResult = pattern->match(op)) {
- pattern->rewriteOpInst(op, funcWiseState, std::move(*matchResult),
- &rewriter);
- break;
- }
- }
- }
-}
-} // end namespace mlir
-
-#endif // MLIR_TRANSFORMS_MLPATTERNLOWERINGPASS_H
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/Functional.h"
-#include "mlir/Transforms/MLPatternLoweringPass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/VectorOps/VectorOps.h"
/// 4. local memory deallocation.
/// Minor variations occur depending on whether a VectorTransferReadOp or
/// a VectorTransferWriteOp is rewritten.
-template <typename VectorTransferOpTy> class VectorTransferRewriter {
-public:
- VectorTransferRewriter(VectorTransferOpTy transfer,
- MLFuncLoweringRewriter *rewriter,
- MLFuncGlobalLoweringState *state);
+template <typename VectorTransferOpTy>
+struct VectorTransferRewriter : public RewritePattern {
+ explicit VectorTransferRewriter(MLIRContext *context)
+ : RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
/// Used for staging the transfer in a local scalar buffer.
- MemRefType tmpMemRefType() {
+ MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
auto vectorType = transfer.getVectorType();
return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
{}, 0);
}
+
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
/// buffer.
- MemRefType vectorMemRefType() {
+ MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
}
- /// Performs the rewrite.
- void rewrite();
-private:
- VectorTransferOpTy transfer;
- MLFuncLoweringRewriter *rewriter;
- MLFuncGlobalLoweringState *state;
+ /// Performs the rewrite.
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
return clippedScalarAccessExprs;
}
-template <typename VectorTransferOpTy>
-VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
- VectorTransferOpTy transfer, MLFuncLoweringRewriter *rewriter,
- MLFuncGlobalLoweringState *state)
- : transfer(transfer), rewriter(rewriter), state(state){};
-
/// Lowers VectorTransferReadOp into a combination of:
/// 1. local memory allocation;
/// 2. perfect loop nest over:
///
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
-template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
+
+/// Performs the rewrite.
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
+ VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
+
// 1. Setup all the captures.
- ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
+ ScopedContext scope(FuncBuilder(op), transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
VectorView vectorView(transfer.getVector());
auto steps = vectorView.getSteps();
// 2. Emit alloc-copy-load-dealloc.
- ValueHandle tmp = alloc(tmpMemRefType());
+ ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
- ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
+ ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
LoopNestBuilder(pivs, lbs, ubs, steps)({
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, view, ivs)),
(dealloc(tmp)); // vexing parse
// 3. Propagate.
- transfer.replaceAllUsesWith(vectorValue.getValue());
- transfer.erase();
+ rewriter.replaceOp(op, vectorValue.getValue());
+ return matchSuccess();
}
/// Lowers VectorTransferWriteOp into a combination of:
///
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
-template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
+template <>
+PatternMatchResult
+VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
+ VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
+
// 1. Setup all the captures.
- ScopedContext scope(FuncBuilder(transfer.getOperation()), transfer.getLoc());
+ ScopedContext scope(FuncBuilder(op), transfer.getLoc());
IndexedValue remote(transfer.getMemRef());
MemRefView view(transfer.getMemRef());
ValueHandle vectorValue(transfer.getVector());
auto steps = vectorView.getSteps();
// 2. Emit alloc-store-copy-dealloc.
- ValueHandle tmp = alloc(tmpMemRefType());
+ ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
- ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
+ ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
store(vectorValue, vec, {constant_index(0)});
LoopNestBuilder(pivs, lbs, ubs, steps)({
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
});
(dealloc(tmp)); // vexing parse...
- transfer.erase();
+ rewriter.replaceOp(op, llvm::None);
+ return matchSuccess();
}
namespace {
-template <typename VectorTransferOpTy>
-class VectorTransferExpander : public MLLoweringPattern {
-public:
- explicit VectorTransferExpander(MLIRContext *context)
- : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- if (m_Op<VectorTransferOpTy>().match(op))
- return matchSuccess();
- return matchFailure();
- }
- void rewriteOpInst(Operation *op, MLFuncGlobalLoweringState *funcWiseState,
- std::unique_ptr<PatternState> opState,
- MLFuncLoweringRewriter *rewriter) const override {
- VectorTransferRewriter<VectorTransferOpTy>(
- op->dyn_cast<VectorTransferOpTy>(), rewriter, funcWiseState)
- .rewrite();
- }
-};
-
struct LowerVectorTransfersPass
: public FunctionPass<LowerVectorTransfersPass> {
void runOnFunction() {
- auto &f = getFunction();
- applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
- VectorTransferExpander<VectorTransferWriteOp>>(&f);
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(
+ llvm::make_unique<VectorTransferRewriter<VectorTransferReadOp>>(
+ context));
+ patterns.push_back(
+ llvm::make_unique<VectorTransferRewriter<VectorTransferWriteOp>>(
+ context));
+ applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
static PassRegistration<LowerVectorTransfersPass>
pass("lower-vector-transfers", "Materializes vector transfer ops to a "
"proper abstraction for the hardware");
-
-#undef DEBUG_TYPE