--- /dev/null
+# MLIR Generic DAG Rewriter Infrastructure
+
+## Introduction and Motivation
+
+The goal of a compiler IR is to represent code - at various levels of
+abstraction which pose different sets of tradeoffs in terms of representational
+capabilities and ease of transformation. However, the ability to represent code
+is not itself very useful - you also need to be able to implement those
+transformations.
+
+There are many different sorts of compiler transformations, but this document
+focuses on a particularly important class of transformation that comes up
+repeatedly at scale, and is important for the immediate goals of MLIR: that of
+pattern matching on a set of operations and replacing with another set. This is
+the key algorithm required to implement the "op fission" algorithm used by the
+tf2xla bridge, pattern matching rewrites from TF ops to TF/Lite, peephole
+optimizations like "eliminate identity nodes" or "replace x+0 with x", as well
+as a useful abstraction to implement optimization algorithms for MLIR graphs at
+all levels.
+
+A particular strength of MLIR (and a major difference vs other compiler
+infrastructures like LLVM, GCC, XLA, TensorFlow, etc) is that it uses a single
+compiler IR to represent code at multiple levels of abstraction: an MLIR
+operation can be a "TensorFlow operation", an "XLA HLO", a "TF Lite
+FlatBufferModel op", a TPU LLO instruction, an LLVM IR instruction (transitively
+including X86, Lanai, CUDA, and other target specific instructions), or anything
+else that the MLIR type system can reasonably express. Because MLIR spans such a
+wide range of different problems, a single infrastructure for performing
+graph-to-graph rewrites can help solve many diverse domain challenges, including
+TensorFlow graph level down to the machine code level.
+
+[Static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form)
+(SSA) representations like MLIR make it easy to access the operands and "users"
+of an operation. As such, a natural abstraction for these graph-to-graph
+rewrites is that of DAG pattern matching: clients define DAG tile patterns, and
+each pattern includes a result DAG to produce and the cost of the result (or,
+inversely, the benefit of doing the replacement). A common infrastructure
+efficiently finds and perform the rewrites.
+
+While this concept is simple, the details are more nuanced. This proposal
+defines and explores a set of abstractions that we feel can solve a wide range
+of different problems, and can be applied to many different sorts of problems
+that MLIR is - and is expected to - face over time. We do this by separating the
+pattern definition and matching algorithm from the "driver" of the computation
+loop, and make space for the patterns to be defined declaratively in the future.
+
+## Related Work
+
+There is a huge amount of related work to consider, given that pretty much every
+compiler in existence has to solve this problem many times over. Here are a few
+graph rewrite systems we have used, along with the pros and cons of this related
+work. One unifying problem with all of these is that these systems are only
+trying to solve one particular and usually narrow problem: our proposal would
+like to solve many of these problems with a single infrastructure. Of these, the
+most similar design to our proposal is the LLVM DAG-to-DAG instruction selection
+algorithm at the end.
+
+### Constant folding
+
+A degenerate but pervasive case of DAG-to-DAG pattern matching is constant
+folding: given an operation whose operands contain constants can often be folded
+to a result constant value.
+
+MLIR already has constant folding routines which provide a simpler API than a
+general DAG-to-DAG pattern matcher, and we expect it to remain because the
+simpler contract makes it applicable in some cases that a generic matcher would
+not. For example, a DAG-rewrite can remove arbitrary nodes in the current
+function, which could invalidate iterators. Constant folding as an API does not
+remove any nodes, it just provides a (list of) constant values and allows the
+clients to update their data structures as necessary.
+
+### AST-Level Pattern Matchers
+
+The literature is full of source-to-source translators which transform
+identities in order to improve performance (e.g. transforming `X*0` into `0`).
+One large example that I'm aware of is the GCC `fold` function, which performs
+[many optimizations](https://github.com/gcc-mirror/gcc/blob/master/gcc/fold-const.c)
+on ASTs. Clang has
+[similar routines](http://releases.llvm.org/3.5.0/tools/clang/docs/InternalsManual.html#constant-folding-in-the-clang-ast)
+for simple constant folding of expressions (as required by the C++ standard) but
+doesn't perform general optimizations on its ASTs.
+
+The primary downside of tree optimizers are that you can't see across operations
+that have multiple uses. It is
+[well known in literature](https://llvm.org/pubs/2008-06-LCTES-ISelUsingSSAGraphs.pdf)
+that DAG pattern matching is more powerful than tree pattern matching, but OTOH,
+DAG pattern matching can lead to duplication of computation which needs to be
+checked for.
+
+### "Combiners" and other peephole optimizers
+
+Compilers end up with a lot of peephole optimizers for various things, e.g. the
+GCC
+["combine" routines](https://github.com/gcc-mirror/gcc/blob/master/gcc/combine.c)
+(which try to merge two machine instructions into a single one), the LLVM
+[Inst Combine](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/)
+[pass](https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions),
+LLVM's
+[DAG Combiner](https://github.com/llvm-mirror/llvm/blob/master/lib/CodeGen/SelectionDAG/DAGCombiner.cpp),
+the Swift compiler's
+[SIL Combiner](https://github.com/apple/swift/tree/master/lib/SILOptimizer/SILCombiner),
+etc. These generally match one or more operations and produce zero or more
+operations as a result. The LLVM
+[Legalization](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/)
+infrastructure has a different outer loop but otherwise works the same way.
+
+These passes have a lot of diversity, but also having a unifying structure: they
+mostly have a worklist outer loop which visits operations. They then use the C++
+visitor pattern (or equivalent) to switch over the class of operation and
+dispatch to a method. That method contains a long list of hand-written C++ code
+that pattern-matches various special cases. LLVM introduced a "match" function
+that allows writing patterns in a somewhat more declarative style using template
+metaprogramming (MLIR has similar facilities). Here's a simple example:
+
+```c++
+ // Y - (X + 1) --> ~X + Y
+ if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
+ return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
+```
+
+Here is a somewhat more complicated one (this is not the biggest or most
+complicated :)
+
+```c++
+ // C2 is ODD
+ // LHS = XOR(Y,C1), Y = AND(Z,C2), C1==(C2+1) => LHS == NEG(OR(Z, ~C2))
+ // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
+ if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
+ if (C1->countTrailingZeros() == 0)
+ if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
+ Value *NewOr = Builder.CreateOr(Z, ~(*C2));
+ return Builder.CreateSub(RHS, NewOr, "sub");
+ }
+```
+
+These systems are simple to set up, and pattern matching templates have some
+advantages (they are extensible for new sorts of sub-patterns, look compact at
+point of use). OTOH, they have lots of well known problems, for example:
+
+* These patterns are very error prone to write, and contain lots of
+ redundancies.
+* The IR being matched often has identities (e.g. when matching commutative
+ operators) and the C++ code has to handle it manually - take a look at
+ [the full code](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp?view=markup#l775)
+ for checkForNegativeOperand that defines the second pattern).
+* The matching code compiles slowly, both because it generates tons of code
+ and because the templates instantiate slowly.
+* Adding new patterns (e.g. for count leading zeros in the example above) is
+ awkward and doesn't often happen.
+* The cost model for these patterns is not really defined - it is emergent
+ based on the order the patterns are matched in code.
+* They are non-extensible without rebuilding the compiler.
+* It isn't practical to apply theorem provers and other tools to these
+ patterns - they cannot be reused for other purposes.
+
+In addition to structured "combiners" like these, there are lots of ad-hoc
+systems like the
+[LLVM Machine code peephole optimizer](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/PeepholeOptimizer.cpp?view=markup)
+which are related.
+
+### LLVM's DAG-to-DAG Instruction Selection Infrastructure
+
+The instruction selection subsystem in LLVM is the result of many years worth of
+iteration and discovery, driven by the need for LLVM to support code generation
+for lots of targets, the complexity of code generators for modern instruction
+sets (e.g. X86), and the fanatical pursuit of reusing code across targets. Eli
+wrote a
+[nice short overview](https://eli.thegreenplace.net/2013/02/25/a-deeper-look-into-the-llvm-code-generator-part-1)
+of how this works, and the
+[LLVM documentation](https://llvm.org/docs/CodeGenerator.html#select-instructions-from-dag)
+describes it in more depth including its advantages and limitations. It allows
+writing patterns like this.
+
+```
+def : Pat<(or GR64:$src, (not (add GR64:$src, 1))),
+ (BLCI64rr GR64:$src)>;
+```
+
+This example defines a matcher for the
+["blci" instruction](https://en.wikipedia.org/wiki/Bit_Manipulation_Instruction_Sets#TBM_\(Trailing_Bit_Manipulation\))
+in the
+[X86 target description](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrInfo.td?view=markup),
+there are many others in that file (look for `Pat<>` patterns, since they aren't
+entangled in details of the compiler like assembler/disassembler generation
+logic).
+
+For our purposes, there is much to like about this system, for example:
+
+* It is defined in a declarative format.
+* It is extensible to target-defined operations.
+* It automates matching across identities, like commutative patterns.
+* It allows custom abstractions and intense factoring of target-specific
+ commonalities.
+* It generates compact code - it compiles into a state machine, which is
+ interpreted.
+* It allows the instruction patterns to be defined and reused for multiple
+ purposes.
+* The patterns are "type checked" at compile time, detecting lots of bugs
+ early and eliminating redundancy from the pattern specifications.
+* It allows the use of general C++ code for weird/complex cases.
+
+While there is a lot that is good here, there is also a lot of bad things:
+
+* All of this machinery is only applicable to instruction selection. Even
+ directly adjacent problems like the DAGCombiner and Legalizer can't use it.
+* This isn't extensible at compiler runtime, you have to rebuild the compiler
+ to extend it.
+* The error messages when failing to match a pattern
+ [are not exactly optimal](https://www.google.com/search?q=llvm+cannot+select).
+* It has lots of implementation problems and limitations (e.g. can't write a
+ pattern for a multi-result operation) as a result of working with the
+ awkward SelectionDAG representation and being designed and implemented
+ lazily.
+* This stuff all grew organically over time and has lots of sharp edges.
+
+### Summary
+
+MLIR will face a wide range of pattern matching and graph rewrite problems, and
+one of the major advantages of having a common representation for code at
+multiple levels that it allows us to invest in - and highly leverage - a single
+infra for doing this sort of work.
+
+## Goals
+
+This proposal includes support for defining pattern matching and rewrite
+algorithms on MLIR. We'd like these algorithms to encompass many problems in the
+MLIR space, including 1-to-N expansions (e.g. as seen in the TF/XLA bridge when
+lowering a "tf.AddN" to multiple "add" HLOs), M-to-1 patterns (as seen in
+Grappler optimization passes, e.g. that convert multiple/add into a single
+muladd op), as well as general M-to-N patterns (e.g. instruction selection for
+target instructions). Patterns should have a cost associated with them, and the
+common infrastructure should be responsible for sorting out the lowest cost
+match for a given application.
+
+We separate the task of picking a particular locally optimal pattern from a
+given root node, the algorithm used to rewrite an entire graph given a
+particular set of goals, and the definition of the patterns themselves. We do
+this because DAG tile pattern matching is NP complete, which means that there
+are no known polynomial time algorithms to optimally solve this problem.
+Additionally, we would like to support iterative rewrite algorithms that
+progressively transform the input program through multiple steps. Furthermore,
+we would like to support many different sorts of clients across the MLIR stack,
+and they may have different tolerances for compile time cost, different demands
+for optimality, and other algorithmic goals or constraints.
+
+We aim for MLIR transformations to be easy to implement and reduce the
+likelihood for compiler bugs. We expect there to be a very very large number of
+patterns that are defined over time, and we believe that these sorts of patterns
+will have a very large number of legality/validity constraints - many of which
+are difficult to reason about in a consistent way, may be target specific, and
+whose implementation may be particularly bugpone. As such, we aim to design the
+API around pattern definition to be simple, resilient to programmer errors, and
+allow separation of concerns between the legality of the nodes generated from
+the idea of the pattern being defined.
+
+Finally, error handling is a topmost concern: in addition to allowing patterns
+to be defined in a target-independent way that may not apply for all hardware,
+we also want failure for any pattern to match to be diagnosable in a reasonable
+way. To be clear, this is not a solvable problem in general - the space of
+malfunction is too great to be fully enumerated and handled optimally, but there
+are better and worse ways to handle the situation. MLIR is already designed to
+represent the provenance of an operation well. This project aims to propagate
+that provenance information precisely, as well as diagnose pattern match
+failures with the rationale for why a set of patterns do not apply.
+
+### Non goals
+
+This proposal doesn't aim to solve all compiler problems, it is simply a
+DAG-to-DAG pattern matching system, starting with a greedy driver algorithm.
+Compiler algorithms that require global dataflow analysis (e.g. common
+subexpression elimination, conditional constant propagation, and many many
+others) will not be directly solved by this infrastructure.
+
+This proposal is limited to DAG patterns, which (by definition) prevent the
+patterns from seeing across cycles in a graph. In an SSA-based IR like MLIR,
+this means that these patterns don't see across PHI nodes / basic block
+arguments. We consider this acceptable given the set of problems we are trying
+to solve - we don't know of any other system that attempts to do so, and
+consider the payoff of worrying about this to be low.
+
+This design includes the ability for DAG patterns to have associated costs
+(benefits), but those costs are defined in terms of magic numbers (typically
+equal to the number of nodes being replaced). For any given application, the
+units of magic numbers will have to be defined.
+
+## Overall design
+
+We decompose the problem into four major pieces:
+
+1. the code that is used to define patterns to match, cost, and their
+ replacement actions
+1. the driver logic to pick the best match for a given root node
+1. the client that is implementing some transformation (e.g. a combiner)
+1. (future) the subsystem that allows patterns to be described with a
+ declarative syntax, which sugars step #1.
+
+We sketch the first three of these pieces, each in turn. This is not intended to
+be a concrete API proposal, merely to describe the design
+
+### Defining Patterns
+
+Each pattern will be an instance of a mlir::Pattern class, whose subclasses
+implement methods like this. Note that this API is meant for exposition, the
+actual details are different for efficiency and coding standards reasons (e.g.
+the memory management of `PatternState` is not specified below, etc):
+
+```c++
+class Pattern {
+ /// Return the benefit (the inverse of "cost") of matching this pattern. The
+ /// benefit of a Pattern is always static - rewrites that may have dynamic
+ /// benefit can be instantiated multiple times (different Pattern instances)
+ /// for each benefit that they may return, and be guarded by different match
+ /// condition predicates.
+ PatternBenefit getBenefit() const { return benefit; }
+
+ /// Return the root node that this pattern matches. Patterns that can
+ /// match multiple root types are instantiated once per root.
+ OperationName getRootKind() const { return rootKind; }
+
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). On failure, this
+ /// returns a None value. On success it a (possibly null) pattern-specific
+ /// state wrapped in a Some. This state is passed back into its rewrite
+ /// function if this match is selected.
+ virtual Optional<PatternState*> match(Operation *op) const = 0;
+
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter. If an unexpected error is encountered (an internal
+ /// compiler error), it is emitted through the normal MLIR diagnostic
+ /// hooks and the IR is left in a valid state.
+ virtual void rewrite(Operation *op, PatternState *state,
+ PatternRewriter &rewriter) const;
+};
+```
+
+In practice, the first patterns we implement will directly subclass and
+implement this stuff, but we will define some helpers to reduce boilerplate.
+When we have a declarative way to describe patterns, this should be
+automatically generated from the description.
+
+Instances of `Pattern` have a benefit that is static upon construction of the
+pattern instance, but may be computed dynamically at pattern initialization
+time, e.g. allowing the benefit to be derived from domain specific information,
+like the target architecture). This limitation allows us MLIR to (eventually)
+perform pattern fusion and compile patterns into an efficient state machine, and
+[Thier, Ertl, and Krall](https://dl.acm.org/citation.cfm?id=3179501) have shown
+that match predicates eliminate the need for dynamically computed costs in
+almost all cases: you can simply instantiate the same pattern one time for each
+possible cost and use the predicate to guard the match.
+
+The two phase nature of this API (match separate from rewrite) is important for
+two reasons: 1) some clients may want to explore different ways to tile the
+graph, and only rewrite after committing to one tiling. 2) We want to support
+runtime extensibility of the pattern sets, but want to be able to statically
+compile the bulk of known patterns into a state machine at "compiler compile
+time". Both of these reasons lead to us needing to match multiple patterns
+before committing to an answer.
+
+### Picking and performing a replacement
+
+In the short term, this API can be very simple, something like this can work and
+will be useful for many clients:
+
+```c++
+class PatternMatcher {
+ // Create a pattern matcher with a bunch of patterns. This constructor
+ // looks across all of the specified patterns, and builds an internal
+ // data structure that allows efficient matching.
+ PatternMatcher(ArrayRef<Pattern*> patterns);
+
+ // Given a specific operation, see if there is some rewrite that is
+ // interesting. If so, return success and return the list of new
+ // operations that were created. If not, return failure.
+ bool matchAndRewrite(Operation *op,
+ SmallVectorImpl<Operation*> &newlyCreatedOps);
+};
+```
+
+In practice the interesting part of this class is the acceleration structure it
+builds internally. It buckets up the patterns by root operation, and sorts them
+by their static benefit. When performing a match, it tests any dynamic patterns,
+then tests statically known patterns from highest to lowest benefit.
+
+### First Client: A Greedy Worklist Combiner
+
+We expect that there will be lots of clients for this, but a simple greedy
+worklist-driven combiner should be powerful enough to serve many important ones,
+including the
+[TF2XLA op expansion logic](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels),
+many of the pattern substitution passes of the
+[TOCO compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco)
+for TF-Lite, many
+[Grappler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/grappler)
+passes, and other general performance optimizations for applying identities.
+
+The structure of this algorithm is straight-forward, here is pseudo code:
+
+* Walk a function in preorder, adding each operation to a worklist.
+* While the worklist is non-empty, pull something off the back (processing
+ things generally in postorder)
+ * Perform matchAndRewrite on the operation. If failed, continue to the
+ next operation.
+ * On success, add the newly created ops to the worklist and continue.
+
+## Future directions
+
+It is important to get implementation and usage experience with this, and many
+patterns can be defined using this sort of framework. Over time, we can look to
+make it easier to declare patterns in a declarative form (e.g. with the LLVM
+tblgen tool or something newer/better). Once we have that, we can define an
+internal abstraction for describing the patterns to match, allowing better high
+level optimization of patterns (including fusion of the matching logic across
+patterns, which the LLVM instruction selector does) and allow the patterns to be
+defined without rebuilding the compiler itself.
PatternState() {}
};
-/// This is the type returned by a pattern match. The first field indicates
-/// the benefit of the match, the second is a state token that can optionally
-/// be produced by a pattern match to maintain state between the match and
-/// rewrite phases.
-using PatternMatchResult =
- std::pair<PatternBenefit, std::unique_ptr<PatternState>>;
+/// This is the type returned by a pattern match. A match failure returns a
+/// None value. A match success returns a Some value with any state the pattern
+/// may need to maintain (but may also be null).
+using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
//===----------------------------------------------------------------------===//
// Pattern class
class Pattern {
public:
- // Return the benefit (the inverse of "cost") of matching this pattern,
- // if it is statically determinable. The result is a PatternBenefit if known,
- // or 'None' if the cost is dynamically computed.
- Optional<PatternBenefit> getStaticBenefit() const;
+ /// Return the benefit (the inverse of "cost") of matching this pattern. The
+ /// benefit of a Pattern is always static - rewrites that may have dynamic
+ /// benefit can be instantiated multiple times (different Pattern instances)
+ /// for each benefit that they may return, and be guarded by different match
+ /// condition predicates.
+ PatternBenefit getBenefit() const { return benefit; }
- // Return the root node that this pattern matches. Patterns that can
- // match multiple root types are instantiated once per root.
- OperationName getRootKind() const;
+ /// Return the root node that this pattern matches. Patterns that can
+ /// match multiple root types are instantiated once per root.
+ OperationName getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
- // Attempt to match against code rooted at the specified operation,
- // which is the same operation code as getRootKind(). On success it
- // returns the benefit of the match along with an (optional)
- // pattern-specific state which is passed back into its rewrite
- // function if this match is selected. On failure, this returns a
- // sentinel indicating that it didn’t match.
+ /// Attempt to match against code rooted at the specified operation,
+ /// which is the same operation code as getRootKind(). On failure, this
+ /// returns a None value. On success it a (possibly null) pattern-specific
+ /// state wrapped in a Some. This state is passed back into its rewrite
+ /// function if this match is selected.
virtual PatternMatchResult match(Operation *op) const = 0;
- // Rewrite the IR rooted at the specified operation with the result of
- // this pattern, generating any new operations with the specified
- // builder. If an unexpected error is encountered (an internal
- // compiler error), it is emitted through the normal MLIR diagnostic
- // hooks and the IR is left in a valid state.
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// rewriter. If an unexpected error is encountered (an internal
+ /// compiler error), it is emitted through the normal MLIR diagnostic
+ /// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const;
- // Rewrite the IR rooted at the specified operation with the result of
- // this pattern, generating any new operations with the specified
- // builder. If an unexpected error is encountered (an internal
- // compiler error), it is emitted through the normal MLIR diagnostic
- // hooks and the IR is left in a valid state.
+ /// Rewrite the IR rooted at the specified operation with the result of
+ /// this pattern, generating any new operations with the specified
+ /// builder. If an unexpected error is encountered (an internal
+ /// compiler error), it is emitted through the normal MLIR diagnostic
+ /// hooks and the IR is left in a valid state.
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
virtual ~Pattern() {}
//===--------------------------------------------------------------------===//
/// This method indicates that no match was found.
- static PatternMatchResult matchFailure();
+ static PatternMatchResult matchFailure() { return None; }
/// This method indicates that a match was found and has the specified cost.
PatternMatchResult
- matchSuccess(PatternBenefit benefit,
- std::unique_ptr<PatternState> state = {}) const;
-
- /// This method indicates that a match was found for patterns that have a
- /// known static benefit.
- PatternMatchResult
- matchSuccess(std::unique_ptr<PatternState> state = {}) const;
+ matchSuccess(std::unique_ptr<PatternState> state = {}) const {
+ return PatternMatchResult(std::move(state));
+ }
protected:
/// Patterns must specify the root operation name they match against, and can
- /// also optionally specify a static benefit of matching.
- Pattern(StringRef rootName, MLIRContext *context,
- Optional<PatternBenefit> staticBenefit = llvm::None);
-
- Pattern(StringRef rootName, MLIRContext *context, unsigned staticBenefit);
+ /// also specify the benefit of the pattern matching.
+ Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
private:
const OperationName rootKind;
- const Optional<PatternBenefit> staticBenefit;
+ const PatternBenefit benefit;
};
//===----------------------------------------------------------------------===//