--- /dev/null
+//===- mlir/PatternMatch.h - Base classes for pattern match -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_PATTERN_MATCH_H
+#define MLIR_PATTERN_MATCH_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+
+class Operation;
+class MLFuncBuilder;
+class SSAValue;
+
+//===----------------------------------------------------------------------===//
+// Definition of Pattern and related types.
+//===----------------------------------------------------------------------===//
+
+/// This class represents the benefit of a pattern match in a unitless scheme
+/// that ranges from 0 (very little benefit) to 65K. The most common unit to
+/// use here is the "number of operations matched" by the pattern.
+///
+/// This also has a sentinel representation that can be used for patterns that
+/// fail to match.
+///
+class PatternBenefit {
+ enum { ImpossibleToMatchSentinel = 65535 };
+
+public:
+ /*implicit*/ PatternBenefit(unsigned benefit);
+ PatternBenefit(const PatternBenefit &) = default;
+ PatternBenefit &operator=(const PatternBenefit &) = default;
+
+ static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
+
+ bool isImpossibleToMatch() const {
+ return representation == ImpossibleToMatchSentinel;
+ }
+
+ /// If the corresponding pattern can match, return its benefit. If the
+ // corresponding pattern isImpossibleToMatch() then this aborts.
+ unsigned short getBenefit() const;
+
+ inline bool operator==(const PatternBenefit& other);
+ inline bool operator!=(const PatternBenefit& other);
+
+private:
+ PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
+ unsigned short representation;
+};
+
+/// Pattern state is used by patterns that want to maintain state between their
+/// match and rewrite phases. Patterns can define a pattern-specific subclass
+/// of this.
+class PatternState {
+public:
+ virtual ~PatternState() {}
+
+protected:
+ // Must be subclassed.
+ PatternState() {}
+};
+
+/// This is the type returned by a pattern match. The first field indicates
+/// the benefit of the match, the second is a state token that can optionally
+/// be produced by a pattern match to maintain state between the match and
+/// rewrite phases.
+typedef std::pair<PatternBenefit, std::unique_ptr<PatternState>>
+ PatternMatchResult;
+
+class Pattern {
+public:
+ // Return the benefit (the inverse of "cost") of matching this pattern,
+ // if it is statically determinable. The result is a PatternBenefit if known,
+ // or 'None' if the cost is dynamically computed.
+ Optional<PatternBenefit> getStaticBenefit() const;
+
+ // Return the root node that this pattern matches. Patterns that can
+ // match multiple root types are instantiated once per root.
+ OperationName getRootKind() const;
+
+ //===--------------------------------------------------------------------===//
+ // Implementation hooks for patterns to implement.
+ //===--------------------------------------------------------------------===//
+
+ // Attempt to match against code rooted at the specified operation,
+ // which is the same operation code as getRootKind(). On success it
+ // returns the benefit of the match along with an (optional)
+ // pattern-specific state which is passed back into its rewrite
+ // function if this match is selected. On failure, this returns a
+ // sentinel indicating that it didn’t match.
+ virtual PatternMatchResult match(Operation *op) const = 0;
+
+ // Rewrite the IR rooted at the specified operation with the result of
+ // this pattern, generating any new operations with the specified
+ // builder. If an unexpected error is encountered (an internal
+ // compiler error), it is emitted through the normal MLIR diagnostic
+ // hooks and the IR is left in a valid state.
+ virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+ // TODO: Need a generic builder.
+ MLFuncBuilder &builder) const;
+
+ // Rewrite the IR rooted at the specified operation with the result of
+ // this pattern, generating any new operations with the specified
+ // builder. If an unexpected error is encountered (an internal
+ // compiler error), it is emitted through the normal MLIR diagnostic
+ // hooks and the IR is left in a valid state.
+ virtual void rewrite(Operation *op,
+ // TODO: Need a generic builder.
+ MLFuncBuilder &builder) const;
+
+ virtual ~Pattern() {}
+
+ //===--------------------------------------------------------------------===//
+ // Helper methods to simplify pattern implementations
+ //===--------------------------------------------------------------------===//
+
+ /// This method indicates that no match was found.
+ static PatternMatchResult matchFailure();
+
+ /// This method indicates that a match was found and has the specified cost.
+ PatternMatchResult
+ matchSuccess(PatternBenefit benefit,
+ std::unique_ptr<PatternState> state = {}) const;
+
+ /// This method indicates that a match was found for patterns that have a
+ /// known static benefit.
+ PatternMatchResult
+ matchSuccess(std::unique_ptr<PatternState> state = {}) const;
+
+ /// This method is used as the final replacement hook for patterns that match
+ /// a single result value. In addition to replacing and removing the
+ /// specified operation, clients can specify a list of other nodes that this
+ /// replacement may make (perhaps transitively) dead. If any of those ops are
+ /// dead, this will remove them as well.
+ void replaceSingleResultOp(Operation *op, SSAValue *newValue,
+ ArrayRef<SSAValue *> opsToRemoveIfDead = {}) const;
+
+protected:
+ /// Patterns must specify the root operation name they match against, and can
+ /// also optionally specify a static benefit of matching.
+ Pattern(OperationName rootKind,
+ Optional<PatternBenefit> staticBenefit = llvm::None);
+
+ Pattern(OperationName rootKind, unsigned staticBenefit);
+
+private:
+ const OperationName rootKind;
+ const Optional<PatternBenefit> staticBenefit;
+};
+
+//===----------------------------------------------------------------------===//
+// PatternMatcher class
+//===----------------------------------------------------------------------===//
+
+/// This class manages optimization an execution of a group of patterns, and
+/// provides an API for finding the best match against a given node.
+///
+class PatternMatcher {
+public:
+ /// Create a PatternMatch with the specified set of patterns. This takes
+ /// ownership of the patterns in question.
+ explicit PatternMatcher(ArrayRef<Pattern *> patterns)
+ : patterns(patterns.begin(), patterns.end()) {}
+
+ typedef std::pair<Pattern *, std::unique_ptr<PatternState>> MatchResult;
+
+ /// Find the highest benefit pattern available in the pattern set for the DAG
+ /// rooted at the specified node. This returns the pattern (and any state it
+ /// needs) if found, or null if there are no matches.
+ MatchResult findMatch(Operation *op);
+
+ ~PatternMatcher() { llvm::DeleteContainerPointers(patterns); }
+
+private:
+ PatternMatcher(const PatternMatcher &) = delete;
+ void operator=(const PatternMatcher &) = delete;
+
+ std::vector<Pattern *> patterns;
+};
+} // end namespace mlir
+
+#endif // MLIR_PATTERN_MATCH_H
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/PatternMatch.h"
+
#include <memory>
using namespace mlir;
-//===----------------------------------------------------------------------===//
-// Definition of Pattern and related types.
-//===----------------------------------------------------------------------===//
-
-// TODO(clattner): Move this out of this file when it is ready.
-
-/// This class represents the benefit of a pattern match in a unitless scheme
-/// that ranges from 0 (very little benefit) to 65K. The most common unit to
-/// use here is the "number of operations matched" by the pattern.
-///
-/// This also has a sentinel representation that can be used for patterns that
-/// fail to match.
-///
-class PatternBenefit {
- enum { ImpossibleToMatchSentinel = 65535 };
-
-public:
- /*implicit*/ PatternBenefit(unsigned benefit) : representation(benefit) {
- assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
- "This pattern match benefit is too large to represent");
- }
- PatternBenefit(const PatternBenefit &) = default;
- PatternBenefit &operator=(const PatternBenefit &) = default;
-
- static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
-
- bool isImpossibleToMatch() const {
- return representation == ImpossibleToMatchSentinel;
- }
-
- /// If the corresponding pattern can match, return its benefit. If the
- // corresponding pattern isImpossibleToMatch() then this aborts.
- unsigned short getBenefit() const {
- assert(representation != ImpossibleToMatchSentinel &&
- "Pattern doesn't match");
- return representation;
- }
-
-private:
- PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
- unsigned short representation;
-};
-
-static inline bool operator==(PatternBenefit lhs, PatternBenefit rhs) {
- if (lhs.isImpossibleToMatch())
- return rhs.isImpossibleToMatch();
- if (rhs.isImpossibleToMatch())
- return false;
- return lhs.getBenefit() == rhs.getBenefit();
-}
-
-static inline bool operator!=(PatternBenefit lhs, PatternBenefit rhs) {
- return !operator==(lhs, rhs);
-}
-
-/// Pattern state is used by patterns that want to maintain state between their
-/// match and rewrite phases. Patterns can define a pattern-specific subclass
-/// of this.
-class PatternState {
-public:
- virtual ~PatternState() {}
-
-protected:
- // Must be subclassed.
- PatternState() {}
-};
-
-/// This is the type returned by a pattern match. The first field indicates the
-/// benefit of the match, the second is a state token that can optionally be
-/// produced by a pattern match to maintain state between the match and rewrite
-/// phases.
-typedef std::pair<PatternBenefit, std::unique_ptr<PatternState>>
- PatternMatchResult;
-
-class Pattern {
-public:
- // Return the benefit (the inverse of "cost") of matching this pattern,
- // if it is statically determinable. The result is a PatternBenefit if known,
- // or 'None' if the cost is dynamically computed.
- Optional<PatternBenefit> getStaticBenefit() const { return staticBenefit; }
-
- // Return the root node that this pattern matches. Patterns that can
- // match multiple root types are instantiated once per root.
- OperationName getRootKind() const { return rootKind; }
-
- //===--------------------------------------------------------------------===//
- // Implementation hooks for patterns to implement.
- //===--------------------------------------------------------------------===//
-
- // Attempt to match against code rooted at the specified operation,
- // which is the same operation code as getRootKind(). On success it
- // returns the benefit of the match along with an (optional)
- // pattern-specific state which is passed back into its rewrite
- // function if this match is selected. On failure, this returns a
- // sentinel indicating that it didn’t match.
- virtual PatternMatchResult match(Operation *op) const = 0;
-
- // Rewrite the IR rooted at the specified operation with the result of
- // this pattern, generating any new operations with the specified
- // builder. If an unexpected error is encountered (an internal
- // compiler error), it is emitted through the normal MLIR diagnostic
- // hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
- // TODO: Need a generic builder.
- MLFuncBuilder &builder) const {
- rewrite(op, builder);
- }
-
- // Rewrite the IR rooted at the specified operation with the result of
- // this pattern, generating any new operations with the specified
- // builder. If an unexpected error is encountered (an internal
- // compiler error), it is emitted through the normal MLIR diagnostic
- // hooks and the IR is left in a valid state.
- virtual void rewrite(Operation *op,
- // TODO: Need a generic builder.
- MLFuncBuilder &builder) const {
- llvm_unreachable("need to implement one of the rewrite functions!");
- }
-
- virtual ~Pattern();
-
- //===--------------------------------------------------------------------===//
- // Helper methods to simplify pattern implementations
- //===--------------------------------------------------------------------===//
-
- /// This method indicates that no match was found.
- static PatternMatchResult matchFailure() {
- // TODO: Use a proper sentinel / discriminated union instad of -1 magic
- // number.
- return {-1, std::unique_ptr<PatternState>()};
- }
-
- /// This method indicates that a match was found and has the specified cost.
- PatternMatchResult
- matchSuccess(PatternBenefit benefit,
- std::unique_ptr<PatternState> state = {}) const {
- assert((!getStaticBenefit().hasValue() ||
- getStaticBenefit().getValue() == benefit) &&
- "This version of matchSuccess must be called with a benefit that "
- "matches the static benefit if set!");
-
- return {benefit, std::move(state)};
- }
-
- /// This method indicates that a match was found for patterns that have a
- /// known static benefit.
- PatternMatchResult
- matchSuccess(std::unique_ptr<PatternState> state = {}) const {
- auto benefit = getStaticBenefit();
- assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
- return matchSuccess(benefit.getValue(), std::move(state));
- }
-
- /// This method is used as the final replacement hook for patterns that match
- /// a single result value. In addition to replacing and removing the
- /// specified operation, clients can specify a list of other nodes that this
- /// replacement may make (perhaps transitively) dead. If any of those ops are
- /// dead, this will remove them as well.
- void
- replaceSingleResultOp(Operation *op, SSAValue *newValue,
- ArrayRef<SSAValue *> opsToRemoveIfDead = {}) const {
- assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
- op->getResult(0)->replaceAllUsesWith(newValue);
-
- // TODO: This shouldn't be statement specific.
- cast<OperationStmt>(op)->eraseFromBlock();
-
- // TODO: Process the opsToRemoveIfDead list once we have side-effect
- // information. Be careful about notifying clients that this is happening
- // so they can be removed from worklists etc (needs a callback of some
- // sort).
- }
-
-protected:
- /// Patterns must specify the root operation name they match against, and can
- /// also optionally specify a static benefit of matching.
- Pattern(OperationName rootKind,
- Optional<PatternBenefit> staticBenefit = llvm::None)
- : rootKind(rootKind), staticBenefit(staticBenefit) {}
- Pattern(OperationName rootKind, unsigned staticBenefit)
- : rootKind(rootKind), staticBenefit(staticBenefit) {}
-
-private:
- const OperationName rootKind;
- const Optional<PatternBenefit> staticBenefit;
-};
-
-Pattern::~Pattern() {}
-
-//===----------------------------------------------------------------------===//
-// PatternMatcher class
-//===----------------------------------------------------------------------===//
-
-/// This class manages optimization an execution of a group of patterns, and
-/// provides an API for finding the best match against a given node.
-///
-class PatternMatcher {
-public:
- /// Create a PatternMatch with the specified set of patterns. This takes
- /// ownership of the patterns in question.
- explicit PatternMatcher(ArrayRef<Pattern *> patterns)
- : patterns(patterns.begin(), patterns.end()) {}
-
- typedef std::pair<Pattern *, std::unique_ptr<PatternState>> MatchResult;
-
- /// Find the highest benefit pattern available in the pattern set for the DAG
- /// rooted at the specified node. This returns the pattern (and any state it
- /// needs) if found, or null if there are no matches.
- MatchResult findMatch(Operation *op);
-
- ~PatternMatcher() { llvm::DeleteContainerPointers(patterns); }
-
-private:
- PatternMatcher(const PatternMatcher &) = delete;
- void operator=(const PatternMatcher &) = delete;
-
- std::vector<Pattern *> patterns;
-};
-
-/// Find the highest benefit pattern available in the pattern set for the DAG
-/// rooted at the specified node. This returns the pattern if found, or null
-/// if there are no matches.
-auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
- // TODO: This is a completely trivial implementation, expand this in the
- // future.
-
- // Keep track of the best match, the benefit of it, and any matcher specific
- // state it is maintaining.
- MatchResult bestMatch = {nullptr, nullptr};
- Optional<PatternBenefit> bestBenefit;
-
- for (auto *pattern : patterns) {
- // Ignore patterns that are for the wrong root.
- if (pattern->getRootKind() != op->getName())
- continue;
-
- // If we know the static cost of the pattern is worse than what we've
- // already found then don't run it.
- auto staticBenefit = pattern->getStaticBenefit();
- if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
- staticBenefit.getValue().getBenefit() <
- bestBenefit.getValue().getBenefit())
- continue;
-
- // Check to see if this pattern matches this node.
- auto result = pattern->match(op);
- auto benefit = result.first;
-
- // If this pattern failed to match, ignore it.
- if (benefit.isImpossibleToMatch())
- continue;
-
- // If it matched but had lower benefit than our best match so far, then
- // ignore it.
- if (bestBenefit.hasValue() &&
- benefit.getBenefit() < bestBenefit.getValue().getBenefit())
- continue;
-
- // Okay we found a match that is better than our previous one, remember it.
- bestBenefit = benefit;
- bestMatch = {pattern, std::move(result.second)};
- }
-
- // If we found any match, return it.
- return bestMatch;
-}
-
//===----------------------------------------------------------------------===//
// Definition of a few patterns for canonicalizing operations.
//===----------------------------------------------------------------------===//
--- /dev/null
+//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/PatternMatch.h"
+
+using namespace mlir;
+
+PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
+ assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
+ "This pattern match benefit is too large to represent");
+}
+
+unsigned short PatternBenefit::getBenefit() const {
+ assert(representation != ImpossibleToMatchSentinel &&
+ "Pattern doesn't match");
+ return representation;
+}
+
+bool PatternBenefit::operator==(const PatternBenefit& other) {
+ if (isImpossibleToMatch())
+ return other.isImpossibleToMatch();
+ if (other.isImpossibleToMatch())
+ return false;
+ return getBenefit() == other.getBenefit();
+}
+
+bool PatternBenefit::operator!=(const PatternBenefit& other) {
+ return !(*this == other);
+}
+
+Pattern::Pattern(OperationName rootKind, Optional<PatternBenefit> staticBenefit)
+ : rootKind(rootKind), staticBenefit(staticBenefit) {}
+
+Pattern::Pattern(OperationName rootKind, unsigned staticBenefit)
+ : rootKind(rootKind), staticBenefit(staticBenefit) {}
+
+Optional<PatternBenefit> Pattern::getStaticBenefit() const {
+ return staticBenefit;
+}
+
+OperationName Pattern::getRootKind() const { return rootKind; }
+
+void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+ // TODO: Need a generic builder.
+ MLFuncBuilder &builder) const {
+ rewrite(op, builder);
+}
+
+void Pattern::rewrite(Operation *op,
+ // TODO: Need a generic builder.
+ MLFuncBuilder &builder) const {
+ llvm_unreachable("need to implement one of the rewrite functions!");
+}
+
+/// This method indicates that no match was found.
+PatternMatchResult Pattern::matchFailure() {
+ // TODO: Use a proper sentinel / discriminated union instad of -1 magic
+ // number.
+ return {-1, std::unique_ptr<PatternState>()};
+}
+
+/// This method indicates that a match was found and has the specified cost.
+PatternMatchResult
+Pattern::matchSuccess(PatternBenefit benefit,
+ std::unique_ptr<PatternState> state) const {
+ assert((!getStaticBenefit().hasValue() ||
+ getStaticBenefit().getValue() == benefit) &&
+ "This version of matchSuccess must be called with a benefit that "
+ "matches the static benefit if set!");
+
+ return {benefit, std::move(state)};
+}
+
+/// This method indicates that a match was found for patterns that have a
+/// known static benefit.
+PatternMatchResult
+Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
+ auto benefit = getStaticBenefit();
+ assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
+ return matchSuccess(benefit.getValue(), std::move(state));
+}
+
+/// This method is used as the final replacement hook for patterns that match
+/// a single result value. In addition to replacing and removing the
+/// specified operation, clients can specify a list of other nodes that this
+/// replacement may make (perhaps transitively) dead. If any of those ops are
+/// dead, this will remove them as well.
+void Pattern::replaceSingleResultOp(
+ Operation *op, SSAValue *newValue,
+ ArrayRef<SSAValue *> opsToRemoveIfDead) const {
+ assert(op->getNumResults() == 1 && "op isn't a SingleResultOp!");
+ op->getResult(0)->replaceAllUsesWith(newValue);
+
+ // TODO: This shouldn't be statement specific.
+ cast<OperationStmt>(op)->eraseFromBlock();
+
+ // TODO: Process the opsToRemoveIfDead list once we have side-effect
+ // information. Be careful about notifying clients that this is happening
+ // so they can be removed from worklists etc (needs a callback of some
+ // sort).
+}
+
+/// Find the highest benefit pattern available in the pattern set for the DAG
+/// rooted at the specified node. This returns the pattern if found, or null
+/// if there are no matches.
+auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
+ // TODO: This is a completely trivial implementation, expand this in the
+ // future.
+
+ // Keep track of the best match, the benefit of it, and any matcher specific
+ // state it is maintaining.
+ MatchResult bestMatch = {nullptr, nullptr};
+ Optional<PatternBenefit> bestBenefit;
+
+ for (auto *pattern : patterns) {
+ // Ignore patterns that are for the wrong root.
+ if (pattern->getRootKind() != op->getName())
+ continue;
+
+ // If we know the static cost of the pattern is worse than what we've
+ // already found then don't run it.
+ auto staticBenefit = pattern->getStaticBenefit();
+ if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
+ staticBenefit.getValue().getBenefit() <
+ bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Check to see if this pattern matches this node.
+ auto result = pattern->match(op);
+ auto benefit = result.first;
+
+ // If this pattern failed to match, ignore it.
+ if (benefit.isImpossibleToMatch())
+ continue;
+
+ // If it matched but had lower benefit than our best match so far, then
+ // ignore it.
+ if (bestBenefit.hasValue() &&
+ benefit.getBenefit() < bestBenefit.getValue().getBenefit())
+ continue;
+
+ // Okay we found a match that is better than our previous one, remember it.
+ bestBenefit = benefit;
+ bestMatch = {pattern, std::move(result.second)};
+ }
+
+ // If we found any match, return it.
+ return bestMatch;
+}