Make quantization rewrites happen in place.
authorSuharsh Sivakumar <suharshs@google.com>
Thu, 8 Feb 2018 17:34:49 +0000 (09:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 17:38:08 +0000 (09:38 -0800)
If no graph is provided, then the default graph is used.

PiperOrigin-RevId: 185007107

tensorflow/contrib/quantize/BUILD
tensorflow/contrib/quantize/python/copy_graph.py [deleted file]
tensorflow/contrib/quantize/python/copy_graph_test.py [deleted file]
tensorflow/contrib/quantize/python/fold_batch_norms_test.py
tensorflow/contrib/quantize/python/quantize_graph.py
tensorflow/contrib/quantize/python/quantize_graph_test.py

index ada336e623561c4ff3246bb96102ac7f626addd2..42e295e622ed59c34791a6ebdd10258228fbd088 100644 (file)
@@ -95,7 +95,6 @@ py_test(
     srcs = ["python/fold_batch_norms_test.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":copy_graph",
         ":fold_batch_norms",
         "//tensorflow/contrib/layers:layers_py",
         "//tensorflow/python:array_ops",
@@ -110,31 +109,7 @@ py_test(
         "//tensorflow/python:random_ops",
         "//tensorflow/python:random_seed",
         "//tensorflow/python:session",
-        "//tensorflow/python:variables",
-    ],
-)
-
-py_library(
-    name = "copy_graph",
-    srcs = ["python/copy_graph.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        "//tensorflow/python:framework_ops",
         "//tensorflow/python:training",
-    ],
-)
-
-py_test(
-    name = "copy_graph_test",
-    size = "small",
-    srcs = ["python/copy_graph_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":copy_graph",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:platform_test",
         "//tensorflow/python:variables",
     ],
 )
@@ -235,12 +210,9 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
-        ":copy_graph",
         ":fold_batch_norms",
         ":quantize",
-        "//tensorflow/python:framework_ops",
         "//tensorflow/python:util",
-        "//tensorflow/python:variables",
     ],
 )
 
@@ -253,13 +225,11 @@ py_test(
         ":quantize_graph",
         "//tensorflow/contrib/layers:layers_py",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:constant_op",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:platform_test",
-        "//tensorflow/python:variables",
     ],
 )
 
diff --git a/tensorflow/contrib/quantize/python/copy_graph.py b/tensorflow/contrib/quantize/python/copy_graph.py
deleted file mode 100644 (file)
index 0376fcb..0000000
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility to copy a tf.Graph."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.training import saver as saver_lib
-
-
-def CopyGraph(graph):
-  """Return a copy of graph."""
-  meta_graph = saver_lib.export_meta_graph(
-      graph=graph, collection_list=graph.get_all_collection_keys())
-  graph_copy = ops.Graph()
-  with graph_copy.as_default():
-    _ = saver_lib.import_meta_graph(meta_graph)
-  return graph_copy
diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py
deleted file mode 100644 (file)
index 7ff9ad9..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for copy_graph."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.quantize.python import copy_graph
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import googletest
-
-
-class CopyGraphTest(test_util.TensorFlowTestCase):
-
-  def _CompareNodeInGraph(self, node, graph):
-    graph_node = graph.get_operation_by_name(node.name)
-    self.assertEqual(str(node.node_def), str(graph_node.node_def))
-
-  def testCopyGraph(self):
-    graph = ops.Graph()
-    with graph.as_default():
-      a = constant_op.constant(1.0)
-      b = variables.Variable(2.0)
-      c = a + b
-    graph_copy = copy_graph.CopyGraph(graph)
-    # Ensure that the three original nodes are in the new graph.
-    # import_meta_graph also adds a saver node to the graph which we don't care
-    # about in this specific use case.
-    for tensor in [a, b, c]:
-      self._CompareNodeInGraph(tensor.op, graph_copy)
-    # Test that the graph collections are the same.
-    for key in graph.get_all_collection_keys():
-      self.assertEqual(
-          len(graph.get_collection(key)),
-          len(graph_copy.get_collection(key)), 'Collection %s differs.')
-
-
-if __name__ == '__main__':
-  googletest.main()
index 330bd8a6474c18b236b635d930e7a1df9594d84f..c90a18ab0357f1bcbc5d8ccd48edf894d7baf5f9 100644 (file)
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.layers.python.layers import layers
-from tensorflow.contrib.quantize.python import copy_graph
 from tensorflow.contrib.quantize.python import fold_batch_norms
 from tensorflow.python.client import session
 from tensorflow.python.framework import dtypes
@@ -34,6 +33,7 @@ from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
+from tensorflow.python.training import saver as saver_lib
 
 batch_norm = layers.batch_norm
 conv2d = layers.conv2d
