srcs = ["python/fold_batch_norms_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":copy_graph",
":fold_batch_norms",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:session",
- "//tensorflow/python:variables",
- ],
-)
-
-py_library(
- name = "copy_graph",
- srcs = ["python/copy_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:framework_ops",
"//tensorflow/python:training",
- ],
-)
-
-py_test(
- name = "copy_graph_test",
- size = "small",
- srcs = ["python/copy_graph_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":copy_graph",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
],
srcs_version = "PY2AND3",
deps = [
- ":copy_graph",
":fold_batch_norms",
":quantize",
- "//tensorflow/python:framework_ops",
"//tensorflow/python:util",
- "//tensorflow/python:variables",
],
)
":quantize_graph",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
- "//tensorflow/python:variables",
],
)
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility to copy a tf.Graph."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.training import saver as saver_lib
-
-
-def CopyGraph(graph):
- """Return a copy of graph."""
- meta_graph = saver_lib.export_meta_graph(
- graph=graph, collection_list=graph.get_all_collection_keys())
- graph_copy = ops.Graph()
- with graph_copy.as_default():
- _ = saver_lib.import_meta_graph(meta_graph)
- return graph_copy
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for copy_graph."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.quantize.python import copy_graph
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import googletest
-
-
-class CopyGraphTest(test_util.TensorFlowTestCase):
-
- def _CompareNodeInGraph(self, node, graph):
- graph_node = graph.get_operation_by_name(node.name)
- self.assertEqual(str(node.node_def), str(graph_node.node_def))
-
- def testCopyGraph(self):
- graph = ops.Graph()
- with graph.as_default():
- a = constant_op.constant(1.0)
- b = variables.Variable(2.0)
- c = a + b
- graph_copy = copy_graph.CopyGraph(graph)
- # Ensure that the three original nodes are in the new graph.
- # import_meta_graph also adds a saver node to the graph which we don't care
- # about in this specific use case.
- for tensor in [a, b, c]:
- self._CompareNodeInGraph(tensor.op, graph_copy)
- # Test that the graph collections are the same.
- for key in graph.get_all_collection_keys():
- self.assertEqual(
- len(graph.get_collection(key)),
- len(graph_copy.get_collection(key)), 'Collection %s differs.')
-
-
-if __name__ == '__main__':
- googletest.main()
from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
-from tensorflow.contrib.quantize.python import copy_graph
from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+from tensorflow.python.training import saver as saver_lib
batch_norm = layers.batch_norm
conv2d = layers.conv2d
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu_node = relu(node, name='test/' + relu_op_name)
- folded_g = copy_graph.CopyGraph(unfolded_g)
+ folded_g = self._CopyGraph(unfolded_g)
with folded_g.as_default():
fold_batch_norms.FoldBatchNorms(
folded_g,
out_op = graph.get_operation_by_name(out_op_name)
self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
+ def _CopyGraph(self, graph):
+ """Return a copy of graph."""
+ meta_graph = saver_lib.export_meta_graph(
+ graph=graph, collection_list=graph.get_all_collection_keys())
+ graph_copy = ops.Graph()
+ with graph_copy.as_default():
+ _ = saver_lib.import_meta_graph(meta_graph)
+ return graph_copy
+
+
if __name__ == '__main__':
googletest.main()
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.quantize.python import copy_graph
from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.contrib.quantize.python import quantize
from tensorflow.python.framework import ops
-from tensorflow.python.ops import variables
-def _create_graph(input_graph,
- is_training,
- elements=None,
- device_name_or_function=None):
- """Returns a transformed training input_graph for simulated quantization.
+def _create_graph(input_graph=None, is_training=True):
+ """Rewrites input_graph in place for simulated quantization.
- The forward pass has fake quantization ops inserted to simulate the error
- introduced by quantization.
+ The graph has fake quantization ops inserted to simulate the error
+ introduced by quantization. Since the graph is transformed in place,
+ the expected behavior of previously held references to nodes and tensors may
+ change.
Args:
- input_graph: The tf.Graph to be transformed.
+ input_graph: The tf.Graph to be transformed, if None then defaults to the
+ default graph.
is_training: Whether quantizing training or eval graph.
- elements: (Optional) List of Tensors and Operations in input_graph whose
- corresponding elements in the new graph will be returned.
- device_name_or_function: (Optional) The device name or function to use.
-
- Returns:
- g is new tf.Graph that is rewritten for simulated quantization.
- l is a list of Tensors/Operations in g corresponding to the provided input
- elements, if elements is not None.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
tf.Operation.
"""
- # TODO(suharshs): Describe the process in more detail in the doc string.
- g = copy_graph.CopyGraph(input_graph)
if is_training:
# TODO(raghuramank): Need to make freeze_batch_norm_delay
# a function of the batch size. For now setting this to 250 epochs
freeze_batch_norm_delay = 5000000
else:
freeze_batch_norm_delay = None
- with g.as_default():
- with ops.device(device_name_or_function):
- fold_batch_norms.FoldBatchNorms(
- g,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- is_training=is_training)
- quantize.Quantize(g, is_training=is_training)
- if elements is None:
- return g
-
- return_elements = []
- for element in elements:
- if isinstance(element, (ops.Tensor, variables.Variable)):
- return_elements.append(g.get_tensor_by_name(element.name))
- elif isinstance(element, ops.Operation):
- return_elements.append(g.get_operation_by_name(element.name))
- else:
- raise ValueError(
- 'elements must consist of Tensor or Operation objects, got: ',
- str(element))
- return g, return_elements
-
-
-def create_training_graph(input_graph,
- elements=None,
- device_name_or_function=None):
- """Returns a transformed training input_graph for simulated quantization.
-
- The forward pass has fake quantization ops inserted to simulate the error
- introduced by quantization.
+ if input_graph is None:
+ input_graph = ops.get_default_graph()
+ with input_graph.as_default():
+ fold_batch_norms.FoldBatchNorms(
+ input_graph,
+ freeze_batch_norm_delay=freeze_batch_norm_delay,
+ is_training=is_training)
+ quantize.Quantize(input_graph, is_training=is_training)
+
+
+def create_training_graph(input_graph=None):
+ """Rewrites a training input_graph in place for simulated quantization.
+
+ The graph has fake quantization ops inserted to simulate the error
+ introduced by quantization. Since the graph is transformed in place,
+ the expected behavior of previously held references to nodes and tensors may
+ change.
Args:
input_graph: The tf.Graph to be transformed.
- elements: (Optional) List of Tensors and Operations in input_graph whose
- corresponding elements in the new graph will be returned.
- device_name_or_function: (Optional) The device name or function to use.
-
- Returns:
- g is new tf.Graph that is rewritten for simulated quantization.
- l is a list of Tensors/Operations in g corresponding to the provided input
- elements, if elements is not None.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
tf.Operation.
"""
- return _create_graph(
- input_graph=input_graph,
- is_training=True,
- elements=elements,
- device_name_or_function=device_name_or_function)
+ _create_graph(input_graph=input_graph, is_training=True)
-def create_eval_graph(input_graph, elements=None, device_name_or_function=None):
- """Returns a transformed eval input_graph for simulated quantization.
+def create_eval_graph(input_graph=None):
+ """Rewrites an eval input_graph in place for simulated quantization.
- The forward pass has fake quantization ops inserted to simulate the error
- introduced by quantization.
+ The graph has fake quantization ops inserted to simulate the error
+ introduced by quantization. Since the graph is transformed in place,
+ the expected behavior of previously held references to nodes and tensors may
+ change.
Args:
- input_graph: The tf.Graph to be transformed.
- elements: (Optional) List of Tensors and Operations in input_graph whose
- corresponding elements in the new graph will be returned.
- device_name_or_function: (Optional) The device name or function to use.
+ input_graph: The tf.Graph to be transformed, if None then defaults to the
+ default graph.
- Returns:
- g is new tf.Graph that is rewritten for simulated quantization.
- l is a list of Tensors/Operations in g corresponding to the provided input
- elements, if elements is not None.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
tf.Operation.
"""
- return _create_graph(
- input_graph=input_graph,
- is_training=False,
- elements=elements,
- device_name_or_function=device_name_or_function)
+ _create_graph(input_graph=input_graph, is_training=False)
-def experimental_create_training_graph(input_graph,
- elements=None,
- device_name_or_function=None):
- """Returns a transformed training input_graph for simulated quantization.
+def experimental_create_training_graph(input_graph=None):
+ """Rewrites a training input_graph in place for simulated quantization.
- This function has additional experimental options not (yet) available to
- create_training_graph. The resulting behavior may be undefined.
- The forward pass has fake quantization ops inserted to simulate the error
- introduced by quantization.
+ The graph has fake quantization ops inserted to simulate the error
+ introduced by quantization. Since the graph is transformed in place,
+ the expected behavior of previously held references to nodes and tensors may
+ change.
Args:
- input_graph: The tf.Graph to be transformed.
- elements: (Optional) List of Tensors and Operations in input_graph whose
- corresponding elements in the new graph will be returned.
- device_name_or_function: (Optional) The device name or function to use.
-
- Returns:
- g is new tf.Graph that is rewritten for simulated quantization.
- l is a list of Tensors/Operations in g corresponding to the provided input
- elements, if elements is not None.
+ input_graph: The tf.Graph to be transformed, if None then defaults to the
+ default graph.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
tf.Operation.
"""
- return _create_graph(
- input_graph=input_graph,
- is_training=True,
- elements=elements,
- device_name_or_function=device_name_or_function)
+ _create_graph(input_graph=input_graph, is_training=True)
-def experimental_create_eval_graph(input_graph,
- elements=None,
- device_name_or_function=None):
- """Returns a transformed eval input_graph for simulated quantization.
+def experimental_create_eval_graph(input_graph=None):
+ """Rewrites an eval input_graph in place for simulated quantization.
- This function has additional experimental options not (yet) available to
- create_eval_graph. The resulting behavior may be undefined.
- The forward pass has fake quantization ops inserted to simulate the error
- introduced by quantization.
+ The graph has fake quantization ops inserted to simulate the error
+ introduced by quantization. Since the graph is transformed in place,
+ the expected behavior of previously held references to nodes and tensors may
+ change.
Args:
- input_graph: The tf.Graph to be transformed.
- elements: (Optional) List of Tensors and Operations in input_graph whose
- corresponding elements in the new graph will be returned.
- device_name_or_function: (Optional) The device name or function to use.
-
- Returns:
- g is new tf.Graph that is rewritten for simulated quantization.
- l is a list of Tensors/Operations in g corresponding to the provided input
- elements, if elements is not None.
+ input_graph: The tf.Graph to be transformed, if None then defaults to the
+ default graph.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
tf.Operation.
"""
- return _create_graph(
- input_graph=input_graph,
- is_training=False,
- elements=elements,
- device_name_or_function=device_name_or_function)
+ _create_graph(input_graph=input_graph, is_training=False)
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import quantize_graph
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
for fn in rewrite_fns:
test_fn(fn)
- def testReturnedElements(self):
- self._RunTestOverParameters(self._TestReturnElements)
+ def testRewrite(self):
+ self._RunTestOverParameters(self._TestRewrite)
- def _TestReturnElements(self, fn):
- graph = ops.Graph()
- with graph.as_default():
- a = constant_op.constant(1.0)
- b = variables.Variable(2.0)
- c = a + b
- elements = [a, b, c.op]
- q_graph, returned_elements = fn(graph, elements=elements)
- # Make sure q_graph is different from graph.
- self.assertTrue(graph != q_graph)
- # Check that the returned elements are part of the new graph.
- for returned_element in returned_elements:
- self.assertEqual(q_graph, returned_element.graph)
- # Check that the elements match with the one from the input graph.
- for element, returned_element in zip(elements, returned_elements):
- self.assertEqual(element.name, returned_element.name)
-
- def testNoReturnElements(self):
- self._RunTestOverParameters(self._TestNoReturnElements)
-
- def _TestNoReturnElements(self, fn):
- graph = ops.Graph()
- with graph.as_default():
- a = constant_op.constant(1.0)
- b = variables.Variable(2.0)
- _ = a + b
- q_graph = fn(graph)
- # Check that quantize_graph didn't return a tuple when elements isn't
- # provided.
- self.assertTrue(isinstance(q_graph, ops.Graph))
- # Make sure q_graph is different from graph.
- self.assertTrue(graph != q_graph)
-
- def testDeviceName(self):
- self._RunTestOverParameters(self._TestDeviceName)
-
- def _TestDeviceName(self, fn):
+ def _TestRewrite(self, fn):
graph = ops.Graph()
with graph.as_default():
batch_size, height, width, depth = 5, 128, 128, 3
scope='test')
_ = nn_ops.relu6(conv)
- device_name = '/job:oink/task:0/device:CPU:0'
- q_graph = fn(graph, device_name_or_function=device_name)
-
orig_variable_names = set(
[v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+
+ fn(graph)
+
+ q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
- # All added variables should have the specified device name.
- for var in q_variables:
- if var.name not in orig_variable_names:
- self.assertEqual(var.device, device_name)
+
+ def testDefaultGraph(self):
+ self._RunTestOverParameters(self._TestRewrite)
+
+ def _TestDefaultGraph(self, fn):
+ with ops.Graph().as_default() as g:
+ batch_size, height, width, depth = 5, 128, 128, 3
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ conv = layers.conv2d(
+ inputs,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+ _ = nn_ops.relu6(conv)
+
+ orig_variable_names = set(
+ [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ fn()
+
+ q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ # Ensure that variables were added.
+ self.assertTrue(len(orig_variable_names) < len(q_variables))
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.