Automated g4 rollback of changelist 187582263
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Mar 2018 20:58:08 +0000 (12:58 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 21:02:11 +0000 (13:02 -0800)
PiperOrigin-RevId: 187657654

tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer.h
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/protobuf/rewriter_config.proto
tensorflow/python/grappler/memory_optimizer_test.py

index 694139fa5033410375fcfae2f1141c82fa9d550c..27e9d2c78d0456e61d31f7f772172fb8d17a11ac 100644 (file)
@@ -413,7 +413,7 @@ void RecomputeSubgraph(
 }
 
 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
-                                const string& recomputation_targets_name_prefix,
+                                const string& recomputation_targets_name_scope,
                                 GraphDef* graph, const GrapplerItem& item) {
   if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
       optimization_level != RewriterConfig::HEURISTICS &&
@@ -438,15 +438,14 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
     feeds.insert(NodeName(feed.first));
   }
   std::function<bool(const NodeDef&)> is_target =
-      [&recomputation_targets_name_prefix](const NodeDef& node) {
-        // Nodes whose inputs we may want to recompute. Typically targets will
-        // be gradients (recomputation_targets_name_prefix="gradients/"),
-        // although the prefix is configurable since gradients may be created
-        // in a name scope.
-        // TODO(allenl): Use a static schedule
-        // (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
-        // whose outputs will sit around for a while.
-        return node.name().find(recomputation_targets_name_prefix) == 0;
+      [&recomputation_targets_name_scope](const NodeDef& node) {
+        // Nodes whose inputs we may want to recompute. This matches node names
+        // that contain recomputation_targets_name_scope as a name scope,
+        // meaning it either begins with or contains the name scope.
+        // Defaults to "gradients/" which will match any node names that begins
+        // with "gradients/" or contains "/gradients/".
+        return node.name().find(recomputation_targets_name_scope) == 0 ||
+               node.name().find("/" + recomputation_targets_name_scope) != -1;
       };
 
   if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
@@ -1225,8 +1224,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   *optimized_graph = item.graph;
 
   RecomputationRewritingPass(optimization_level_,
-                             recomputation_targets_name_prefix_,
-                             optimized_graph, item);
+                             recomputation_targets_name_scope_, optimized_graph,
+                             item);
 
   GrapplerItem optimized_item(item, std::move(*optimized_graph));
   std::unordered_set<string> skip_list;
index c3dd0c45c6c524ef850ce7cfb9f6543d22e783ec..5c555a26746b759500f3d778ce137d6d9bedb67b 100644 (file)
@@ -27,14 +27,14 @@ class MemoryOptimizer : public GraphOptimizer {
  public:
   // optimization_level: Controls the level of autonomy for the memory
   //   optimizer. See RewriterConfig::memory_optimization.
-  // recomputation_targets_name_prefix: Name prefix for potential outputs of
+  // recomputation_targets_name_scope: Name scope for potential outputs of
   //   recomputations. See
-  //   RewriterConfig::memory_optimizer_target_node_name_prefix.
+  //   RewriterConfig::memory_optimizer_target_node_name_scope.
   explicit MemoryOptimizer(
       RewriterConfig::MemOptType optimization_level,
-      const string& recomputation_targets_name_prefix = "gradients/")
+      const string& recomputation_targets_name_scope = "gradients/")
       : optimization_level_(optimization_level),
-        recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
+        recomputation_targets_name_scope_(recomputation_targets_name_scope) {}
   ~MemoryOptimizer() override {}
 
   string name() const override { return "memory_optimizer"; };
@@ -47,7 +47,7 @@ class MemoryOptimizer : public GraphOptimizer {
 
  private:
   RewriterConfig::MemOptType optimization_level_;
-  string recomputation_targets_name_prefix_;
+  string recomputation_targets_name_scope_;
 };
 
 }  // end namespace grappler
index 72d7b94dc8c27b6843be94ebbb715c8c5d215de3..fff1e354f47e0f0ea8f450a1dd26ad5cc54d8bbe 100644 (file)
@@ -119,7 +119,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
           std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
     }
     if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
