Enable the use of scheduling heuristics to reduce peak memory usage by default
authorBenoit Steiner <bsteiner@google.com>
Mon, 12 Feb 2018 20:05:49 +0000 (12:05 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 20:09:29 +0000 (12:09 -0800)
PiperOrigin-RevId: 185413855

tensorflow/core/grappler/costs/virtual_scheduler.cc
tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/grappler/optimizers/model_pruner.cc
tensorflow/core/protobuf/rewriter_config.proto

index 020492a..14b4ed7 100644 (file)
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/utils.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/logging.h"
@@ -446,13 +447,14 @@ Status VirtualScheduler::Init() {
   }
 
   if (ready_nodes_->Empty()) {
-    return Status(error::UNAVAILABLE, "No ready nodes in the graph.");
+    return errors::InvalidArgument("No ready nodes in the graph.");
   }
 
-  if (!feed_nodes.empty())
-    LOG(ERROR) << "Some feed nodes were not found in the graph: "
-               << str_util::Join(feed_nodes, ",");
-
+  if (!feed_nodes.empty()) {
+    return errors::InvalidArgument(
+        strings::StrCat("Some feed nodes were not found in the graph: ",
+                        str_util::Join(feed_nodes, ",")));
+  }
   initialized_ = true;
   return Status::OK();
 }
index 9f3e940..ef178ad 100644 (file)
@@ -511,6 +511,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
     }
   }
 
+  if (addn_list.empty()) {
+    return false;
+  }
+
   GraphMemory memory(*item);
   const std::unordered_map<string, DeviceProperties>& devices =
       cluster->GetDevices();
@@ -560,6 +564,13 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
       VLOG(1) << "Missing properties for " << node->name();
       continue;
     }
+    const TensorShapeProto& shape =
+        properties.GetOutputProperties(node->name())[0].shape();
+    PartialTensorShape shp(shape);
+    if (!shp.IsFullyDefined()) {
+      VLOG(1) << "Shape not fully known for " << node->name();
+      continue;
+    }
 
     // Compute a topological ordering for the node fanin.
     std::unordered_map<NodeDef*, int> topo_order;
@@ -608,8 +619,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
       }
     }
 
-    const TensorShapeProto& shape =
-        properties.GetOutputProperties(node->name())[0].shape();
     DataType dtype = node->attr().at("T").type();
     const string& device = node->device();
 
@@ -1223,7 +1232,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   bool updated_graph = true;
   for (int i = 0; i < 25 && updated_graph; ++i) {
     updated_graph = false;
-    if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
+    if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
+         optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
          optimization_level_ == RewriterConfig::HEURISTICS) &&
         cluster != nullptr) {
       updated_graph |= SchedulingPass(cluster, &optimized_item);
index 6d93f74..ab7e05d 100644 (file)
@@ -101,7 +101,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
       optimizers.push_back(
           std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
     }
-    if (cfg_.memory_optimization() > 1) {
+    if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
       if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
         optimizers.push_back(std::unique_ptr<GraphOptimizer>(
             // Use the default target node name prefix "gradients/"
@@ -136,7 +136,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   bool already_optimized = false;
   for (const auto& optimizer : optimizers) {
     if (!already_optimized) {
-      auto status = optimizer->Optimize(cluster, item, optimized_graph);
+      Status status = optimizer->Optimize(cluster, item, optimized_graph);
       string result;
       if (!status.ok()) {
         VLOG(1) << "Not able to apply optimizer " << optimizer->name()
@@ -152,7 +152,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
               << " return status: " << result;
     } else {
       GrapplerItem optimized_item(item, std::move(*optimized_graph));
-      auto status =
+      Status status =
           optimizer->Optimize(cluster, optimized_item, optimized_graph);
       string result;
       if (!status.ok()) {
@@ -205,7 +205,8 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
          cfg.constant_folding() != RewriterConfig::OFF ||
          cfg.dependency_optimization() != RewriterConfig::OFF ||
          cfg.arithmetic_optimization() != RewriterConfig::OFF ||
-         cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 ||
+         cfg.auto_parallel().enable() ||
+         cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
          !cfg.optimizers().empty();
 }
 
index ece9df0..f52a2ab 100644 (file)
@@ -67,7 +67,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
       // let's be conservative and preserve the graph as is.
       return errors::InvalidArgument("Invalid input graph.");
     }
-    // Try to keep the nodes ordored somewhat topologically since this helps
+    // Try to keep the nodes ordered somewhat topologically since this helps
     // further optimizations perform better.
     for (int i = keep.size() - 1; i >= 0; --i) {
       *runnable_item.graph.add_node() = *keep[i];
index dddadce..77667e4 100644 (file)
@@ -41,7 +41,7 @@ message RewriterConfig {
   bool disable_model_pruning = 2;
 
   enum MemOptType {
-    // The default setting (currently disabled)
+    // The default setting (SCHEDULING_HEURISTICS only)
     DEFAULT_MEM_OPT = 0;
     // Disabled in the meta-optimizer.
     NO_MEM_OPT = 1;