From a4f0b3afb631f40024996c16a8bf2a146fb3dc8c Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 8 Feb 2018 09:34:49 -0800 Subject: [PATCH] Make quantization rewrites happen in place. If no graph is provided, then the default graph is used. PiperOrigin-RevId: 185007107 --- tensorflow/contrib/quantize/BUILD | 30 ---- tensorflow/contrib/quantize/python/copy_graph.py | 32 ---- .../contrib/quantize/python/copy_graph_test.py | 55 ------- .../quantize/python/fold_batch_norms_test.py | 14 +- .../contrib/quantize/python/quantize_graph.py | 177 ++++++--------------- .../contrib/quantize/python/quantize_graph_test.py | 82 ++++------ 6 files changed, 98 insertions(+), 292 deletions(-) delete mode 100644 tensorflow/contrib/quantize/python/copy_graph.py delete mode 100644 tensorflow/contrib/quantize/python/copy_graph_test.py diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index ada336e..42e295e 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -95,7 +95,6 @@ py_test( srcs = ["python/fold_batch_norms_test.py"], srcs_version = "PY2AND3", deps = [ - ":copy_graph", ":fold_batch_norms", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", @@ -110,31 +109,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:session", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "copy_graph", - srcs = ["python/copy_graph.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:framework_ops", "//tensorflow/python:training", - ], -) - -py_test( - name = "copy_graph_test", - size = "small", - srcs = ["python/copy_graph_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":copy_graph", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], ) @@ -235,12 +210,9 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":copy_graph", ":fold_batch_norms", ":quantize", - "//tensorflow/python:framework_ops", "//tensorflow/python:util", - "//tensorflow/python:variables", ], ) @@ -253,13 +225,11 @@ py_test( ":quantize_graph", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/quantize/python/copy_graph.py b/tensorflow/contrib/quantize/python/copy_graph.py deleted file mode 100644 index 0376fcb..0000000 --- a/tensorflow/contrib/quantize/python/copy_graph.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility to copy a tf.Graph.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.training import saver as saver_lib - - -def CopyGraph(graph): - """Return a copy of graph.""" - meta_graph = saver_lib.export_meta_graph( - graph=graph, collection_list=graph.get_all_collection_keys()) - graph_copy = ops.Graph() - with graph_copy.as_default(): - _ = saver_lib.import_meta_graph(meta_graph) - return graph_copy diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py deleted file mode 100644 index 7ff9ad9..0000000 --- a/tensorflow/contrib/quantize/python/copy_graph_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for copy_graph.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.quantize.python import copy_graph -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import googletest - - -class CopyGraphTest(test_util.TensorFlowTestCase): - - def _CompareNodeInGraph(self, node, graph): - graph_node = graph.get_operation_by_name(node.name) - self.assertEqual(str(node.node_def), str(graph_node.node_def)) - - def testCopyGraph(self): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - c = a + b - graph_copy = copy_graph.CopyGraph(graph) - # Ensure that the three original nodes are in the new graph. - # import_meta_graph also adds a saver node to the graph which we don't care - # about in this specific use case. - for tensor in [a, b, c]: - self._CompareNodeInGraph(tensor.op, graph_copy) - # Test that the graph collections are the same. - for key in graph.get_all_collection_keys(): - self.assertEqual( - len(graph.get_collection(key)), - len(graph_copy.get_collection(key)), 'Collection %s differs.') - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 330bd8a..c90a18a 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers -from tensorflow.contrib.quantize.python import copy_graph from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -34,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.training import saver as saver_lib batch_norm = layers.batch_norm conv2d = layers.conv2d @@ -379,7 +379,7 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu_node = relu(node, name='test/' + relu_op_name) - folded_g = copy_graph.CopyGraph(unfolded_g) + folded_g = self._CopyGraph(unfolded_g) with folded_g.as_default(): fold_batch_norms.FoldBatchNorms( folded_g, @@ -462,5 +462,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_op = graph.get_operation_by_name(out_op_name) self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + def _CopyGraph(self, graph): + """Return a copy of graph.""" + meta_graph = saver_lib.export_meta_graph( + graph=graph, collection_list=graph.get_all_collection_keys()) + graph_copy = ops.Graph() + with graph_copy.as_default(): + _ = saver_lib.import_meta_graph(meta_graph) + return graph_copy + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index 89b744c..81471d4 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -18,40 +18,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.quantize.python import copy_graph from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -def _create_graph(input_graph, - is_training, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. +def _create_graph(input_graph=None, is_training=True): + """Rewrites input_graph in place for simulated quantization. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. is_training: Whether quantizing training or eval graph. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - # TODO(suharshs): Describe the process in more detail in the doc string. - g = copy_graph.CopyGraph(input_graph) if is_training: # TODO(raghuramank): Need to make freeze_batch_norm_delay # a function of the batch size. For now setting this to 250 epochs @@ -59,146 +47,87 @@ def _create_graph(input_graph, freeze_batch_norm_delay = 5000000 else: freeze_batch_norm_delay = None - with g.as_default(): - with ops.device(device_name_or_function): - fold_batch_norms.FoldBatchNorms( - g, - freeze_batch_norm_delay=freeze_batch_norm_delay, - is_training=is_training) - quantize.Quantize(g, is_training=is_training) - if elements is None: - return g - - return_elements = [] - for element in elements: - if isinstance(element, (ops.Tensor, variables.Variable)): - return_elements.append(g.get_tensor_by_name(element.name)) - elif isinstance(element, ops.Operation): - return_elements.append(g.get_operation_by_name(element.name)) - else: - raise ValueError( - 'elements must consist of Tensor or Operation objects, got: ', - str(element)) - return g, return_elements - - -def create_training_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. - - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + if input_graph is None: + input_graph = ops.get_default_graph() + with input_graph.as_default(): + fold_batch_norms.FoldBatchNorms( + input_graph, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) + quantize.Quantize(input_graph, is_training=is_training) + + +def create_training_graph(input_graph=None): + """Rewrites a training input_graph in place for simulated quantization. + + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph( - input_graph=input_graph, - is_training=True, - elements=elements, - device_name_or_function=device_name_or_function) + _create_graph(input_graph=input_graph, is_training=True) -def create_eval_graph(input_graph, elements=None, device_name_or_function=None): - """Returns a transformed eval input_graph for simulated quantization. +def create_eval_graph(input_graph=None): + """Rewrites an eval input_graph in place for simulated quantization. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph( - input_graph=input_graph, - is_training=False, - elements=elements, - device_name_or_function=device_name_or_function) + _create_graph(input_graph=input_graph, is_training=False) -def experimental_create_training_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. +def experimental_create_training_graph(input_graph=None): + """Rewrites a training input_graph in place for simulated quantization. - This function has additional experimental options not (yet) available to - create_training_graph. The resulting behavior may be undefined. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph( - input_graph=input_graph, - is_training=True, - elements=elements, - device_name_or_function=device_name_or_function) + _create_graph(input_graph=input_graph, is_training=True) -def experimental_create_eval_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed eval input_graph for simulated quantization. +def experimental_create_eval_graph(input_graph=None): + """Rewrites an eval input_graph in place for simulated quantization. - This function has additional experimental options not (yet) available to - create_eval_graph. The resulting behavior may be undefined. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph( - input_graph=input_graph, - is_training=False, - elements=elements, - device_name_or_function=device_name_or_function) + _create_graph(input_graph=input_graph, is_training=False) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 514862a..7e08ebc 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -20,13 +20,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize_graph -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -44,46 +42,10 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for fn in rewrite_fns: test_fn(fn) - def testReturnedElements(self): - self._RunTestOverParameters(self._TestReturnElements) + def testRewrite(self): + self._RunTestOverParameters(self._TestRewrite) - def _TestReturnElements(self, fn): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - c = a + b - elements = [a, b, c.op] - q_graph, returned_elements = fn(graph, elements=elements) - # Make sure q_graph is different from graph. - self.assertTrue(graph != q_graph) - # Check that the returned elements are part of the new graph. - for returned_element in returned_elements: - self.assertEqual(q_graph, returned_element.graph) - # Check that the elements match with the one from the input graph. - for element, returned_element in zip(elements, returned_elements): - self.assertEqual(element.name, returned_element.name) - - def testNoReturnElements(self): - self._RunTestOverParameters(self._TestNoReturnElements) - - def _TestNoReturnElements(self, fn): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - _ = a + b - q_graph = fn(graph) - # Check that quantize_graph didn't return a tuple when elements isn't - # provided. - self.assertTrue(isinstance(q_graph, ops.Graph)) - # Make sure q_graph is different from graph. - self.assertTrue(graph != q_graph) - - def testDeviceName(self): - self._RunTestOverParameters(self._TestDeviceName) - - def _TestDeviceName(self, fn): + def _TestRewrite(self, fn): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -98,18 +60,40 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): scope='test') _ = nn_ops.relu6(conv) - device_name = '/job:oink/task:0/device:CPU:0' - q_graph = fn(graph, device_name_or_function=device_name) - orig_variable_names = set( [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + + fn(graph) + + q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) - # All added variables should have the specified device name. - for var in q_variables: - if var.name not in orig_variable_names: - self.assertEqual(var.device, device_name) + + def testDefaultGraph(self): + self._RunTestOverParameters(self._TestRewrite) + + def _TestDefaultGraph(self, fn): + with ops.Graph().as_default() as g: + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) + + orig_variable_names = set( + [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + + fn() + + q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # Ensure that variables were added. + self.assertTrue(len(orig_variable_names) < len(q_variables)) def _WeightInit(self, stddev): """Returns truncated normal variable initializer. -- 2.7.4