Update the rewriter options with the optimizer options
authorBenoit Steiner <bsteiner@google.com>
Fri, 6 Apr 2018 19:22:17 +0000 (12:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 19:30:14 +0000 (12:30 -0700)
PiperOrigin-RevId: 191923287

tensorflow/python/framework/function_test.py

index 83d256f..c05396b 100644 (file)
@@ -58,12 +58,32 @@ def _OptimizerOptions():
   for cse in [False, True]:
     for inline in [False, True]:
       for cfold in [False, True]:
-        yield config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
-            optimizer_options=config_pb2.OptimizerOptions(
-                opt_level=config_pb2.OptimizerOptions.L0,
-                do_common_subexpression_elimination=cse,
-                do_function_inlining=inline,
-                do_constant_folding=cfold)))
+        cfg = config_pb2.ConfigProto(
+            graph_options=config_pb2.GraphOptions(
+                optimizer_options=config_pb2.OptimizerOptions(
+                    opt_level=config_pb2.OptimizerOptions.L0,
+                    do_common_subexpression_elimination=cse,
+                    do_function_inlining=inline,
+                    do_constant_folding=cfold)))
+        if cse:
+          cfg.graph_options.rewrite_options.arithmetic_optimization = (
+              rewriter_config_pb2.RewriterConfig.ON)
+        else:
+          cfg.graph_options.rewrite_options.arithmetic_optimization = (
+              rewriter_config_pb2.RewriterConfig.OFF)
+        if inline:
+          cfg.graph_options.rewrite_options.function_optimization = (
+              rewriter_config_pb2.RewriterConfig.ON)
+        else:
+          cfg.graph_options.rewrite_options.function_optimization = (
+              rewriter_config_pb2.RewriterConfig.OFF)
+        if cfold:
+          cfg.graph_options.rewrite_options.constant_folding = (
+              rewriter_config_pb2.RewriterConfig.ON)
+        else:
+          cfg.graph_options.rewrite_options.constant_folding = (
+              rewriter_config_pb2.RewriterConfig.OFF)
+        yield cfg
 
 
 @test_util.with_c_api