from __future__ import print_function
import abc
+import itertools
class Pattern(object):
class OpTypePattern(Pattern):
"""A tree pattern that matches TF expressions with certain op types."""
- def __init__(self, op_type, name=None, inputs=None):
+ def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True):
"""Initializes an OpTypePattern.
Args:
inputs: Optional list of `Pattern`s or strings that specify the
patterns for the inputs of a matching op. If None, this pattern accepts
any inputs of a matching op.
+ ordered_inputs: Defaults to True. If False, will match any op that
+ matches a permutation of the inputs.
+
+ Raises:
+ ValueError: if too many inputs are provided when order_inputs is False.
"""
self._op_type = op_type
self._name = name
if inputs is None:
inputs = []
+ if len(inputs) > 8:
+ raise ValueError(
+ 'Only < 8 inputs are allowed when ordered_inputs is False.')
self._inputs = [
input_pattern
if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
for input_pattern in inputs
]
+ self._ordered_inputs = ordered_inputs
@property
def name(self):
if len(op.inputs) != len(self._inputs):
return None
- for input_tensor, input_pattern in zip(op.inputs, self._inputs):
- input_match_result = input_pattern.match(input_tensor.op, input_tensor)
- if input_match_result is None:
- return None
- match_result.merge_from(input_match_result)
- return match_result
+ input_patterns_list = [self._inputs]
+ # If order doesn't matter for the inputs, then make sure we match at least
+ # one permutation of the inputs.
+ if not self._ordered_inputs:
+ input_patterns_list = list(itertools.permutations(self._inputs))
+
+ for input_patterns in input_patterns_list:
+ match_failed = False
+ for input_tensor, input_pattern in zip(op.inputs, input_patterns):
+ input_match_result = input_pattern.match(input_tensor.op, input_tensor)
+ if input_match_result is None:
+ match_failed = True
+ break
+ match_result.merge_from(input_match_result)
+ if not match_failed:
+ return match_result
+ return None
class OneofPattern(Pattern):
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import graph_matcher
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
self.assertEqual(match_result.get_tensor('slice'), slicing)
self.assertEqual(match_result.get_op('transpose'), transpose.op)
+ def test_ordered_pattern(self):
+ # + +
+ # / \ / \
+ # x y and y x should both match when ordered inputs is False.
+ # Even when x and y are different operations.
+ g = ops.Graph()
+ with g.as_default():
+ x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
+ y = constant_op.constant(1.0, dtype=dtypes.float32)
+ plus = x + y
+
+ add_pattern_a = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False)
+ add_pattern_b = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False)
+ add_pattern_fail = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True)
+ # Both add_pattern_a and add_pattern_b should match the graph since
+ # ordered_input was set False.
+ matcher_a = graph_matcher.GraphMatcher(add_pattern_a)
+ self.assertEqual([
+ match_result.get_op(add_pattern_a)
+ for match_result in matcher_a.match_graph(g)
+ ], [plus.op])
+ matcher_b = graph_matcher.GraphMatcher(add_pattern_b)
+ self.assertEqual([
+ match_result.get_op(add_pattern_b)
+ for match_result in matcher_b.match_graph(g)
+ ], [plus.op])
+ # But if ordered_inputs is True, the inputs list match should fail if not
+ # specified in the right order.
+ matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail)
+ self.assertEqual(
+ len([
+ match_result.get_op(add_pattern_fail)
+ for match_result in matcher_fail.match_graph(g)
+ ]), 0)
+
if __name__ == '__main__':
googletest.main()
weight_identity_pattern, weight_resource_var_pattern,
folded_weight_pattern
])
- ])
+ ],
+ ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
- 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
+ 'Mul',
+ inputs=[graph_matcher.OpTypePattern('*'), layer_pattern],
+ ordered_inputs=False)
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
- 'Add', inputs=[folded_bias_mul_pattern,
- graph_matcher.OpTypePattern('*')])
+ 'Add',
+ inputs=[folded_bias_mul_pattern,
+ graph_matcher.OpTypePattern('*')],
+ ordered_inputs=False)
folded_bias_add_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
post_layer_op_correction_pattern,
graph_matcher.OpTypePattern('*')
- ])
+ ],
+ ordered_inputs=False)
bias_add_pattern = graph_matcher.OpTypePattern(
- 'Add|BiasAdd', inputs=[layer_pattern, '*'])
+ 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False)
# The bias can come from the bias add or the folded bias add.
- bypass_pattern_a = graph_matcher.OpTypePattern(
+ bypass_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
graph_matcher.OneofPattern(
[bias_add_pattern, folded_bias_add_pattern]), '*'
- ])
- bypass_pattern_b = graph_matcher.OpTypePattern(
- 'Add',
- inputs=[
- '*',
- graph_matcher.OneofPattern(
- [bias_add_pattern, folded_bias_add_pattern])
- ])
+ ],
+ ordered_inputs=False)
# The input to the activation can come from bias add, fold bias add, the
# bypasses.
'|'.join(_ACTIVATION_TYPES) + '|Identity',
inputs=[
graph_matcher.OneofPattern([
- bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
- bypass_pattern_b
+ bias_add_pattern,
+ folded_bias_add_pattern,
+ bypass_pattern,
])
])
- post_activation_bypass_pattern_a = graph_matcher.OpTypePattern(
- 'Add', inputs=['*', activation_pattern])
- post_activation_bypass_pattern_b = graph_matcher.OpTypePattern(
- 'Add', inputs=[activation_pattern, '*'])
+ post_activation_bypass_pattern = graph_matcher.OpTypePattern(
+ 'Add', inputs=['*', activation_pattern], ordered_inputs=False)
# The order of the following matching blocks is very important. Since matches
# aren't guaranteed to be disjoint, we structure matches from largest to
# to ensure we don't match only the first part of this layer, missing the
# post activation bypass node.
post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
- graph_matcher.OneofPattern([
- post_activation_bypass_pattern_a,
- post_activation_bypass_pattern_b,
- ]))
+ post_activation_bypass_pattern)
for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
- bypass_op = match_result.get_op(bypass_pattern_a)
- if bypass_op is None:
- bypass_op = match_result.get_op(bypass_pattern_b)
+ bypass_op = match_result.get_op(bypass_pattern)
post_activation_bypass_op = match_result.get_op(
- post_activation_bypass_pattern_a)
- if post_activation_bypass_op is None:
- post_activation_bypass_op = match_result.get_op(
- post_activation_bypass_pattern_b)
+ post_activation_bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
- bypass_op = match_result.get_op(bypass_pattern_a)
- if bypass_op is None:
- bypass_op = match_result.get_op(bypass_pattern_b)
+ bypass_op = match_result.get_op(bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(