Run 2 passes of rewrites by default
authorBenoit Steiner <bsteiner@google.com>
Thu, 26 Apr 2018 20:19:39 +0000 (13:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 20:21:59 +0000 (13:21 -0700)
PiperOrigin-RevId: 194443770

tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/python/estimator/estimator.py
tensorflow/python/grappler/memory_optimizer_test.py

index c42d614..2edc4da 100644 (file)
@@ -39,7 +39,7 @@ namespace grappler {
 
 namespace {
 
-constexpr int kDefaultNumberOfIterations = 1;
+constexpr int kDefaultNumberOfIterations = 2;
 
 int64 NumEdges(const GraphDef& graph) {
   int64 num_edges = 0;
@@ -63,7 +63,10 @@ int NumIterations(const RewriterConfig& cfg) {
 }
 
 // Check if optimizer is allowed to run only once.
-bool IsRunOnceOptimizer(const string& name) { return name == "layout"; }
+bool IsRunOnceOptimizer(const string& name) {
+  return name == "layout" || name == "memory_optimizer" ||
+         name == "arithmetic_optimizer" || name == "loop_optimizer";
+}
 
 }  // namespace
 
index 2f1212d..2363845 100644 (file)
@@ -30,6 +30,7 @@ import six
 from google.protobuf import message
 from tensorflow.core.framework import summary_pb2
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session as tf_session
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
@@ -203,7 +204,11 @@ class Estimator(object):
     logging.info('Using config: %s', str(vars(self._config)))
 
     if self._config.session_config is None:
-      self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+      rewrite_opts = rewriter_config_pb2.RewriterConfig(
+          meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+      graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+      self._session_config = config_pb2.ConfigProto(
+          allow_soft_placement=True, graph_options=graph_opts)
     else:
       self._session_config = self._config.session_config
 
index 4df959c..3f9d886 100644 (file)
@@ -76,6 +76,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
 
     rewriter_config = rewriter_config_pb2.RewriterConfig(
         disable_model_pruning=True,
+        meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE,
         constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
         memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
     graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)