Introduce ordered_inputs option to graph_matcher to allow simpler matching of commuta...
authorSuharsh Sivakumar <suharshs@google.com>
Fri, 11 May 2018 17:51:24 +0000 (10:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 17:54:15 +0000 (10:54 -0700)
#18919

PiperOrigin-RevId: 196276502

tensorflow/contrib/quantize/python/graph_matcher.py
tensorflow/contrib/quantize/python/graph_matcher_test.py
tensorflow/contrib/quantize/python/quantize.py

index bacc707..aa3ca99 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
+import itertools
 
 
 class Pattern(object):
@@ -33,7 +34,7 @@ class Pattern(object):
 class OpTypePattern(Pattern):
   """A tree pattern that matches TF expressions with certain op types."""
 
-  def __init__(self, op_type, name=None, inputs=None):
+  def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True):
     """Initializes an OpTypePattern.
 
     Args:
@@ -48,16 +49,25 @@ class OpTypePattern(Pattern):
       inputs: Optional list of `Pattern`s or strings that specify the
         patterns for the inputs of a matching op. If None, this pattern accepts
         any inputs of a matching op.
+      ordered_inputs: Defaults to True. If False, will match any op that
+        matches a permutation of the inputs.
+
+    Raises:
+      ValueError: if too many inputs are provided when order_inputs is False.
     """
     self._op_type = op_type
     self._name = name
     if inputs is None:
       inputs = []
+    if len(inputs) > 8:
+      raise ValueError(
+          'Only < 8 inputs are allowed when ordered_inputs is False.')
     self._inputs = [
         input_pattern
         if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
         for input_pattern in inputs
     ]
+    self._ordered_inputs = ordered_inputs
 
   @property
   def name(self):
@@ -78,12 +88,23 @@ class OpTypePattern(Pattern):
     if len(op.inputs) != len(self._inputs):
       return None
 
-    for input_tensor, input_pattern in zip(op.inputs, self._inputs):
-      input_match_result = input_pattern.match(input_tensor.op, input_tensor)
-      if input_match_result is None:
-        return None
-      match_result.merge_from(input_match_result)
-    return match_result
+    input_patterns_list = [self._inputs]
+    # If order doesn't matter for the inputs, then make sure we match at least
+    # one permutation of the inputs.
+    if not self._ordered_inputs:
+      input_patterns_list = list(itertools.permutations(self._inputs))
+
+    for input_patterns in input_patterns_list:
+      match_failed = False
+      for input_tensor, input_pattern in zip(op.inputs, input_patterns):
+        input_match_result = input_pattern.match(input_tensor.op, input_tensor)
+        if input_match_result is None:
+          match_failed = True
+          break
+        match_result.merge_from(input_match_result)
+      if not match_failed:
+        return match_result
+    return None
 
 
 class OneofPattern(Pattern):
index 6d58757..be74164 100644 (file)
@@ -22,6 +22,7 @@ from tensorflow.contrib.framework.python import ops as contrib_ops
 from tensorflow.contrib.layers.python.layers import initializers
 from tensorflow.contrib.layers.python.layers import layers
 from tensorflow.contrib.quantize.python import graph_matcher
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
@@ -163,6 +164,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase):
       self.assertEqual(match_result.get_tensor('slice'), slicing)
       self.assertEqual(match_result.get_op('transpose'), transpose.op)
 
+  def test_ordered_pattern(self):
+    #   +            +
+    #  / \          / \
+    # x   y  and   y   x  should both match when ordered inputs is False.
+    # Even when x and y are different operations.
+    g = ops.Graph()
+    with g.as_default():
+      x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
+      y = constant_op.constant(1.0, dtype=dtypes.float32)
+      plus = x + y
+
+    add_pattern_a = graph_matcher.OpTypePattern(
+        'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False)
+    add_pattern_b = graph_matcher.OpTypePattern(
+        'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False)
+    add_pattern_fail = graph_matcher.OpTypePattern(
+        'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True)
+    # Both add_pattern_a and add_pattern_b should match the graph since
+    # ordered_input was set False.
+    matcher_a = graph_matcher.GraphMatcher(add_pattern_a)
+    self.assertEqual([
+        match_result.get_op(add_pattern_a)
+        for match_result in matcher_a.match_graph(g)
+    ], [plus.op])
+    matcher_b = graph_matcher.GraphMatcher(add_pattern_b)
+    self.assertEqual([
+        match_result.get_op(add_pattern_b)
+        for match_result in matcher_b.match_graph(g)
+    ], [plus.op])
+    # But if ordered_inputs is True, the inputs list match should fail if not
+    # specified in the right order.
+    matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail)
+    self.assertEqual(
+        len([
+            match_result.get_op(add_pattern_fail)
+            for match_result in matcher_fail.match_graph(g)
+        ]), 0)
+
 
 if __name__ == '__main__':
   googletest.main()
index 60616ea..4e0de24 100644 (file)
@@ -233,37 +233,37 @@ def _FindLayersToQuantize(graph):
               weight_identity_pattern, weight_resource_var_pattern,
               folded_weight_pattern
           ])
