- Simplify PatternMatch to *require* static benefits at pattern construction
authorChris Lattner <clattner@google.com>
Sun, 11 Nov 2018 23:56:49 +0000 (15:56 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:54:38 +0000 (13:54 -0700)
  time.  The "Fast and Flexible Instruction Selection With Constraints" paper
  from CC2018 makes a credible argument that dynamic costs aren't actually
  necessary/important, and we are not using them.

- Check in my "MLIR Generic DAG Rewriter Infrastructure" design doc into the
  source tree.

PiperOrigin-RevId: 221017546

mlir/g3doc/GenericDAGRewriter.md [new file with mode: 0644]
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/StandardOps/StandardOps.cpp

diff --git a/mlir/g3doc/GenericDAGRewriter.md b/mlir/g3doc/GenericDAGRewriter.md
new file mode 100644 (file)
index 0000000..dbeea21
--- /dev/null
@@ -0,0 +1,415 @@
+# MLIR Generic DAG Rewriter Infrastructure
+
+## Introduction and Motivation
+
+The goal of a compiler IR is to represent code - at various levels of
+abstraction which pose different sets of tradeoffs in terms of representational
+capabilities and ease of transformation. However, the ability to represent code
+is not itself very useful - you also need to be able to implement those
+transformations.
+
+There are many different sorts of compiler transformations, but this document
+focuses on a particularly important class of transformation that comes up
+repeatedly at scale, and is important for the immediate goals of MLIR: that of
+pattern matching on a set of operations and replacing with another set. This is
+the key algorithm required to implement the "op fission" algorithm used by the
+tf2xla bridge, pattern matching rewrites from TF ops to TF/Lite, peephole
+optimizations like "eliminate identity nodes" or "replace x+0 with x", as well
+as a useful abstraction to implement optimization algorithms for MLIR graphs at
+all levels.
+
+A particular strength of MLIR (and a major difference vs other compiler
+infrastructures like LLVM, GCC, XLA, TensorFlow, etc) is that it uses a single
+compiler IR to represent code at multiple levels of abstraction: an MLIR
+operation can be a "TensorFlow operation", an "XLA HLO", a "TF Lite
+FlatBufferModel op", a TPU LLO instruction, an LLVM IR instruction (transitively
+including X86, Lanai, CUDA, and other target specific instructions), or anything
+else that the MLIR type system can reasonably express. Because MLIR spans such a
+wide range of different problems, a single infrastructure for performing
+graph-to-graph rewrites can help solve many diverse domain challenges, including
+TensorFlow graph level down to the machine code level.
+
+[Static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form)
+(SSA) representations like MLIR make it easy to access the operands and "users"
+of an operation. As such, a natural abstraction for these graph-to-graph
+rewrites is that of DAG pattern matching: clients define DAG tile patterns, and
+each pattern includes a result DAG to produce and the cost of the result (or,
+inversely, the benefit of doing the replacement). A common infrastructure
+efficiently finds and perform the rewrites.
+
+While this concept is simple, the details are more nuanced. This proposal
+defines and explores a set of abstractions that we feel can solve a wide range
+of different problems, and can be applied to many different sorts of problems
+that MLIR is - and is expected to - face over time. We do this by separating the
+pattern definition and matching algorithm from the "driver" of the computation
+loop, and make space for the patterns to be defined declaratively in the future.
+
+## Related Work
+
+There is a huge amount of related work to consider, given that pretty much every
+compiler in existence has to solve this problem many times over. Here are a few
+graph rewrite systems we have used, along with the pros and cons of this related
+work. One unifying problem with all of these is that these systems are only
+trying to solve one particular and usually narrow problem: our proposal would
+like to solve many of these problems with a single infrastructure. Of these, the
+most similar design to our proposal is the LLVM DAG-to-DAG instruction selection
+algorithm at the end.
+
+### Constant folding
+
+A degenerate but pervasive case of DAG-to-DAG pattern matching is constant
+folding: given an operation whose operands contain constants can often be folded
+to a result constant value.
+
+MLIR already has constant folding routines which provide a simpler API than a
+general DAG-to-DAG pattern matcher, and we expect it to remain because the
+simpler contract makes it applicable in some cases that a generic matcher would
+not. For example, a DAG-rewrite can remove arbitrary nodes in the current
+function, which could invalidate iterators. Constant folding as an API does not
+remove any nodes, it just provides a (list of) constant values and allows the
+clients to update their data structures as necessary.
+
+### AST-Level Pattern Matchers
+
+The literature is full of source-to-source translators which transform
+identities in order to improve performance (e.g. transforming `X*0` into `0`).
+One large example that I'm aware of is the GCC `fold` function, which performs
+[many optimizations](https://github.com/gcc-mirror/gcc/blob/master/gcc/fold-const.c)
+on ASTs. Clang has
+[similar routines](http://releases.llvm.org/3.5.0/tools/clang/docs/InternalsManual.html#constant-folding-in-the-clang-ast)
+for simple constant folding of expressions (as required by the C++ standard) but
+doesn't perform general optimizations on its ASTs.
+
+The primary downside of tree optimizers are that you can't see across operations
+that have multiple uses. It is
+[well known in literature](https://llvm.org/pubs/2008-06-LCTES-ISelUsingSSAGraphs.pdf)
+that DAG pattern matching is more powerful than tree pattern matching, but OTOH,
+DAG pattern matching can lead to duplication of computation which needs to be
+checked for.
+
+### "Combiners" and other peephole optimizers
+
+Compilers end up with a lot of peephole optimizers for various things, e.g. the
+GCC
+["combine" routines](https://github.com/gcc-mirror/gcc/blob/master/gcc/combine.c)
+(which try to merge two machine instructions into a single one), the LLVM
+[Inst Combine](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/)
+[pass](https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions),
+LLVM's
+[DAG Combiner](https://github.com/llvm-mirror/llvm/blob/master/lib/CodeGen/SelectionDAG/DAGCombiner.cpp),
+the Swift compiler's
+[SIL Combiner](https://github.com/apple/swift/tree/master/lib/SILOptimizer/SILCombiner),
+etc. These generally match one or more operations and produce zero or more
+operations as a result. The LLVM
+[Legalization](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/)
+infrastructure has a different outer loop but otherwise works the same way.
+
+These passes have a lot of diversity, but also having a unifying structure: they
+mostly have a worklist outer loop which visits operations. They then use the C++
+visitor pattern (or equivalent) to switch over the class of operation and
+dispatch to a method. That method contains a long list of hand-written C++ code
+that pattern-matches various special cases. LLVM introduced a "match" function
+that allows writing patterns in a somewhat more declarative style using template
+metaprogramming (MLIR has similar facilities). Here's a simple example:
+
+```c++
+  // Y - (X + 1) --> ~X + Y
+  if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
+    return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
+```
+
+Here is a somewhat more complicated one (this is not the biggest or most
+complicated :)
+
+```c++
+  // C2 is ODD
+  // LHS = XOR(Y,C1), Y = AND(Z,C2), C1==(C2+1) => LHS == NEG(OR(Z, ~C2))
+  // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
+  if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
+    if (C1->countTrailingZeros() == 0)
+      if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
+        Value *NewOr = Builder.CreateOr(Z, ~(*C2));
+        return Builder.CreateSub(RHS, NewOr, "sub");
+      }
+```
+
+These systems are simple to set up, and pattern matching templates have some
+advantages (they are extensible for new sorts of sub-patterns, look compact at
+point of use). OTOH, they have lots of well known problems, for example:
+
+*   These patterns are very error prone to write, and contain lots of
+    redundancies.
+*   The IR being matched often has identities (e.g. when matching commutative
+    operators) and the C++ code has to handle it manually - take a look at
+    [the full code](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp?view=markup#l775)
+    for checkForNegativeOperand that defines the second pattern).
+*   The matching code compiles slowly, both because it generates tons of code
+    and because the templates instantiate slowly.
+*   Adding new patterns (e.g. for count leading zeros in the example above) is
+    awkward and doesn't often happen.
+*   The cost model for these patterns is not really defined - it is emergent
+    based on the order the patterns are matched in code.
+*   They are non-extensible without rebuilding the compiler.
+*   It isn't practical to apply theorem provers and other tools to these
+    patterns - they cannot be reused for other purposes.
+
+In addition to structured "combiners" like these, there are lots of ad-hoc
+systems like the
+[LLVM Machine code peephole optimizer](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/PeepholeOptimizer.cpp?view=markup)
+which are related.
+
+### LLVM's DAG-to-DAG Instruction Selection Infrastructure
+
+The instruction selection subsystem in LLVM is the result of many years worth of
+iteration and discovery, driven by the need for LLVM to support code generation
+for lots of targets, the complexity of code generators for modern instruction
+sets (e.g. X86), and the fanatical pursuit of reusing code across targets. Eli
+wrote a
+[nice short overview](https://eli.thegreenplace.net/2013/02/25/a-deeper-look-into-the-llvm-code-generator-part-1)
+of how this works, and the
+[LLVM documentation](https://llvm.org/docs/CodeGenerator.html#select-instructions-from-dag)
+describes it in more depth including its advantages and limitations. It allows
+writing patterns like this.
+
+```
+def : Pat<(or GR64:$src, (not (add GR64:$src, 1))),
+          (BLCI64rr GR64:$src)>;
+```
+
+This example defines a matcher for the
+["blci" instruction](https://en.wikipedia.org/wiki/Bit_Manipulation_Instruction_Sets#TBM_\(Trailing_Bit_Manipulation\))
+in the
+[X86 target description](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrInfo.td?view=markup),
+there are many others in that file (look for `Pat<>` patterns, since they aren't
+entangled in details of the compiler like assembler/disassembler generation
+logic).
+
+For our purposes, there is much to like about this system, for example:
+
+*   It is defined in a declarative format.
+*   It is extensible to target-defined operations.
+*   It automates matching across identities, like commutative patterns.
+*   It allows custom abstractions and intense factoring of target-specific
+    commonalities.
+*   It generates compact code - it compiles into a state machine, which is
+    interpreted.
+*   It allows the instruction patterns to be defined and reused for multiple
+    purposes.
+*   The patterns are "type checked" at compile time, detecting lots of bugs
+    early and eliminating redundancy from the pattern specifications.
+*   It allows the use of general C++ code for weird/complex cases.
+
+While there is a lot that is good here, there is also a lot of bad things:
+
+*   All of this machinery is only applicable to instruction selection. Even
+    directly adjacent problems like the DAGCombiner and Legalizer can't use it.
+*   This isn't extensible at compiler runtime, you have to rebuild the compiler
+    to extend it.
+*   The error messages when failing to match a pattern
+    [are not exactly optimal](https://www.google.com/search?q=llvm+cannot+select).
+*   It has lots of implementation problems and limitations (e.g. can't write a
+    pattern for a multi-result operation) as a result of working with the
+    awkward SelectionDAG representation and being designed and implemented
+    lazily.
+*   This stuff all grew organically over time and has lots of sharp edges.
+
+### Summary
+
+MLIR will face a wide range of pattern matching and graph rewrite problems, and
+one of the major advantages of having a common representation for code at
+multiple levels that it allows us to invest in - and highly leverage - a single
+infra for doing this sort of work.
+
+## Goals
+
+This proposal includes support for defining pattern matching and rewrite
+algorithms on MLIR. We'd like these algorithms to encompass many problems in the
+MLIR space, including 1-to-N expansions (e.g. as seen in the TF/XLA bridge when
+lowering a "tf.AddN" to multiple "add" HLOs), M-to-1 patterns (as seen in
+Grappler optimization passes, e.g. that convert multiple/add into a single
+muladd op), as well as general M-to-N patterns (e.g. instruction selection for
+target instructions). Patterns should have a cost associated with them, and the
+common infrastructure should be responsible for sorting out the lowest cost
+match for a given application.
+
+We separate the task of picking a particular locally optimal pattern from a
+given root node, the algorithm used to rewrite an entire graph given a
+particular set of goals, and the definition of the patterns themselves. We do
+this because DAG tile pattern matching is NP complete, which means that there
+are no known polynomial time algorithms to optimally solve this problem.
+Additionally, we would like to support iterative rewrite algorithms that
+progressively transform the input program through multiple steps. Furthermore,
+we would like to support many different sorts of clients across the MLIR stack,
+and they may have different tolerances for compile time cost, different demands
+for optimality, and other algorithmic goals or constraints.
+
+We aim for MLIR transformations to be easy to implement and reduce the
+likelihood for compiler bugs. We expect there to be a very very large number of
+patterns that are defined over time, and we believe that these sorts of patterns
+will have a very large number of legality/validity constraints - many of which
+are difficult to reason about in a consistent way, may be target specific, and
+whose implementation may be particularly bugpone. As such, we aim to design the
+API around pattern definition to be simple, resilient to programmer errors, and
+allow separation of concerns between the legality of the nodes generated from
+the idea of the pattern being defined.
+
+Finally, error handling is a topmost concern: in addition to allowing patterns
+to be defined in a target-independent way that may not apply for all hardware,
+we also want failure for any pattern to match to be diagnosable in a reasonable
+way. To be clear, this is not a solvable problem in general - the space of
+malfunction is too great to be fully enumerated and handled optimally, but there
+are better and worse ways to handle the situation. MLIR is already designed to
+represent the provenance of an operation well. This project aims to propagate
+that provenance information precisely, as well as diagnose pattern match
+failures with the rationale for why a set of patterns do not apply.
+
+### Non goals
+
+This proposal doesn't aim to solve all compiler problems, it is simply a
+DAG-to-DAG pattern matching system, starting with a greedy driver algorithm.
+Compiler algorithms that require global dataflow analysis (e.g. common
+subexpression elimination, conditional constant propagation, and many many
+others) will not be directly solved by this infrastructure.
+
+This proposal is limited to DAG patterns, which (by definition) prevent the
+patterns from seeing across cycles in a graph. In an SSA-based IR like MLIR,
+this means that these patterns don't see across PHI nodes / basic block
+arguments. We consider this acceptable given the set of problems we are trying
+to solve - we don't know of any other system that attempts to do so, and
+consider the payoff of worrying about this to be low.
+
+This design includes the ability for DAG patterns to have associated costs
+(benefits), but those costs are defined in terms of magic numbers (typically
+equal to the number of nodes being replaced). For any given application, the
+units of magic numbers will have to be defined.
+
+## Overall design
+
+We decompose the problem into four major pieces:
+
+1.  the code that is used to define patterns to match, cost, and their
+    replacement actions
+1.  the driver logic to pick the best match for a given root node
+1.  the client that is implementing some transformation (e.g. a combiner)
+1.  (future) the subsystem that allows patterns to be described with a
+    declarative syntax, which sugars step #1.
+
+We sketch the first three of these pieces, each in turn. This is not intended to
+be a concrete API proposal, merely to describe the design
+
+### Defining Patterns
+
+Each pattern will be an instance of a mlir::Pattern class, whose subclasses
+implement methods like this. Note that this API is meant for exposition, the
+actual details are different for efficiency and coding standards reasons (e.g.
+the memory management of `PatternState` is not specified below, etc):
+
+```c++
+class Pattern {
+  /// Return the benefit (the inverse of "cost") of matching this pattern.  The
+  /// benefit of a Pattern is always static - rewrites that may have dynamic
+  /// benefit can be instantiated multiple times (different Pattern instances)
+  /// for each benefit that they may return, and be guarded by different match
+  /// condition predicates.
+  PatternBenefit getBenefit() const { return benefit; }
+
+  /// Return the root node that this pattern matches.  Patterns that can
+  /// match multiple root types are instantiated once per root.
+  OperationName getRootKind() const { return rootKind; }
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().  On failure, this
+  /// returns a None value.  On success it a (possibly null) pattern-specific
+  /// state wrapped in a Some.  This state is passed back into its rewrite
+  /// function if this match is selected.
+  virtual Optional<PatternState*> match(Operation *op) const = 0;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
+  virtual void rewrite(Operation *op, PatternState *state,
+                       PatternRewriter &rewriter) const;
+};
+```
+
+In practice, the first patterns we implement will directly subclass and
+implement this stuff, but we will define some helpers to reduce boilerplate.
+When we have a declarative way to describe patterns, this should be
+automatically generated from the description.
+
+Instances of `Pattern` have a benefit that is static upon construction of the
+pattern instance, but may be computed dynamically at pattern initialization
+time, e.g. allowing the benefit to be derived from domain specific information,
+like the target architecture). This limitation allows us MLIR to (eventually)
+perform pattern fusion and compile patterns into an efficient state machine, and
+[Thier, Ertl, and Krall](https://dl.acm.org/citation.cfm?id=3179501) have shown
+that match predicates eliminate the need for dynamically computed costs in
+almost all cases: you can simply instantiate the same pattern one time for each
+possible cost and use the predicate to guard the match.
+
+The two phase nature of this API (match separate from rewrite) is important for
+two reasons: 1) some clients may want to explore different ways to tile the
+graph, and only rewrite after committing to one tiling. 2) We want to support
+runtime extensibility of the pattern sets, but want to be able to statically
+compile the bulk of known patterns into a state machine at "compiler compile
+time". Both of these reasons lead to us needing to match multiple patterns
+before committing to an answer.
+
+### Picking and performing a replacement
+
+In the short term, this API can be very simple, something like this can work and
+will be useful for many clients:
+
+```c++
+class PatternMatcher {
+   // Create a pattern matcher with a bunch of patterns.  This constructor
+   // looks across all of the specified patterns, and builds an internal
+   // data structure that allows efficient matching.
+   PatternMatcher(ArrayRef<Pattern*> patterns);
+
+   // Given a specific operation, see if there is some rewrite that is
+   // interesting.  If so, return success and return the list of new
+   // operations that were created.  If not, return failure.
+   bool matchAndRewrite(Operation *op,
+                        SmallVectorImpl<Operation*> &newlyCreatedOps);
+};
+```
+
+In practice the interesting part of this class is the acceleration structure it
+builds internally. It buckets up the patterns by root operation, and sorts them
+by their static benefit. When performing a match, it tests any dynamic patterns,
+then tests statically known patterns from highest to lowest benefit.
+
+### First Client: A Greedy Worklist Combiner
+
+We expect that there will be lots of clients for this, but a simple greedy
+worklist-driven combiner should be powerful enough to serve many important ones,
+including the
+[TF2XLA op expansion logic](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels),
+many of the pattern substitution passes of the
+[TOCO compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco)
+for TF-Lite, many
+[Grappler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/grappler)
+passes, and other general performance optimizations for applying identities.
+
+The structure of this algorithm is straight-forward, here is pseudo code:
+
+*   Walk a function in preorder, adding each operation to a worklist.
+*   While the worklist is non-empty, pull something off the back (processing
+    things generally in postorder)
+    *   Perform matchAndRewrite on the operation. If failed, continue to the
+        next operation.
+    *   On success, add the newly created ops to the worklist and continue.
+
+## Future directions
+
+It is important to get implementation and usage experience with this, and many
+patterns can be defined using this sort of framework. Over time, we can look to
+make it easier to declare patterns in a declarative form (e.g. with the LLVM
+tblgen tool or something newer/better). Once we have that, we can define an
+internal abstraction for describing the patterns to match, allowing better high
+level optimization of patterns (including fusion of the matching logic across
+patterns, which the LLVM instruction selector does) and allow the patterns to be
+defined without rebuilding the compiler itself.
index ddcb2d7f9c25161766f6c901cb85c8b1ac434bd9..d2bbc2d265295a61e6312c8ac483de31f94102a3 100644 (file)
@@ -73,12 +73,10 @@ protected:
   PatternState() {}
 };
 
-/// This is the type returned by a pattern match.  The first field indicates
-/// the benefit of the match, the second is a state token that can optionally
-/// be produced by a pattern match to maintain state between the match and
-/// rewrite phases.
-using PatternMatchResult =
-    std::pair<PatternBenefit, std::unique_ptr<PatternState>>;
+/// This is the type returned by a pattern match.  A match failure returns a
+/// None value.  A match success returns a Some value with any state the pattern
+/// may need to maintain (but may also be null).
+using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
 
 //===----------------------------------------------------------------------===//
 // Pattern class
@@ -86,40 +84,41 @@ using PatternMatchResult =
 
 class Pattern {
 public:
-  // Return the benefit (the inverse of "cost") of matching this pattern,
-  // if it is statically determinable.  The result is a PatternBenefit if known,
-  // or 'None' if the cost is dynamically computed.
-  Optional<PatternBenefit> getStaticBenefit() const;
+  /// Return the benefit (the inverse of "cost") of matching this pattern.  The
+  /// benefit of a Pattern is always static - rewrites that may have dynamic
+  /// benefit can be instantiated multiple times (different Pattern instances)
+  /// for each benefit that they may return, and be guarded by different match
+  /// condition predicates.
+  PatternBenefit getBenefit() const { return benefit; }
 
-  // Return the root node that this pattern matches.  Patterns that can
-  // match multiple root types are instantiated once per root.
-  OperationName getRootKind() const;
+  /// Return the root node that this pattern matches.  Patterns that can
+  /// match multiple root types are instantiated once per root.
+  OperationName getRootKind() const { return rootKind; }
 
   //===--------------------------------------------------------------------===//
   // Implementation hooks for patterns to implement.
   //===--------------------------------------------------------------------===//
 
-  // Attempt to match against code rooted at the specified operation,
-  // which is the same operation code as getRootKind().  On success it
-  // returns the benefit of the match along with an (optional)
-  // pattern-specific state which is passed back into its rewrite
-  // function if this match is selected.  On failure, this returns a
-  // sentinel indicating that it didn’t match.
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().  On failure, this
+  /// returns a None value.  On success it a (possibly null) pattern-specific
+  /// state wrapped in a Some.  This state is passed back into its rewrite
+  /// function if this match is selected.
   virtual PatternMatchResult match(Operation *op) const = 0;
 
-  // Rewrite the IR rooted at the specified operation with the result of
-  // this pattern, generating any new operations with the specified
-  // builder.  If an unexpected error is encountered (an internal
-  // compiler error), it is emitted through the normal MLIR diagnostic
-  // hooks and the IR is left in a valid state.
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
   virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
                        PatternRewriter &rewriter) const;
 
-  // Rewrite the IR rooted at the specified operation with the result of
-  // this pattern, generating any new operations with the specified
-  // builder.  If an unexpected error is encountered (an internal
-  // compiler error), it is emitted through the normal MLIR diagnostic
-  // hooks and the IR is left in a valid state.
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// builder.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
   virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
 
   virtual ~Pattern() {}
@@ -129,29 +128,22 @@ public:
   //===--------------------------------------------------------------------===//
 
   /// This method indicates that no match was found.
-  static PatternMatchResult matchFailure();
+  static PatternMatchResult matchFailure() { return None; }
 
   /// This method indicates that a match was found and has the specified cost.
   PatternMatchResult
-  matchSuccess(PatternBenefit benefit,
-               std::unique_ptr<PatternState> state = {}) const;
-
-  /// This method indicates that a match was found for patterns that have a
-  /// known static benefit.
-  PatternMatchResult
-  matchSuccess(std::unique_ptr<PatternState> state = {}) const;
+  matchSuccess(std::unique_ptr<PatternState> state = {}) const {
+    return PatternMatchResult(std::move(state));
+  }
 
 protected:
   /// Patterns must specify the root operation name they match against, and can
-  /// also optionally specify a static benefit of matching.
-  Pattern(StringRef rootName, MLIRContext *context,
-          Optional<PatternBenefit> staticBenefit = llvm::None);
-
-  Pattern(StringRef rootName, MLIRContext *context, unsigned staticBenefit);
+  /// also specify the benefit of the pattern matching.
+  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
 
 private:
   const OperationName rootKind;
-  const Optional<PatternBenefit> staticBenefit;
+  const PatternBenefit benefit;
 };
 
 //===----------------------------------------------------------------------===//
index 01a5e57686ace061a0be35a1787f8e5e72dfac79..a6a9c1ddaedfec114dd9e194402612f2307c19af 100644 (file)
@@ -47,20 +47,9 @@ bool PatternBenefit::operator!=(const PatternBenefit& other) {
 // Pattern implementation
 //===----------------------------------------------------------------------===//
 
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
-                 Optional<PatternBenefit> staticBenefit)
-    : rootKind(OperationName(rootName, context)), staticBenefit(staticBenefit) {
-}
-
-Pattern::Pattern(StringRef rootName, MLIRContext *context,
-                 unsigned staticBenefit)
-    : rootKind(rootName, context), staticBenefit(staticBenefit) {}
-
-Optional<PatternBenefit> Pattern::getStaticBenefit() const {
-  return staticBenefit;
-}
-
-OperationName Pattern::getRootKind() const { return rootKind; }
+Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
+                 MLIRContext *context)
+    : rootKind(OperationName(rootName, context)), benefit(benefit) {}
 
 void Pattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
                       PatternRewriter &rewriter) const {
@@ -71,32 +60,6 @@ void Pattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
   llvm_unreachable("need to implement one of the rewrite functions!");
 }
 
-/// This method indicates that no match was found.
-PatternMatchResult Pattern::matchFailure() {
-  return {PatternBenefit::impossibleToMatch(), std::unique_ptr<PatternState>()};
-}
-
-/// This method indicates that a match was found and has the specified cost.
-PatternMatchResult
-Pattern::matchSuccess(PatternBenefit benefit,
-                      std::unique_ptr<PatternState> state) const {
-  assert((!getStaticBenefit().hasValue() ||
-          getStaticBenefit().getValue() == benefit) &&
-         "This version of matchSuccess must be called with a benefit that "
-         "matches the static benefit if set!");
-
-  return {benefit, std::move(state)};
-}
-
-/// This method indicates that a match was found for patterns that have a
-/// known static benefit.
-PatternMatchResult
-Pattern::matchSuccess(std::unique_ptr<PatternState> state) const {
-  auto benefit = getStaticBenefit();
-  assert(benefit.hasValue() && "Pattern doesn't have a static benefit");
-  return matchSuccess(benefit.getValue(), std::move(state));
-}
-
 //===----------------------------------------------------------------------===//
 // PatternRewriter implementation
 //===----------------------------------------------------------------------===//
@@ -163,31 +126,26 @@ auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
     if (pattern->getRootKind() != op->getName())
       continue;
 
-    // If we know the static cost of the pattern is worse than what we've
-    // already found then don't run it.
-    auto staticBenefit = pattern->getStaticBenefit();
-    if (staticBenefit.hasValue() && bestBenefit.hasValue() &&
-        staticBenefit.getValue().getBenefit() <
-            bestBenefit.getValue().getBenefit())
+    auto benefit = pattern->getBenefit();
+    if (benefit.isImpossibleToMatch())
+      continue;
+
+    // If the benefit of the pattern is worse than what we've already found then
+    // don't run it.
+    if (bestBenefit.hasValue() &&
+        benefit.getBenefit() < bestBenefit.getValue().getBenefit())
       continue;
 
     // Check to see if this pattern matches this node.
     auto result = pattern->match(op);
-    auto benefit = result.first;
 
     // If this pattern failed to match, ignore it.
-    if (benefit.isImpossibleToMatch())
-      continue;
-
-    // If it matched but had lower benefit than our best match so far, then
-    // ignore it.
-    if (bestBenefit.hasValue() &&
-        benefit.getBenefit() < bestBenefit.getValue().getBenefit())
+    if (!result)
       continue;
 
     // Okay we found a match that is better than our previous one, remember it.
     bestBenefit = benefit;
-    bestMatch = {pattern.get(), std::move(result.second)};
+    bestMatch = {pattern.get(), std::move(result.getValue())};
   }
 
   // If we found any match, return it.
index 13dac7b03295050e17f505ef021abab672533238..10f62548abbfac06e8ce3f68924a0d3371e3fce0 100644 (file)
@@ -55,10 +55,9 @@ namespace {
 struct MemRefCastFolder : public Pattern {
   /// The rootOpName is the name of the root operation to match against.
   MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
-      : Pattern(rootOpName, context, 1) {}
+      : Pattern(rootOpName, 1, context) {}
 
-  std::pair<PatternBenefit, std::unique_ptr<PatternState>>
-  match(Operation *op) const override {
+  PatternMatchResult match(Operation *op) const override {
     for (auto *operand : op->getOperands())
       if (matchPattern(operand, m_Op<MemRefCastOp>()))
         return matchSuccess();
@@ -113,10 +112,9 @@ namespace {
 ///
 struct SimplifyAddX0 : public Pattern {
   SimplifyAddX0(MLIRContext *context)
-      : Pattern(AddIOp::getOperationName(), context, 1) {}
+      : Pattern(AddIOp::getOperationName(), 1, context) {}
 
-  std::pair<PatternBenefit, std::unique_ptr<PatternState>>
-  match(Operation *op) const override {
+  PatternMatchResult match(Operation *op) const override {
     auto addi = op->cast<AddIOp>();
 
     if (matchPattern(addi->getOperand(1), m_Zero()))
@@ -220,10 +218,9 @@ namespace {
 /// Fold constant dimensions into an alloc instruction.
 struct SimplifyAllocConst : public Pattern {
   SimplifyAllocConst(MLIRContext *context)
-      : Pattern(AllocOp::getOperationName(), context, 1) {}
+      : Pattern(AllocOp::getOperationName(), 1, context) {}
 
-  std::pair<PatternBenefit, std::unique_ptr<PatternState>>
-  match(Operation *op) const override {
+  PatternMatchResult match(Operation *op) const override {
     auto alloc = op->cast<AllocOp>();
 
     // Check to see if any dimensions operands are constants.  If so, we can
@@ -1059,10 +1056,9 @@ namespace {
 ///
 struct SimplifyMulX1 : public Pattern {
   SimplifyMulX1(MLIRContext *context)
-      : Pattern(MulIOp::getOperationName(), context, 1) {}
+      : Pattern(MulIOp::getOperationName(), 1, context) {}
 
-  std::pair<PatternBenefit, std::unique_ptr<PatternState>>
-  match(Operation *op) const override {
+  PatternMatchResult match(Operation *op) const override {
     auto muli = op->cast<MulIOp>();
 
     if (matchPattern(muli->getOperand(1), m_One()))
@@ -1192,10 +1188,9 @@ namespace {
 ///
 struct SimplifyXMinusX : public Pattern {
   SimplifyXMinusX(MLIRContext *context)
-      : Pattern(SubIOp::getOperationName(), context, 1) {}
+      : Pattern(SubIOp::getOperationName(), 1, context) {}
 
-  std::pair<PatternBenefit, std::unique_ptr<PatternState>>
-  match(Operation *op) const override {
+  PatternMatchResult match(Operation *op) const override {
     auto subi = op->cast<SubIOp>();
     if (subi->getOperand(0) == subi->getOperand(1))
       return matchSuccess();