From ef59be7e91a2b61c73b71086a43cfc7d96374e99 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Mon, 12 Feb 2018 14:05:08 -0800 Subject: [PATCH] Add tests for visible api arguments in quantize_graph. PiperOrigin-RevId: 185432142 --- .../contrib/quantize/python/quantize_graph_test.py | 140 ++++++++++++++------- 1 file changed, 96 insertions(+), 44 deletions(-) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index c57fcd4..5c65a16 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -28,13 +28,11 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -# TODO(suharshs): Add tests for testing experimental APIs and additional -# input arguments class QuantizeGraphTest(test_util.TensorFlowTestCase): # We have a lot of other tests that test the details of the rewrite, here we # just the specific features of the quantize_graph API. - def _RunTestOverParameters(self, test_fn): + def _RunTestOverAllRewrites(self, test_fn): rewrite_fns = [ quantize_graph.create_training_graph, quantize_graph.create_eval_graph, @@ -44,71 +42,125 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for fn in rewrite_fns: test_fn(fn) + def _RunTestOverTrainingRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.create_training_graph, + quantize_graph.experimental_create_training_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def _RunTestOverExperimentalRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.experimental_create_training_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + def testRewrite(self): - self._RunTestOverParameters(self._TestRewrite) + self._RunTestOverAllRewrites(self._TestRewrite) - def _TestRewrite(self, fn): + def _TestRewrite(self, rewrite_fn): graph = ops.Graph() with graph.as_default(): - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) + self._ConvLayer() orig_variable_names = set( [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - fn(graph) + rewrite_fn(graph) q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) def testDefaultGraph(self): - self._RunTestOverParameters(self._TestRewrite) + self._RunTestOverAllRewrites(self._TestRewrite) - def _TestDefaultGraph(self, fn): + def _TestDefaultGraph(self, rewrite_fn): + # Tests that the default graph is correctly used when no args are provided + # to rewrite_fn. with ops.Graph().as_default() as g: - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) - + self._ConvLayer() orig_variable_names = set( [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - - fn() + rewrite_fn() q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) - def _WeightInit(self, stddev): - """Returns truncated normal variable initializer. + def testQuantDelay(self): + self._RunTestOverTrainingRewrites(self._TestQuantDelay) - Function is defined purely to shorten the name so that it stops wrapping. - - Args: - stddev: Standard deviation of normal variable. - - Returns: - An initialized that initialzes with a truncated normal variable. - """ - return init_ops.truncated_normal_initializer(stddev=stddev) + def _TestQuantDelay(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + quant_delay = 100 + rewrite_fn(quant_delay=quant_delay) + + quant_delay_found = False + for op in g.get_operations(): + # Check to see if the quant_delay is correctly set. + if 'activate_quant' in op.name and op.type == 'Const': + quant_delay_found = True + const_value = str(op.get_attr('value')) + self.assertTrue(('int64_val: %i' % quant_delay) in const_value) + self.assertTrue(quant_delay_found) + + def testWeightBits(self): + self._RunTestOverExperimentalRewrites(self._TestWeightBits) + + def _TestWeightBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + weight_bits = 4 + rewrite_fn(weight_bits=weight_bits) + + weights_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for weights have the right bits + # set. + if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars': + weights_quant_found = True + self.assertEqual(op.get_attr('num_bits'), weight_bits) + self.assertTrue(weights_quant_found) + + def testActivationBits(self): + self._RunTestOverExperimentalRewrites(self._TestActivationBits) + + def _TestActivationBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + activation_bits = 4 + rewrite_fn(activation_bits=activation_bits) + + act_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for activations have the right bits + # set. + act_quant_names = ['act_quant', 'conv_quant', 'add_quant'] + if any(s in op.name + for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars': + act_quant_found = True + self.assertEqual(op.get_attr('num_bits'), activation_bits) + self.assertTrue(act_quant_found) + + def _ConvLayer(self): + """Add a basic convolution layer to the default graph.""" + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + weight_init = init_ops.truncated_normal_initializer + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=weight_init(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) if __name__ == '__main__': -- 2.7.4