Improved detection of swappable nodes
authorBenoit Steiner <bsteiner@google.com>
Fri, 19 Jan 2018 17:12:34 +0000 (09:12 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 19 Jan 2018 17:16:40 +0000 (09:16 -0800)
PiperOrigin-RevId: 182542749

tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc

index 72791cbf6fc2763795e37ecb1b1ece90678cf98d..8418abd80f84a7675dd34414dc582fb31089672b 100644 (file)
@@ -775,8 +775,10 @@ static const NodeDef* FindSwapInTrigger(
   return nullptr;
 }
 
-static bool IsSwappable(GraphView::OutputPort output) {
+static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
   const NodeDef& node = *output.node;
+  // There is no point in swapping out persistent tensors, since the tensor will
+  // continue to use memory.
   if (IsPersistent(node)) {
     return false;
   }
@@ -785,13 +787,29 @@ static bool IsSwappable(GraphView::OutputPort output) {
   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
     return false;
   }
-
   DataType dtype;
   if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
     return false;
   }
+  // References can only refer to persistent memory: therefore the node isn't
+  // swappable.
+  if (IsRefType(dtype)) {
+    return false;
+  }
 
-  return !IsRefType(dtype);
+  if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
+    // If placed on the same device, these nodes are just forwarding references
+    // to their input. Therefore they are swappable iff their fanin is swappable
+    // or it resides on a different device.
+    GraphView::InputPort input;
+    input.node = output.node;
+    input.port_id = 0;
+    GraphView::OutputPort fanin = graph.GetRegularFanin(input);
+    if (fanin.node->device() == node.device()) {
+      return IsSwappable(graph, fanin);
+    }
+  }
+  return true;
 }
 
 static NodeDef* FindSwapOutTrigger(
@@ -898,10 +916,9 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
         // Don't bother with small tensors.
         continue;
       }
-      // Don't try to swap out persistent data
       GraphView::OutputPort port =
           graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
-      if (!IsSwappable(port)) {
+      if (!IsSwappable(graph, port)) {
         continue;
       }
       Costs::NanoSeconds execution_time(-1);
index f507178bcedd14c41d09fdf282706afe7db62177..185ac6040c4ce85ca5e7f8eadbe41b05fbe339df 100644 (file)
@@ -271,14 +271,15 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
 
 TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+  Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
                            {128, 128, 8}, DT_FLOAT);
-  Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
-  Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
-  Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+  Output a = ops::Identity(s.WithOpName("a").WithDevice("/gpu:0"), v);
+  Output b = ops::Square(s.WithOpName("b").WithDevice("/gpu:0"), v);
+  Output c = ops::Sqrt(s.WithOpName("c").WithDevice("/gpu:0"), a);
+  Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), b);
   Output axis = ops::Const(s.WithOpName("axis"), 0);
   Output e =
-      ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+      ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {a, b, c, d}, axis);
 
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -300,21 +301,22 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
       for (int64 input_id : val.list().i()) {
         inputs_to_swap.insert(input_id);
       }
-      EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+      EXPECT_EQ(std::set<int>({1, 2, 3}), inputs_to_swap);
     }
   }
 }
 
 TEST_F(MemoryOptimizerTest, UnswappableInputs) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+  Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
                            {128, 128, 8}, DT_FLOAT);
+  Output a = ops::Square(s.WithOpName("a").WithDevice("/gpu:0"), v);
   Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
   Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
   Output index = ops::Const(s.WithOpName("index"), {0});
   Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
   Output d =
-      ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), a, indices, c);
+      ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), v, indices, c);
   Output axis = ops::Const(s.WithOpName("axis"), 0);
   Output e =
       ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);