Relax limitations on rerouting graph outputs.
authorSuharsh Sivakumar <suharshs@google.com>
Thu, 29 Mar 2018 02:21:08 +0000 (19:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 02:23:05 +0000 (19:23 -0700)
- Allow multiple outputs of output_tensors in fold_batch_norms.
- Allow duplicate consumers in quantize.
- I also quick a fix issue for matching final layers that have batch norm.

PiperOrigin-RevId: 190873003

tensorflow/contrib/quantize/python/fold_batch_norms.py
tensorflow/contrib/quantize/python/quantize.py

index 5750be6..4a8f8a0 100644 (file)
@@ -134,9 +134,9 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
 
       nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
                                                      match.output_tensor)
-      if nodes_modified_count != 1:
-        raise ValueError(
-            'Unexpected inputs to op: %s' % match.output_tensor.name)
+      if nodes_modified_count == 0:
+        raise ValueError('Folding batch norms failed, %s had no outputs.' %
+                         match.output_tensor.name)
 
 
 def _FindFusedBatchNorms(graph):
index 019d123..2889016 100644 (file)
@@ -305,7 +305,8 @@ def _FindLayersToQuantize(graph):
   # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
   # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
   # quantized.
-  final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern)
+  final_layer_matcher = graph_matcher.GraphMatcher(
+      graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
   for match_result in final_layer_matcher.match_graph(graph):
     layer_op = match_result.get_op(layer_pattern)
     weight_tensor = match_result.get_tensor(weight_identity_pattern)
@@ -463,11 +464,16 @@ def _InsertQuantOp(context,
         lambda: inputs,
         name=name_prefix + '/delayed_quant')
 
-  nodes_modified_count = graph_editor.reroute_ts(
-      [quant], [inputs], can_modify=consumers)
-  if nodes_modified_count != len(consumers):
-    raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
-        [consumer.name for consumer in consumers]))
+  if consumers:
+    tensors_modified_count = graph_editor.reroute_ts(
+        [quant], [inputs], can_modify=consumers)
+    # Some operations can have multiple output tensors going to the same
+    # consumer. Since consumers is a set, we need to ensure that
+    # tensors_modified_count is greater than or equal to the length of the set
+    # of consumers.
+    if tensors_modified_count < len(consumers):
+      raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
+          [consumer.name for consumer in consumers]))
 
 
 def _GetContextFromOp(op):