-      ])
+      ],
+      ordered_inputs=False)
 
   folded_bias_mul_pattern = graph_matcher.OpTypePattern(
-      'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
+      'Mul',
+      inputs=[graph_matcher.OpTypePattern('*'), layer_pattern],
+      ordered_inputs=False)
   post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
-      'Add', inputs=[folded_bias_mul_pattern,
-                     graph_matcher.OpTypePattern('*')])
+      'Add',
+      inputs=[folded_bias_mul_pattern,
+              graph_matcher.OpTypePattern('*')],
+      ordered_inputs=False)
   folded_bias_add_pattern = graph_matcher.OpTypePattern(
       'Add',
       inputs=[
           post_layer_op_correction_pattern,
           graph_matcher.OpTypePattern('*')
-      ])
+      ],
+      ordered_inputs=False)
 
   bias_add_pattern = graph_matcher.OpTypePattern(
-      'Add|BiasAdd', inputs=[layer_pattern, '*'])
+      'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False)
 
   # The bias can come from the bias add or the folded bias add.
-  bypass_pattern_a = graph_matcher.OpTypePattern(
+  bypass_pattern = graph_matcher.OpTypePattern(
       'Add',
       inputs=[
           graph_matcher.OneofPattern(
               [bias_add_pattern, folded_bias_add_pattern]), '*'
-      ])
-  bypass_pattern_b = graph_matcher.OpTypePattern(
-      'Add',
-      inputs=[
-          '*',
-          graph_matcher.OneofPattern(
-              [bias_add_pattern, folded_bias_add_pattern])
-      ])
+      ],
+      ordered_inputs=False)
 
   # The input to the activation can come from bias add, fold bias add, the
   # bypasses.
@@ -273,15 +273,14 @@ def _FindLayersToQuantize(graph):
       '|'.join(_ACTIVATION_TYPES) + '|Identity',
       inputs=[
           graph_matcher.OneofPattern([
-              bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
-              bypass_pattern_b
+              bias_add_pattern,
+              folded_bias_add_pattern,
+              bypass_pattern,
           ])
       ])
 
-  post_activation_bypass_pattern_a = graph_matcher.OpTypePattern(
-      'Add', inputs=['*', activation_pattern])
-  post_activation_bypass_pattern_b = graph_matcher.OpTypePattern(
-      'Add', inputs=[activation_pattern, '*'])
+  post_activation_bypass_pattern = graph_matcher.OpTypePattern(
+      'Add', inputs=['*', activation_pattern], ordered_inputs=False)
 
   # The order of the following matching blocks is very important. Since matches
   # aren't guaranteed to be disjoint, we structure matches from largest to
@@ -297,10 +296,7 @@ def _FindLayersToQuantize(graph):
   # to ensure we don't match only the first part of this layer, missing the
   # post activation bypass node.
   post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
-      graph_matcher.OneofPattern([
-          post_activation_bypass_pattern_a,
-          post_activation_bypass_pattern_b,
-      ]))
+      post_activation_bypass_pattern)
   for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
     layer_op = match_result.get_op(layer_pattern)
     weight_tensor = match_result.get_tensor(weight_identity_pattern)
@@ -312,14 +308,9 @@ def _FindLayersToQuantize(graph):
     bias_add_op = match_result.get_op(bias_add_pattern)
     if bias_add_op is None:
       bias_add_op = match_result.get_op(folded_bias_add_pattern)
-    bypass_op = match_result.get_op(bypass_pattern_a)
-    if bypass_op is None:
-      bypass_op = match_result.get_op(bypass_pattern_b)
+    bypass_op = match_result.get_op(bypass_pattern)
     post_activation_bypass_op = match_result.get_op(
-        post_activation_bypass_pattern_a)
-    if post_activation_bypass_op is None:
-      post_activation_bypass_op = match_result.get_op(
-          post_activation_bypass_pattern_b)
+        post_activation_bypass_pattern)
     if layer_op not in matched_layer_set:
       matched_layer_set.add(layer_op)
       layer_matches.append(
@@ -340,9 +331,7 @@ def _FindLayersToQuantize(graph):
     bias_add_op = match_result.get_op(bias_add_pattern)
     if bias_add_op is None:
       bias_add_op = match_result.get_op(folded_bias_add_pattern)
-    bypass_op = match_result.get_op(bypass_pattern_a)
-    if bypass_op is None:
-      bypass_op = match_result.get_op(bypass_pattern_b)
+    bypass_op = match_result.get_op(bypass_pattern)
     if layer_op not in matched_layer_set:
       matched_layer_set.add(layer_op)
       layer_matches.append(