nir: Add support for 2src_commutative ops that have 3 sources
authorIan Romanick <ian.d.romanick@intel.com>
Thu, 9 May 2019 22:33:11 +0000 (15:33 -0700)
committerIan Romanick <ian.d.romanick@intel.com>
Tue, 14 May 2019 18:25:02 +0000 (11:25 -0700)
v2: Instead of handling 3 sources as a special case, generalize with
loops to N sources.  Suggested by Jason.

v3: Further generalize by only checking that number of sources is >= 2.
Suggested by Jason.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
src/compiler/nir/nir_algebraic.py
src/compiler/nir/nir_instr_set.c
src/compiler/nir/nir_search.c

index d945c1a..aa4e977 100644 (file)
@@ -796,12 +796,12 @@ class TreeAutomaton(object):
       self.opcodes = self.IndexMap()
 
       def get_item(opcode, children, pattern=None):
-         commutative = len(children) == 2 \
+         commutative = len(children) >= 2 \
                and "2src_commutative" in opcodes[opcode].algebraic_properties
          item = self.items.setdefault((opcode, children),
                                       self.Item(opcode, children))
          if commutative:
-            self.items[opcode, (children[1], children[0])] = item
+            self.items[opcode, (children[1], children[0]) + children[2:]] = item
          if pattern is not None:
             item.patterns.append(pattern)
          return item
index c6a69d3..80c0312 100644 (file)
@@ -57,7 +57,8 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
    /* We explicitly don't hash instr->dest.dest.exact */
 
    if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
-      assert(nir_op_infos[instr->op].num_inputs == 2);
+      assert(nir_op_infos[instr->op].num_inputs >= 2);
+
       uint32_t hash0 = hash_alu_src(hash, &instr->src[0],
                                     nir_ssa_alu_instr_src_components(instr, 0));
       uint32_t hash1 = hash_alu_src(hash, &instr->src[1],
@@ -69,6 +70,11 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
        * collision.  Either addition or multiplication will also work.
        */
       hash = hash0 * hash1;
+
+      for (unsigned i = 2; i < nir_op_infos[instr->op].num_inputs; i++) {
+         hash = hash_alu_src(hash, &instr->src[i],
+                             nir_ssa_alu_instr_src_components(instr, i));
+      }
    } else {
       for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
          hash = hash_alu_src(hash, &instr->src[i],
@@ -529,11 +535,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2)
       /* We explicitly don't hash instr->dest.dest.exact */
 
       if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
-         assert(nir_op_infos[alu1->op].num_inputs == 2);
-         return (nir_alu_srcs_equal(alu1, alu2, 0, 0) &&
-                 nir_alu_srcs_equal(alu1, alu2, 1, 1)) ||
-                (nir_alu_srcs_equal(alu1, alu2, 0, 1) &&
-                 nir_alu_srcs_equal(alu1, alu2, 1, 0));
+         if ((!nir_alu_srcs_equal(alu1, alu2, 0, 0) ||
+              !nir_alu_srcs_equal(alu1, alu2, 1, 1)) &&
+             (!nir_alu_srcs_equal(alu1, alu2, 0, 1) ||
+              !nir_alu_srcs_equal(alu1, alu2, 1, 0)))
+            return false;
+
+         for (unsigned i = 2; i < nir_op_infos[alu1->op].num_inputs; i++) {
+            if (!nir_alu_srcs_equal(alu1, alu2, i, i))
+               return false;
+         }
       } else {
          for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
             if (!nir_alu_srcs_equal(alu1, alu2, i, i))
index 6d3fbf7..3ddda7c 100644 (file)
@@ -408,7 +408,11 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
 
    bool matched = true;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
+      /* 2src_commutative instructions that have 3 sources are only commutative
+       * in the first two sources.  Source 2 is always source 2.
+       */
+      if (!match_value(expr->srcs[i], instr,
+                       i < 2 ? i ^ comm_op_flip : i,
                        num_components, swizzle, state)) {
          matched = false;
          break;