@@ -379,7 +379,7 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
       if with_bypass:
         node = math_ops.add(inputs, node, name='test/Add')
       relu_node = relu(node, name='test/' + relu_op_name)
-    folded_g = copy_graph.CopyGraph(unfolded_g)
+    folded_g = self._CopyGraph(unfolded_g)
     with folded_g.as_default():
       fold_batch_norms.FoldBatchNorms(
           folded_g,
@@ -462,5 +462,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
       out_op = graph.get_operation_by_name(out_op_name)
       self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
 
+  def _CopyGraph(self, graph):
+    """Return a copy of graph."""
+    meta_graph = saver_lib.export_meta_graph(
+        graph=graph, collection_list=graph.get_all_collection_keys())
+    graph_copy = ops.Graph()
+    with graph_copy.as_default():
+      _ = saver_lib.import_meta_graph(meta_graph)
+    return graph_copy
+
+
 if __name__ == '__main__':
   googletest.main()
index 89b744c559170e7d9e502d3d8610afaca2c549b7..81471d4c50bc3b5ddb2064a78a03abb9cdb976f3 100644 (file)
@@ -18,40 +18,28 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.quantize.python import copy_graph
 from tensorflow.contrib.quantize.python import fold_batch_norms
 from tensorflow.contrib.quantize.python import quantize
 from tensorflow.python.framework import ops
-from tensorflow.python.ops import variables
 
 
-def _create_graph(input_graph,
-                  is_training,
-                  elements=None,
-                  device_name_or_function=None):
-  """Returns a transformed training input_graph for simulated quantization.
+def _create_graph(input_graph=None, is_training=True):
+  """Rewrites input_graph in place for simulated quantization.
 
-  The forward pass has fake quantization ops inserted to simulate the error
-  introduced by quantization.
+  The graph has fake quantization ops inserted to simulate the error
+  introduced by quantization. Since the graph is transformed in place,
+  the expected behavior of previously held references to nodes and tensors may
+  change.
 
   Args:
-    input_graph: The tf.Graph to be transformed.
+    input_graph: The tf.Graph to be transformed, if None then defaults to the
+      default graph.
     is_training: Whether quantizing training or eval graph.
-    elements: (Optional) List of Tensors and Operations in input_graph whose
-        corresponding elements in the new graph will be returned.
-    device_name_or_function: (Optional) The device name or function to use.
-
-  Returns:
-    g is new tf.Graph that is rewritten for simulated quantization.
-    l is a list of Tensors/Operations in g corresponding to the provided input
-        elements, if elements is not None.
 
   Raises:
     ValueError: If elements contains an element that isn't a tf.Tensor or
         tf.Operation.
   """
-  # TODO(suharshs): Describe the process in more detail in the doc string.
-  g = copy_graph.CopyGraph(input_graph)
   if is_training:
     # TODO(raghuramank): Need to make freeze_batch_norm_delay
     # a function of the batch size. For now setting this to 250 epochs
@@ -59,146 +47,87 @@ def _create_graph(input_graph,
     freeze_batch_norm_delay = 5000000
   else:
     freeze_batch_norm_delay = None
-  with g.as_default():
-    with ops.device(device_name_or_function):
-      fold_batch_norms.FoldBatchNorms(
-          g,
-          freeze_batch_norm_delay=freeze_batch_norm_delay,
-          is_training=is_training)
-      quantize.Quantize(g, is_training=is_training)
-  if elements is None:
-    return g
-
-  return_elements = []
-  for element in elements:
-    if isinstance(element, (ops.Tensor, variables.Variable)):
-      return_elements.append(g.get_tensor_by_name(element.name))
-    elif isinstance(element, ops.Operation):
-      return_elements.append(g.get_operation_by_name(element.name))
-    else:
-      raise ValueError(
-          'elements must consist of Tensor or Operation objects, got: ',
-          str(element))
-  return g, return_elements
-
-
-def create_training_graph(input_graph,
-                          elements=None,
-                          device_name_or_function=None):
-  """Returns a transformed training input_graph for simulated quantization.
-
-  The forward pass has fake quantization ops inserted to simulate the error
-  introduced by quantization.
+  if input_graph is None:
+    input_graph = ops.get_default_graph()
+  with input_graph.as_default():
+    fold_batch_norms.FoldBatchNorms(
+        input_graph,
+        freeze_batch_norm_delay=freeze_batch_norm_delay,
+        is_training=is_training)
+    quantize.Quantize(input_graph, is_training=is_training)
+
+
+def create_training_graph(input_graph=None):
+  """Rewrites a training input_graph in place for simulated quantization.
+
+  The graph has fake quantization ops inserted to simulate the error
+  introduced by quantization. Since the graph is transformed in place,
+  the expected behavior of previously held references to nodes and tensors may
+  change.
 
   Args:
     input_graph: The tf.Graph to be transformed.
-    elements: (Optional) List of Tensors and Operations in input_graph whose
-        corresponding elements in the new graph will be returned.
-    device_name_or_function: (Optional) The device name or function to use.
-
-  Returns:
-    g is new tf.Graph that is rewritten for simulated quantization.
-    l is a list of Tensors/Operations in g corresponding to the provided input
-        elements, if elements is not None.
 
   Raises:
     ValueError: If elements contains an element that isn't a tf.Tensor or
         tf.Operation.
   """
-  return _create_graph(
-      input_graph=input_graph,
-      is_training=True,
-      elements=elements,
-      device_name_or_function=device_name_or_function)
+  _create_graph(input_graph=input_graph, is_training=True)
 
 
-def create_eval_graph(input_graph, elements=None, device_name_or_function=None):
-  """Returns a transformed eval input_graph for simulated quantization.
+def create_eval_graph(input_graph=None):
+  """Rewrites an eval input_graph in place for simulated quantization.
 
-  The forward pass has fake quantization ops inserted to simulate the error
-  introduced by quantization.
+  The graph has fake quantization ops inserted to simulate the error
+  introduced by quantization. Since the graph is transformed in place,
+  the expected behavior of previously held references to nodes and tensors may
+  change.
 
   Args:
-    input_graph: The tf.Graph to be transformed.
-    elements: (Optional) List of Tensors and Operations in input_graph whose
-        corresponding elements in the new graph will be returned.
-    device_name_or_function: (Optional) The device name or function to use.
+    input_graph: The tf.Graph to be transformed, if None then defaults to the
+      default graph.
 
-  Returns:
-    g is new tf.Graph that is rewritten for simulated quantization.
-    l is a list of Tensors/Operations in g corresponding to the provided input
-        elements, if elements is not None.
 
   Raises:
     ValueError: If elements contains an element that isn't a tf.Tensor or
         tf.Operation.
   """
-  return _create_graph(
-      input_graph=input_graph,
-      is_training=False,
-      elements=elements,
-      device_name_or_function=device_name_or_function)
+  _create_graph(input_graph=input_graph, is_training=False)
 
 
-def experimental_create_training_graph(input_graph,
-                                       elements=None,
-                                       device_name_or_function=None):
-  """Returns a transformed training input_graph for simulated quantization.
+def experimental_create_training_graph(input_graph=None):
+  """Rewrites a training input_graph in place for simulated quantization.
 
-  This function has additional experimental options not (yet) available to
-  create_training_graph. The resulting behavior may be undefined.
-  The forward pass has fake quantization ops inserted to simulate the error
-  introduced by quantization.
+  The graph has fake quantization ops inserted to simulate the error
+  introduced by quantization. Since the graph is transformed in place,
+  the expected behavior of previously held references to nodes and tensors may
+  change.
 
   Args:
-    input_graph: The tf.Graph to be transformed.
-    elements: (Optional) List of Tensors and Operations in input_graph whose
-        corresponding elements in the new graph will be returned.
-    device_name_or_function: (Optional) The device name or function to use.
-
-  Returns:
-    g is new tf.Graph that is rewritten for simulated quantization.
-    l is a list of Tensors/Operations in g corresponding to the provided input
-        elements, if elements is not None.
+    input_graph: The tf.Graph to be transformed, if None then defaults to the
+      default graph.
 
   Raises:
     ValueError: If elements contains an element that isn't a tf.Tensor or
         tf.Operation.
   """
-  return _create_graph(
-      input_graph=input_graph,
-      is_training=True,
-      elements=elements,
-      device_name_or_function=device_name_or_function)
+  _create_graph(input_graph=input_graph, is_training=True)
 
 
-def experimental_create_eval_graph(input_graph,
-                                   elements=None,
-                                   device_name_or_function=None):
-  """Returns a transformed eval input_graph for simulated quantization.
+def experimental_create_eval_graph(input_graph=None):
+  """Rewrites an eval input_graph in place for simulated quantization.
 
-  This function has additional experimental options not (yet) available to
-  create_eval_graph. The resulting behavior may be undefined.
-  The forward pass has fake quantization ops inserted to simulate the error
-  introduced by quantization.
+  The graph has fake quantization ops inserted to simulate the error
+  introduced by quantization. Since the graph is transformed in place,
+  the expected behavior of previously held references to nodes and tensors may
+  change.
 
   Args:
-    input_graph: The tf.Graph to be transformed.
-    elements: (Optional) List of Tensors and Operations in input_graph whose
-        corresponding elements in the new graph will be returned.
-    device_name_or_function: (Optional) The device name or function to use.
-
-  Returns:
-    g is new tf.Graph that is rewritten for simulated quantization.
-    l is a list of Tensors/Operations in g corresponding to the provided input
-        elements, if elements is not None.
+    input_graph: The tf.Graph to be transformed, if None then defaults to the
+      default graph.
 
   Raises:
     ValueError: If elements contains an element that isn't a tf.Tensor or
         tf.Operation.
   """
-  return _create_graph(
-      input_graph=input_graph,
-      is_training=False,
-      elements=elements,
-      device_name_or_function=device_name_or_function)
+  _create_graph(input_graph=input_graph, is_training=False)
index 514862a0ab5b796718a04aa65a46e7a7e3b86330..7e08ebcb5c122049810f89d1cfdbd149e211839d 100644 (file)
@@ -20,13 +20,11 @@ from __future__ import print_function
 
 from tensorflow.contrib.layers.python.layers import layers
 from tensorflow.contrib.quantize.python import quantize_graph
-from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 
 
@@ -44,46 +42,10 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
     for fn in rewrite_fns:
       test_fn(fn)
 
-  def testReturnedElements(self):
-    self._RunTestOverParameters(self._TestReturnElements)
+  def testRewrite(self):
+    self._RunTestOverParameters(self._TestRewrite)
 
-  def _TestReturnElements(self, fn):
-    graph = ops.Graph()
-    with graph.as_default():
-      a = constant_op.constant(1.0)
-      b = variables.Variable(2.0)
-      c = a + b
-    elements = [a, b, c.op]
-    q_graph, returned_elements = fn(graph, elements=elements)
-    # Make sure q_graph is different from graph.
-    self.assertTrue(graph != q_graph)
-    # Check that the returned elements are part of the new graph.
-    for returned_element in returned_elements:
-      self.assertEqual(q_graph, returned_element.graph)
-    # Check that the elements match with the one from the input graph.
-    for element, returned_element in zip(elements, returned_elements):
-      self.assertEqual(element.name, returned_element.name)
-
-  def testNoReturnElements(self):
-    self._RunTestOverParameters(self._TestNoReturnElements)
-
-  def _TestNoReturnElements(self, fn):
-    graph = ops.Graph()
-    with graph.as_default():
-      a = constant_op.constant(1.0)
-      b = variables.Variable(2.0)
-      _ = a + b
-    q_graph = fn(graph)
-    # Check that quantize_graph didn't return a tuple when elements isn't
-    # provided.
-    self.assertTrue(isinstance(q_graph, ops.Graph))
-    # Make sure q_graph is different from graph.
-    self.assertTrue(graph != q_graph)
-
-  def testDeviceName(self):
-    self._RunTestOverParameters(self._TestDeviceName)
-
-  def _TestDeviceName(self, fn):
+  def _TestRewrite(self, fn):
     graph = ops.Graph()
     with graph.as_default():
       batch_size, height, width, depth = 5, 128, 128, 3
@@ -98,18 +60,40 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
           scope='test')
       _ = nn_ops.relu6(conv)
 
-    device_name = '/job:oink/task:0/device:CPU:0'
-    q_graph = fn(graph, device_name_or_function=device_name)
-
     orig_variable_names = set(
         [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
-    q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+
+    fn(graph)
+
+    q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
     # Ensure that variables were added.
     self.assertTrue(len(orig_variable_names) < len(q_variables))
-    # All added variables should have the specified device name.
-    for var in q_variables:
-      if var.name not in orig_variable_names:
-        self.assertEqual(var.device, device_name)
+
+  def testDefaultGraph(self):
+    self._RunTestOverParameters(self._TestRewrite)
+
+  def _TestDefaultGraph(self, fn):
+    with ops.Graph().as_default() as g:
+      batch_size, height, width, depth = 5, 128, 128, 3
+      inputs = array_ops.zeros((batch_size, height, width, depth))
+      conv = layers.conv2d(
+          inputs,
+          32, [5, 5],
+          stride=2,
+          padding='SAME',
+          weights_initializer=self._WeightInit(0.09),
+          activation_fn=None,
+          scope='test')
+      _ = nn_ops.relu6(conv)
+
+      orig_variable_names = set(
+          [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+      fn()
+
+      q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+      # Ensure that variables were added.
+      self.assertTrue(len(orig_variable_names) < len(q_variables))
 
   def _WeightInit(self, stddev):
     """Returns truncated normal variable initializer.