nir/search: Add automaton-based pre-searching
authorConnor Abbott <cwabbott0@gmail.com>
Mon, 18 Feb 2019 13:20:34 +0000 (14:20 +0100)
committerConnor Abbott <cwabbott0@gmail.com>
Thu, 2 May 2019 14:14:06 +0000 (16:14 +0200)
nir_opt_algebraic is currently one of the most expensive NIR passes,
because of the many different patterns we've added over the years. Even
though patterns are already sorted by opcode, there are still way too
many patterns for common opcodes like bcsel and fadd, which means that
many patterns are tried but only a few actually match. One way to fix
this is to add a pre-pass over the code that scans it using an automaton
constructed beforehand, similar to the automatons produced by lex and
yacc for parsing source code. This automaton has to walk the SSA graph
and recognize possible pattern matches.

It turns out that the theory to do this is quite mature already, having
been developed for instruction selection as well as other non-compiler
things. I followed the presentation in the dissertation cited in the
code, "Tree algorithms: Two Taxonomies and a Toolkit," trying to keep
the naming similar. To create the automaton, we have to perform
something like the classical NFA to DFA subset construction used by lex,
but it turns out that actually computing the transition table for all
possible states would be way too expensive, with the dissertation
reporting times of almost half an hour for an example of size similar to
nir_opt_algebraic. Instead, we adopt one of the "filter" approaches
explained in the dissertation, which trade much faster table generation
and table size for a few more table lookups per instruction at runtime.
I chose the filter which resulted the fastest table generation time,
with medium table size. Right now, the table generation takes around .5
seconds, despite being implemented in pure Python, which I think is good
enough. Based on the numbers in the dissertation, the other choice might
make table compilation time 25x slower to get 4x smaller table size, but
I don't think that's worth it. As of now, we get the following binary
size before and after this patch:

    text   data     bss      dec    hex filename
11979455 464720  730864 13175039 c908ff before i965_dri.so
   text    data     bss     dec            hex filename
12037835 616244  791792 13445871 cd2aef after i965_dri.so

There are a number of places where I've simplified the automaton by
getting rid of details in the LHS patterns rather than complicate things
to deal with them. For example, right now the automaton doesn't
distinguish between constants with different values. This means that it
isn't as precise as it could be, but the decrease in compile time is
still worth it -- these are the compilation time numbers for a shader-db
run with my (admittedly old) database on Intel skylake:

Difference at 95.0% confidence
-42.3485 +/- 1.375
-7.20383% +/- 0.229926%
(Student's t, pooled s = 1.69843)

We can always experiment with making it more precise later.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
src/compiler/nir/nir_algebraic.py
src/compiler/nir/nir_search.c
src/compiler/nir/nir_search.h

index 4779507..6db749e 100644 (file)
@@ -51,6 +51,13 @@ conv_opcode_types = {
     'f2b' : 'bool',
 }
 
+def get_c_opcode(op):
+      if op in conv_opcode_types:
+         return 'nir_search_op_' + op
+      else:
+         return 'nir_op_' + op
+
+
 if sys.version_info < (3, 0):
     integer_types = (int, long)
     string_type = unicode
@@ -347,10 +354,7 @@ class Expression(Value):
       return self.comm_exprs
 
    def c_opcode(self):
-      if self.opcode in conv_opcode_types:
-         return 'nir_search_op_' + self.opcode
-      else:
-         return 'nir_op_' + self.opcode
+      return get_c_opcode(self.opcode)
 
    def render(self, cache):
       srcs = "\n".join(src.render(cache) for src in self.sources)
@@ -692,6 +696,266 @@ class SearchAndReplace(object):
 
       BitSizeValidator(varset).validate(self.search, self.replace)
 
+class TreeAutomaton(object):
+   """This class calculates a bottom-up tree automaton to quickly search for
+   the left-hand sides of tranforms. Tree automatons are a generalization of
+   classical NFA's and DFA's, where the transition function determines the
+   state of the parent node based on the state of its children. We construct a
+   deterministic automaton to match patterns, using a similar algorithm to the
+   classical NFA to DFA construction. At the moment, it only matches opcodes
+   and constants (without checking the actual value), leaving more detailed
+   checking to the search function which actually checks the leaves. The
+   automaton acts as a quick filter for the search function, requiring only n
+   + 1 table lookups for each n-source operation. The implementation is based
+   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
+   In the language of that reference, this is a frontier-to-root deterministic
+   automaton using only symbol filtering. The filtering is crucial to reduce
+   both the time taken to generate the tables and the size of the tables.
+   """
+   def __init__(self, transforms):
+      self.patterns = [t.search for t in transforms]
+      self._compute_items()
+      self._build_table()
+      #print('num items: {}'.format(len(set(self.items.values()))))
+      #print('num states: {}'.format(len(self.states)))
+      #for state, patterns in zip(self.states, self.patterns):
+      #   print('{}: num patterns: {}'.format(state, len(patterns)))
+
+   class IndexMap(object):
+      """An indexed list of objects, where one can either lookup an object by
+      index or find the index associated to an object quickly using a hash
+      table. Compared to a list, it has a constant time index(). Compared to a
+      set, it provides a stable iteration order.
+      """
+      def __init__(self, iterable=()):
+         self.objects = []
+         self.map = {}
+         for obj in iterable:
+            self.add(obj)
+
+      def __getitem__(self, i):
+         return self.objects[i]
+
+      def __contains__(self, obj):
+         return obj in self.map
+
+      def __len__(self):
+         return len(self.objects)
+
+      def __iter__(self):
+         return iter(self.objects)
+
+      def clear(self):
+         self.objects = []
+         self.map.clear()
+
+      def index(self, obj):
+         return self.map[obj]
+
+      def add(self, obj):
+         if obj in self.map:
+            return self.map[obj]
+         else:
+            index = len(self.objects)
+            self.objects.append(obj)
+            self.map[obj] = index
+            return index
+
+      def __repr__(self):
+         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
+
+   class Item(object):
+      """This represents an "item" in the language of "Tree Automatons." This
+      is just a subtree of some pattern, which represents a potential partial
+      match at runtime. We deduplicate them, so that identical subtrees of
+      different patterns share the same object, and store some extra
+      information needed for the main algorithm as well.
+      """
+      def __init__(self, opcode, children):
+         self.opcode = opcode
+         self.children = children
+         # These are the indices of patterns for which this item is the root node.
+         self.patterns = []
+         # This the set of opcodes for parents of this item. Used to speed up
+         # filtering.
+         self.parent_ops = set()
+
+      def __str__(self):
+         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
+
+      def __repr__(self):
+         return str(self)
+
+   def _compute_items(self):
+      """Build a set of all possible items, deduplicating them."""
+      # This is a map from (opcode, sources) to item.
+      self.items = {}
+
+      # The set of all opcodes used by the patterns. Used later to avoid
+      # building and emitting all the tables for opcodes that aren't used.
+      self.opcodes = self.IndexMap()
+
+      def get_item(opcode, children, pattern=None):
+         commutative = len(children) == 2 \
+               and "commutative" in opcodes[opcode].algebraic_properties
+         item = self.items.setdefault((opcode, children),
+                                      self.Item(opcode, children))
+         if commutative:
+            self.items[opcode, (children[1], children[0])] = item
+         if pattern is not None:
+            item.patterns.append(pattern)
+         return item
+
+      self.wildcard = get_item("__wildcard", ())
+      self.const = get_item("__const", ())
+
+      def process_subpattern(src, pattern=None):
+         if isinstance(src, Constant):
+            # Note: we throw away the actual constant value!
+            return self.const
+         elif isinstance(src, Variable):
+            if src.is_constant:
+               return self.const
+            else:
+               # Note: we throw away which variable it is here! This special
+               # item is equivalent to nu in "Tree Automatons."
+               return self.wildcard
+         else:
+            assert isinstance(src, Expression)
+            opcode = src.opcode
+            stripped = opcode.rstrip('0123456789')
+            if stripped in conv_opcode_types:
+               # Matches that use conversion opcodes with a specific type,
+               # like f2b1, are tricky.  Either we construct the automaton to
+               # match specific NIR opcodes like nir_op_f2b1, in which case we
+               # need to create separate items for each possible NIR opcode
+               # for patterns that have a generic opcode like f2b, or we
+               # construct it to match the search opcode, in which case we
+               # need to map f2b1 to f2b when constructing the automaton. Here
+               # we do the latter.
+               opcode = stripped
+            self.opcodes.add(opcode)
+            children = tuple(process_subpattern(c) for c in src.sources)
+            item = get_item(opcode, children, pattern)
+            for i, child in enumerate(children):
+               child.parent_ops.add(opcode)
+            return item
+
+      for i, pattern in enumerate(self.patterns):
+         process_subpattern(pattern, i)
+
+   def _build_table(self):
+      """This is the core algorithm which builds up the transition table. It
+      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
+      Comp_a and Filt_{a,i} using integers to identify match sets." It
+      simultaneously builds up a list of all possible "match sets" or
+      "states", where each match set represents the set of Item's that match a
+      given instruction, and builds up the transition table between states.
+      """
+      # Map from opcode + filtered state indices to transitioned state.
+      self.table = defaultdict(dict)
+      # Bijection from state to index. q in the original algorithm is
+      # len(self.states)
+      self.states = self.IndexMap()
+      # List of pattern matches for each state index.
+      self.state_patterns = []
+      # Map from state index to filtered state index for each opcode.
+      self.filter = defaultdict(list)
+      # Bijections from filtered state to filtered state index for each
+      # opcode, called the "representor sets" in the original algorithm.
+      # q_{a,j} in the original algorithm is len(self.rep[op]).
+      self.rep = defaultdict(self.IndexMap)
+
+      # Everything in self.states with a index at least worklist_index is part
+      # of the worklist of newly created states. There is also a worklist of
+      # newly fitered states for each opcode, for which worklist_indices
+      # serves a similar purpose. worklist_index corresponds to p in the
+      # original algorithm, while worklist_indices is p_{a,j} (although since
+      # we only filter by opcode/symbol, it's really just p_a).
+      self.worklist_index = 0
+      worklist_indices = defaultdict(lambda: 0)
+
+      # This is the set of opcodes for which the filtered worklist is non-empty.
+      # It's used to avoid scanning opcodes for which there is nothing to
+      # process when building the transition table. It corresponds to new_a in
+      # the original algorithm.
+      new_opcodes = self.IndexMap()
+
+      # Process states on the global worklist, filtering them for each opcode,
+      # updating the filter tables, and updating the filtered worklists if any
+      # new filtered states are found. Similar to ComputeRepresenterSets() in
+      # the original algorithm, although that only processes a single state.
+      def process_new_states():
+         while self.worklist_index < len(self.states):
+            state = self.states[self.worklist_index]
+
+            # Calculate pattern matches for this state. Each pattern is
+            # assigned to a unique item, so we don't have to worry about
+            # deduplicating them here. However, we do have to sort them so
+            # that they're visited at runtime in the order they're specified
+            # in the source.
+            patterns = list(sorted(p for item in state for p in item.patterns))
+            assert len(self.state_patterns) == self.worklist_index
+            self.state_patterns.append(patterns)
+
+            # calculate filter table for this state, and update filtered
+            # worklists.
+            for op in self.opcodes:
+               filt = self.filter[op]
+               rep = self.rep[op]
+               filtered = frozenset(item for item in state if \
+                  op in item.parent_ops)
+               if filtered in rep:
+                  rep_index = rep.index(filtered)
+               else:
+                  rep_index = rep.add(filtered)
+                  new_opcodes.add(op)
+               assert len(filt) == self.worklist_index
+               filt.append(rep_index)
+            self.worklist_index += 1
+
+      # There are two start states: one which can only match as a wildcard,
+      # and one which can match as a wildcard or constant. These will be the
+      # states of intrinsics/other instructions and load_const instructions,
+      # respectively. The indices of these must match the definitions of
+      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
+      # initialize things correctly.
+      self.states.add(frozenset((self.wildcard,)))
+      self.states.add(frozenset((self.const,self.wildcard)))
+      process_new_states()
+
+      while len(new_opcodes) > 0:
+         for op in new_opcodes:
+            rep = self.rep[op]
+            table = self.table[op]
+            op_worklist_index = worklist_indices[op]
+            if op in conv_opcode_types:
+               num_srcs = 1
+            else:
+               num_srcs = opcodes[op].num_inputs
+
+            # Iterate over all possible source combinations where at least one
+            # is on the worklist.
+            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
+               if all(src_idx < op_worklist_index for src_idx in src_indices):
+                  continue
+
+               srcs = tuple(rep[src_idx] for src_idx in src_indices)
+
+               # Try all possible pairings of source items and add the
+               # corresponding parent items. This is Comp_a from the paper.
+               parent = set(self.items[op, item_srcs] for item_srcs in
+                  itertools.product(*srcs) if (op, item_srcs) in self.items)
+
+               # We could always start matching something else with a
+               # wildcard. This is Cl from the paper.
+               parent.add(self.wildcard)
+
+               table[src_indices] = self.states.add(frozenset(parent))
+            worklist_indices[op] = len(rep)
+         new_opcodes.clear()
+         process_new_states()
+
 _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
 #include "nir_builder.h"
@@ -707,6 +971,19 @@ struct transform {
    unsigned condition_offset;
 };
 
+struct per_op_table {
+   const uint16_t *filter;
+   unsigned num_filtered_states;
+   const uint16_t *table;
+};
+
+/* Note: these must match the start states created in
+ * TreeAutomaton._build_table()
+ */
+
+/* WILDCARD_STATE = 0 is set by zeroing the state array */
+static const uint16_t CONST_STATE = 1;
+
 #endif
 
 <% cache = {} %>
@@ -715,17 +992,80 @@ struct transform {
    ${xform.replace.render(cache)}
 % endfor
 
-% for (opcode, xform_list) in sorted(opcode_xforms.items()):
-static const struct transform ${pass_name}_${opcode}_xforms[] = {
-% for xform in xform_list:
-   { ${xform.search.c_ptr(cache)}, ${xform.replace.c_value_ptr(cache)}, ${xform.condition_index} },
+% for state_id, state_xforms in enumerate(automaton.state_patterns):
+static const struct transform ${pass_name}_state${state_id}_xforms[] = {
+% for i in state_xforms:
+  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
 % endfor
 };
 % endfor
 
+static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
+% for op in automaton.opcodes:
+   [${get_c_opcode(op)}] = {
+      .filter = (uint16_t []) {
+      % for e in automaton.filter[op]:
+         ${e},
+      % endfor
+      },
+      <%
+        num_filtered = len(automaton.rep[op])
+      %>
+      .num_filtered_states = ${num_filtered},
+      .table = (uint16_t []) {
+      <%
+        num_srcs = len(next(iter(automaton.table[op])))
+      %>
+      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
+         ${automaton.table[op][indices]},
+      % endfor
+      },
+   },
+% endfor
+};
+
+static void
+${pass_name}_pre_block(nir_block *block, uint16_t *states)
+{
+   nir_foreach_instr(instr, block) {
+      switch (instr->type) {
+      case nir_instr_type_alu: {
+         nir_alu_instr *alu = nir_instr_as_alu(instr);
+         nir_op op = alu->op;
+         uint16_t search_op = nir_search_op_for_nir_op(op);
+         const struct per_op_table *tbl = &${pass_name}_table[search_op];
+         if (tbl->num_filtered_states == 0)
+            continue;
+
+         /* Calculate the index into the transition table. Note the index
+          * calculated must match the iteration order of Python's
+          * itertools.product(), which was used to emit the transition
+          * table.
+          */
+         uint16_t index = 0;
+         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+            index *= tbl->num_filtered_states;
+            index += tbl->filter[states[alu->src[i].src.ssa->index]];
+         }
+         states[alu->dest.dest.ssa.index] = tbl->table[index];
+         break;
+      }
+
+      case nir_instr_type_load_const: {
+         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+         states[load_const->def.index] = CONST_STATE;
+         break;
+      }
+
+      default:
+         break;
+      }
+   }
+}
+
 static bool
 ${pass_name}_block(nir_builder *build, nir_block *block,
-                   const bool *condition_flags)
+                   const uint16_t *states, const bool *condition_flags)
 {
    bool progress = false;
 
@@ -737,11 +1077,11 @@ ${pass_name}_block(nir_builder *build, nir_block *block,
       if (!alu->dest.dest.is_ssa)
          continue;
 
-      switch (alu->op) {
-      % for opcode in sorted(opcode_xforms.keys()):
-      case nir_op_${opcode}:
-         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
-            const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
+      switch (states[alu->dest.dest.ssa.index]) {
+% for i in range(len(automaton.state_patterns)):
+      case ${i}:
+         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
+            const struct transform *xform = &${pass_name}_state${i}_xforms[i];
             if (condition_flags[xform->condition_offset] &&
                 nir_replace_instr(build, alu, xform->search, xform->replace)) {
                progress = true;
@@ -749,9 +1089,8 @@ ${pass_name}_block(nir_builder *build, nir_block *block,
             }
          }
          break;
-      % endfor
-      default:
-         break;
+% endfor
+      default: assert(0);
       }
    }
 
@@ -766,10 +1105,22 @@ ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
    nir_builder build;
    nir_builder_init(&build, impl);
 
+   /* Note: it's important here that we're allocating a zeroed array, since
+    * state 0 is the default state, which means we don't have to visit
+    * anything other than constants and ALU instructions.
+    */
+   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
+
+   nir_foreach_block(block, impl) {
+      ${pass_name}_pre_block(block, states);
+   }
+
    nir_foreach_block_reverse(block, impl) {
-      progress |= ${pass_name}_block(&build, block, condition_flags);
+      progress |= ${pass_name}_block(&build, block, states, condition_flags);
    }
 
+   free(states);
+
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_block_index |
                                   nir_metadata_dominance);
@@ -806,6 +1157,8 @@ ${pass_name}(nir_shader *shader)
 }
 """)
 
+      
+
 class AlgebraicPass(object):
    def __init__(self, pass_name, transforms):
       self.xforms = []
@@ -835,6 +1188,8 @@ class AlgebraicPass(object):
          else:
             self.opcode_xforms[xform.search.opcode].append(xform)
 
+      self.automaton = TreeAutomaton(self.xforms)
+
       if error:
          sys.exit(1)
 
@@ -843,4 +1198,7 @@ class AlgebraicPass(object):
       return _algebraic_pass_template.render(pass_name=self.pass_name,
                                              xforms=self.xforms,
                                              opcode_xforms=self.opcode_xforms,
-                                             condition_list=condition_list)
+                                             condition_list=condition_list,
+                                             automaton=self.automaton,
+                                             get_c_opcode=get_c_opcode,
+                                             itertools=itertools)
index df27a24..c8acdfb 100644 (file)
@@ -134,6 +134,50 @@ nir_op_matches_search_op(nir_op nop, uint16_t sop)
 
 #undef MATCH_FCONV_CASE
 #undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
+}
+
+uint16_t
+nir_search_op_for_nir_op(nir_op nop)
+{
+#define MATCH_FCONV_CASE(op) \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_ICONV_CASE(op) \
+   case nir_op_##op##8: \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_BCONV_CASE(op) \
+   case nir_op_##op##1: \
+   case nir_op_##op##32: \
+      return nir_search_op_##op;
+
+
+   switch (nop) {
+   MATCH_FCONV_CASE(i2f)
+   MATCH_FCONV_CASE(u2f)
+   MATCH_FCONV_CASE(f2f)
+   MATCH_ICONV_CASE(f2u)
+   MATCH_ICONV_CASE(f2i)
+   MATCH_ICONV_CASE(u2u)
+   MATCH_ICONV_CASE(i2i)
+   MATCH_FCONV_CASE(b2f)
+   MATCH_ICONV_CASE(b2i)
+   MATCH_BCONV_CASE(i2b)
+   MATCH_BCONV_CASE(f2b)
+   default:
+      return nop;
+   }
+
+#undef MATCH_FCONV_CASE
+#undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
 }
 
 static nir_op
@@ -187,6 +231,7 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size)
 
 #undef RET_FCONV_CASE
 #undef RET_ICONV_CASE
+#undef RET_BCONV_CASE
 }
 
 static bool
index 9dc09d2..526a498 100644 (file)
@@ -121,8 +121,11 @@ enum nir_search_op {
    nir_search_op_b2i,
    nir_search_op_i2b,
    nir_search_op_f2b,
+   nir_num_search_ops,
 };
 
+uint16_t nir_search_op_for_nir_op(nir_op op);
+
 typedef struct {
    nir_search_value value;