-      if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
+      if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
         optimizers.push_back(std::unique_ptr<GraphOptimizer>(
             // Use the default target node name prefix "gradients/"
             new MemoryOptimizer(cfg_.memory_optimization())));
@@ -127,7 +127,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
         optimizers.push_back(
             std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
                 cfg_.memory_optimization(),
-                cfg_.memory_optimizer_target_node_name_prefix())));
+                cfg_.memory_optimizer_target_node_name_scope())));
       }
     }
     if (cfg_.auto_parallel().enable()) {
index 9ebf21781110c2d4005dda0b9c6cdf2608f80535..0ccf2149f2cd4865627f7ab42441b8d94dc5fff4 100644 (file)
@@ -78,16 +78,15 @@ message RewriterConfig {
   // effect on manually requested memory optimization passes in the optimizers
   // field.
   MemOptType memory_optimization = 4;
-  // The prefix for nodes which are valid outputs of recomputations. Inputs to
-  // nodes with this name prefix may be recomputed (subject either to manual
-  // annotation of those input nodes or to manual annotation and heuristics
-  // depending on memory_optimization), but the prefixed nodes themselves will
-  // not be recomputed. Typically this will be "gradients/", indicating that
-  // activations from the forward pass of a graph may be recomputed as inputs to
-  // gradients, but may be adjusted if gradients are inside a name scope or if
-  // inputs to non-gradients should be recomputed. Defaults to "gradients/" if
-  // empty or not set.
-  string memory_optimizer_target_node_name_prefix = 6;
+  // A node name scope for node names which are valid outputs of recompuations.
+  // Inputs to nodes that match this scope may be recomputed (subject either to
+  // manual annotation of those input nodes or to manual annotation and
+  // heuristics depending on memory_optimization), but the nodes themselves will
+  // not be recomputed. This matches any sub-scopes as well, meaning the scope
+  // can appear not just as a top-level scope. For example, if the value is
+  // "gradients/", the default, it will match node name "gradients/foo",
+  // "foo/gradients/bar", but not "foo_gradients/"
+  string memory_optimizer_target_node_name_scope = 6;
 
   // Configures AutoParallel optimization passes either through the
   // meta-optimizer or when manually specified through the optimizers field.
index 948911f099674af4c6dd19bfdac75e5fc1f75c78..4df959ce04169395589aeebaef9e3e7839e2300c 100644 (file)
@@ -162,7 +162,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
             arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
             memory_optimization=rewriter_config_pb2.RewriterConfig.
             RECOMPUTATION_HEURISTICS,
-            memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
+            # Checks that name scope "gradients/" also match sub-scope.
+            memory_optimizer_target_node_name_scope='gradients/'),
         original_metagraph)
     self.assertGreater(
         len(rewritten_graph_def.node),
@@ -176,6 +177,35 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
         len([node for node in rewritten_graph_def.node
              if 'Recomputed/' in node.name]))
 
+  def testRewritingNameScopedGradientNamesScope(self):
+    """Tests that rewriting occurs with non-standard gradient names."""
+    (original_metagraph, _, _,
+     _) = self._GetMetaGraph(optimizer_scope_name='foo/bar')
+    rewritten_graph_def = tf_optimizer.OptimizeGraph(
+        rewriter_config_pb2.RewriterConfig(
+            disable_model_pruning=True,
+            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+            layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
+            arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+            memory_optimization=rewriter_config_pb2.RewriterConfig.
+            RECOMPUTATION_HEURISTICS,
+            # This should not match anything.
+            memory_optimizer_target_node_name_scope='r/gradients/'),
+        original_metagraph)
+    self.assertEqual(
+        len(rewritten_graph_def.node), len(original_metagraph.graph_def.node))
+    self.assertEqual(0,
+                     len([
+                         node for node in original_metagraph.graph_def.node
+                         if 'Recomputed/' in node.name
+                     ]))
+    self.assertEqual(0,
+                     len([
+                         node for node in rewritten_graph_def.node
+                         if 'Recomputed/' in node.name
+                     ]))
+
   def _GetMemoryOptimizerSessionConfig(self):
     rewrite_options = rewriter_config_pb2.RewriterConfig(
         disable_model_pruning=True,