Automated g4 rollback of changelist 187563544
authorGunhan Gulsoy <gunan@google.com>
Fri, 2 Mar 2018 06:25:41 +0000 (22:25 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 06:29:38 +0000 (22:29 -0800)
PiperOrigin-RevId: 187582263

tensorflow/core/BUILD
tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer.h
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/protobuf/rewriter_config.proto
tensorflow/python/grappler/memory_optimizer_test.py

index 96e30ca3c0a436c2afd1910d6e9e17139bdb7f50..3271825251385387c0f085c675cbae1655c1af75 100644 (file)
@@ -2231,7 +2231,6 @@ cc_library(
     ],
     visibility = [
         "//tensorflow/compiler:__subpackages__",
-        "//tensorflow/core/grappler:__subpackages__",
         "//tensorflow/core/profiler:__subpackages__",
     ],
     deps = [":lib_internal"],
index 0a4330b524864147b4566ee834758bd588d930ae..037438ee751987adbec522f5366258d2cc790a54 100644 (file)
@@ -363,7 +363,6 @@ cc_library(
         ":graph_rewriter",
         ":static_schedule",
         "//tensorflow/core:framework",
-        "//tensorflow/core:regexp_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:graph_view",
index d73050ac4d5a7ba749abf82397ac747b1c4612ec..694139fa5033410375fcfae2f1141c82fa9d550c 100644 (file)
@@ -36,7 +36,6 @@ limitations under the License.
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/grappler/utils/topological_sort.h"
 #include "tensorflow/core/grappler/utils/traversal.h"
-#include "tensorflow/core/platform/regexp.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 
 namespace tensorflow {
@@ -414,7 +413,7 @@ void RecomputeSubgraph(
 }
 
 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
-                                const string& recomputation_targets_name_regexp,
+                                const string& recomputation_targets_name_prefix,
                                 GraphDef* graph, const GrapplerItem& item) {
   if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
       optimization_level != RewriterConfig::HEURISTICS &&
@@ -438,19 +437,16 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
   for (const auto& feed : item.feed) {
     feeds.insert(NodeName(feed.first));
   }
-  RE2 recomputation_targets_re(recomputation_targets_name_regexp);
   std::function<bool(const NodeDef&)> is_target =
-      [&recomputation_targets_re](const NodeDef& node) {
-        // Nodes whose inputs we may want to recompute. This does a prefix
-        // regexp match, and typically one sets regexp="gradients/" meaning
-        // it will match all node names with scope beginning with "gradients/".
-        // If used within scopes, one may want to set regexp="(.+/)?gradients/".
+      [&recomputation_targets_name_prefix](const NodeDef& node) {
+        // Nodes whose inputs we may want to recompute. Typically targets will
+        // be gradients (recomputation_targets_name_prefix="gradients/"),
+        // although the prefix is configurable since gradients may be created
+        // in a name scope.
         // TODO(allenl): Use a static schedule
         // (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
         // whose outputs will sit around for a while.
-        bool match = recomputation_targets_re.Match(
-            node.name(), 0, node.name().size(), RE2::ANCHOR_START, nullptr, 0);
-        return match;
+        return node.name().find(recomputation_targets_name_prefix) == 0;
       };
 
   if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
@@ -1229,7 +1225,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   *optimized_graph = item.graph;
 
   RecomputationRewritingPass(optimization_level_,
-                             recomputation_targets_name_regexp_,
+                             recomputation_targets_name_prefix_,
                              optimized_graph, item);
 
   GrapplerItem optimized_item(item, std::move(*optimized_graph));
index 62ab969848412ad7badd09d14b769b01e243b373..c3dd0c45c6c524ef850ce7cfb9f6543d22e783ec 100644 (file)
@@ -27,14 +27,14 @@ class MemoryOptimizer : public GraphOptimizer {
  public:
   // optimization_level: Controls the level of autonomy for the memory
   //   optimizer. See RewriterConfig::memory_optimization.
-  // recomputation_targets_name_regxp: Name regxp for potential outputs of
+  // recomputation_targets_name_prefix: Name prefix for potential outputs of
   //   recomputations. See
-  //   RewriterConfig::memory_optimizer_target_node_name_regxp.
+  //   RewriterConfig::memory_optimizer_target_node_name_prefix.
   explicit MemoryOptimizer(
       RewriterConfig::MemOptType optimization_level,
-      const string& recomputation_targets_name_regexp = "gradients/")
+      const string& recomputation_targets_name_prefix = "gradients/")
       : optimization_level_(optimization_level),
-        recomputation_targets_name_regexp_(recomputation_targets_name_regexp) {}
+        recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
   ~MemoryOptimizer() override {}
 
   string name() const override { return "memory_optimizer"; };
@@ -47,7 +47,7 @@ class MemoryOptimizer : public GraphOptimizer {
 
  private:
   RewriterConfig::MemOptType optimization_level_;
-  string recomputation_targets_name_regexp_;
+  string recomputation_targets_name_prefix_;
 };
 
 }  // end namespace grappler
index 979f3e716168f58e6b0ac2393b788d117c2057af..72d7b94dc8c27b6843be94ebbb715c8c5d215de3 100644 (file)
@@ -119,7 +119,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
           std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
     }
     if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
-      if (cfg_.memory_optimizer_target_node_name_regexp().empty()) {
+      if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
         optimizers.push_back(std::unique_ptr<GraphOptimizer>(
             // Use the default target node name prefix "gradients/"
             new MemoryOptimizer(cfg_.memory_optimization())));
@@ -127,7 +127,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
         optimizers.push_back(
             std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
                 cfg_.memory_optimization(),
-                cfg_.memory_optimizer_target_node_name_regexp())));
+                cfg_.memory_optimizer_target_node_name_prefix())));
       }
     }
     if (cfg_.auto_parallel().enable()) {
index 63303fa96841fa90b3ae51c61a1778e37771a8e7..9ebf21781110c2d4005dda0b9c6cdf2608f80535 100644 (file)
@@ -78,14 +78,16 @@ message RewriterConfig {
   // effect on manually requested memory optimization passes in the optimizers
   // field.
   MemOptType memory_optimization = 4;
-  // A regexp for node names which are valid outputs of recomputations. Inputs
-  // to nodes that match this regexp may be recomputed (subject either to manual
+  // The prefix for nodes which are valid outputs of recomputations. Inputs to
+  // nodes with this name prefix may be recomputed (subject either to manual
   // annotation of those input nodes or to manual annotation and heuristics
-  // depending on memory_optimization), but the nodes themselves will not be
-  // recomputed. This is a prefix match, meaning it matches any node name that
-  // contains a prefix that matches this regexp. Defaults to "gradients/" if
-  // not provided, but can be changed if used within scopes.
-  string memory_optimizer_target_node_name_regexp = 6;
+  // depending on memory_optimization), but the prefixed nodes themselves will
+  // not be recomputed. Typically this will be "gradients/", indicating that
+  // activations from the forward pass of a graph may be recomputed as inputs to
+  // gradients, but may be adjusted if gradients are inside a name scope or if
+  // inputs to non-gradients should be recomputed. Defaults to "gradients/" if
+  // empty or not set.
+  string memory_optimizer_target_node_name_prefix = 6;
 
   // Configures AutoParallel optimization passes either through the
   // meta-optimizer or when manually specified through the optimizers field.
index 58d3c1e85f0a4fc977c468e9f41521b1fa7aa1ae..948911f099674af4c6dd19bfdac75e5fc1f75c78 100644 (file)
@@ -162,34 +162,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
             arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
             memory_optimization=rewriter_config_pb2.RewriterConfig.
             RECOMPUTATION_HEURISTICS,
-            memory_optimizer_target_node_name_regexp='optimizer/gradients/'),
-        original_metagraph)
-    self.assertGreater(
-        len(rewritten_graph_def.node),
-        len(original_metagraph.graph_def.node))
-    self.assertEqual(
-        0,
-        len([node for node in original_metagraph.graph_def.node
-             if 'Recomputed/' in node.name]))
-    self.assertEqual(
-        20,  # Two per layer
-        len([node for node in rewritten_graph_def.node
-             if 'Recomputed/' in node.name]))
-
-  def testRewritingNameScopedGradientNamesRegexp(self):
-    """Tests that rewriting occurs with non-standard gradient names."""
-    (original_metagraph, _, _, _) = self._GetMetaGraph(
-        optimizer_scope_name='foo/bar')
-    rewritten_graph_def = tf_optimizer.OptimizeGraph(
-        rewriter_config_pb2.RewriterConfig(
-            disable_model_pruning=True,
-            constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
-            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
-            layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
-            arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
-            memory_optimization=rewriter_config_pb2.RewriterConfig.
-            RECOMPUTATION_HEURISTICS,
-            memory_optimizer_target_node_name_regexp='(.+/)gradients/'),
+            memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
         original_metagraph)
     self.assertGreater(
         len(rewritten_graph_def.node),