Ignore nodes that are going to be swapped when computing max memory usage
authorBenoit Steiner <bsteiner@google.com>
Tue, 9 Jan 2018 02:15:51 +0000 (18:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 9 Jan 2018 02:19:40 +0000 (18:19 -0800)
PiperOrigin-RevId: 181248577

tensorflow/core/grappler/costs/graph_memory.cc
tensorflow/core/grappler/costs/graph_memory_test.cc
tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc

index 3168758c8bdb2c01a38d11d9f45252d9f9a78ad0..3604de392f803b8b2eb65e796848c2c3ec6a90e5 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/graph_memory.h"
 #include <list>
 #include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/framework/tensor_description.pb.h"
@@ -163,6 +164,8 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
 
   NodeMap node_map(&item_.graph);
   for (const auto& dev_stats : timeline.dev_stats()) {
+    const string& device_name = dev_stats.device();
+    const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
     std::list<LiveTensor>& device_tensors =
         live_tensors_per_device[dev_stats.device()];
     for (const auto& node_stats : dev_stats.node_stats()) {
@@ -194,7 +197,24 @@ void GraphMemory::InferFromTrace(const StepStats& timeline) {
         // graph (e.g _Send/_Recv nodes).
         continue;
       }
-      for (const string& input : node->input()) {
+      std::unordered_set<int> swapped_inputs;
+      if (is_gpu) {
+        auto it = node->attr().find("_swap_to_host");
+        if (it != node->attr().end()) {
+          const AttrValue& val = it->second;
+          for (int port_id : val.list().i()) {
+            swapped_inputs.insert(port_id);
+          }
+        }
+      }
+      for (int i = 0; i < node->input_size(); ++i) {
+        if (swapped_inputs.find(i) != swapped_inputs.end()) {
+          // The memory of swapped inputs will be released as early as possible:
+          // therefore ignore this input when determining the deallocation time
+          // of the tensor.
+          continue;
+        }
+        const string& input = node->input(i);
         int position;
         string input_node = ParseNodeName(input, &position);
         if (position < 0) {
index 6f3522b068bdb74eb98d3e6071d4d4b2e21c9ff6..95170ba49b77ef1be629cfa77bc4a333d2315e4f 100644 (file)
@@ -134,6 +134,62 @@ TEST_F(GraphMemoryTest, MultiDevice) {
   EXPECT_EQ(gpu_expected, gpu_tensors);
 }
 
+TEST_F(GraphMemoryTest, GpuSwapping) {
+  TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+  item.feed.clear();
+
+  {
+    // Estimate the max memory usage for the graph.
+    GraphMemory memory(item);
+    Status s = memory.InferStatically(devices_);
+    TF_CHECK_OK(s);
+
+    const GraphMemory::MemoryUsage& gpu_mem =
+        memory.GetPeakMemoryUsage("/GPU:0");
+    EXPECT_EQ(20971520, gpu_mem.used_memory);
+    std::set<string> gpu_tensors;
+    for (const auto& t : gpu_mem.live_tensors) {
+      gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+    }
+    std::set<string> gpu_expected;
+    gpu_expected.insert("Square:0");
+    gpu_expected.insert("Square_1:0");
+    gpu_expected.insert("AddN:0");
+    gpu_expected.insert("AddN_1:0");
+    gpu_expected.insert("AddN_2:0");
+    EXPECT_EQ(gpu_expected, gpu_tensors);
+  }
+
+  {
+    // Swap the first input to node AddN_1: its fanin (the square nodes) should
+    // not appear in the max cut anymore.
+    for (auto& node : *item.graph.mutable_node()) {
+      if (node.name() == "AddN_1") {
+        (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0);
+      }
+    }
+    GraphMemory memory(item);
+    Status s = memory.InferStatically(devices_);
+    TF_CHECK_OK(s);
+    const GraphMemory::MemoryUsage& new_gpu_mem =
+        memory.GetPeakMemoryUsage("/GPU:0");
+    EXPECT_EQ(20971520, new_gpu_mem.used_memory);
+    std::set<string> new_gpu_tensors;
+    for (const auto& t : new_gpu_mem.live_tensors) {
+      new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+    }
+    std::set<string> new_gpu_expected;
+    new_gpu_expected.insert("AddN:0");
+    new_gpu_expected.insert("AddN_1:0");
+    new_gpu_expected.insert("AddN_2:0");
+    new_gpu_expected.insert("AddN_3:0");
+    new_gpu_expected.insert("AddN_4:0");
+    EXPECT_EQ(new_gpu_expected, new_gpu_tensors);
+  }
+}
+
 TEST_F(GraphMemoryTest, CtrlDependencies) {
   // Build a simple graph with a control dependency.
   Scope s = Scope::NewRootScope();
index 6d25556770d13058ba65045eff787b12c0ca12de..ec54bd5c7598a5acb5bf653bb2902f6c3aba38f6 100644 (file)
@@ -31,8 +31,6 @@ namespace {
 GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
                         bool use_multiple_devices, bool insert_queue,
                         const std::vector<string>& device_names) {
-  CHECK_GE(device_names.size(), width);
-
   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
 
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
     std::vector<Output> this_stage;
     for (int j = 0; j < width; j++) {
       if (last_stage.size() == 1) {
-        Output unary_op =
-            Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
-                   last_stage[0]);
+        Output unary_op = Square(
+            s.WithDevice(
+                device_names[use_multiple_devices ? j % device_names.size()
+                                                  : 0]),
+            last_stage[0]);
         this_stage.push_back(unary_op);
       } else {
         Output combine =
-            AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
+            AddN(s.WithDevice(
+                     device_names[use_multiple_devices ? j % device_names.size()
+                                                       : 0]),
                  last_stage);
         this_stage.push_back(combine);
       }