Make optimize_for_inference_test.py work the C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Mon, 29 Jan 2018 19:20:39 +0000 (11:20 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 29 Jan 2018 19:24:40 +0000 (11:24 -0800)
PiperOrigin-RevId: 183696425

tensorflow/python/tools/optimize_for_inference_test.py

index 6dd24c0..7686bb0 100644 (file)
@@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import image_ops
@@ -38,6 +39,7 @@ from tensorflow.python.platform import test
 from tensorflow.python.tools import optimize_for_inference_lib
 
 
+@test_util.with_c_api
 class OptimizeForInferenceTest(test.TestCase):
 
   def create_node_def(self, op, name, inputs):
@@ -145,7 +147,7 @@ class OptimizeForInferenceTest(test.TestCase):
           np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
       gamma_op = constant_op.constant(
           np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
-      ops.get_default_graph().graph_def_versions.producer = 8
+      test_util.set_producer_version(ops.get_default_graph(), 8)
       gen_nn_ops._batch_norm_with_global_normalization(
           conv_op,
           mean_op,