Relax the stringent memory allocator constraints in AssignOp if a Grappler graph...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 May 2018 20:34:39 +0000 (13:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 20:37:38 +0000 (13:37 -0700)
PiperOrigin-RevId: 194988134

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/memory_optimizer.cc
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils.h
tensorflow/core/kernels/assign_op.h
tensorflow/core/kernels/resource_variable_ops.cc

index 839b0bb..bf6d4c0 100644 (file)
@@ -54,6 +54,10 @@ bool IsApproximateEqual(const NodeDef& node) {
 
 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
 
+bool IsAssign(const NodeDef& node) {
+  return node.op() == "Assign" || node.op() == "AssignVariableOp";
+}
+
 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
 
 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
index bd8d3a4..3dddf3f 100644 (file)
@@ -30,6 +30,7 @@ bool IsAnyDiv(const NodeDef& node);
 bool IsApproximateEqual(const NodeDef& node);
 bool IsAvgPoolGrad(const NodeDef& node);
 bool IsAssert(const NodeDef& node);
+bool IsAssign(const NodeDef& node);
 bool IsAtan2(const NodeDef& node);
 bool IsBetainc(const NodeDef& node);
 bool IsBiasAdd(const NodeDef& node);
index c1fee0e..7c6468b 100644 (file)
@@ -1219,6 +1219,80 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
   return updated_graph;
 }
 
+// TODO(rmlarsen): Add distributed TF test.
+Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
+  std::unordered_set<string> devices;
+  std::vector<int> assign_nodes;
+  bool found_send = false;
+  for (int i = 0; i < optimized_graph->node_size(); ++i) {
+    const NodeDef& node = optimized_graph->node(i);
+    devices.insert(node.device());
+    if (IsAssign(node)) {
+      assign_nodes.push_back(i);
+    }
+    if (IsSend(node)) {
+      found_send = true;
+      break;
+    }
+  }
+  if (!found_send && devices.size() == 1) {
+    for (int assign_idx : assign_nodes) {
+      // Set an attribute telling AssignOp to ignore allocator constraints.
+      NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
+      (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
+          .set_b(true);
+    }
+    return Status::OK();
+  }
+
+  std::unordered_set<int> optimized_nodes;
+  SimpleGraphView graph_view;
+  TF_RETURN_IF_ERROR(graph_view.Initialize(*optimized_graph));
+  for (int i : assign_nodes) {
+    if (optimized_nodes.find(i) == optimized_nodes.end()) {
+      const NodeDef& node = optimized_graph->node(i);
+      optimized_nodes.insert(i);
+      std::vector<int> assign_nodes_in_fanout;
+      assign_nodes_in_fanout.push_back(i);
+      std::set<int> transitive_fanout;
+      graph_view.DepthFirstSearch(std::unordered_set<string>{}, i,
+                                  &transitive_fanout);
+      const string& assign_device = node.device();
+      bool relax_constraint = true;
+      // If all nodes in the transitive fanout are on the same device as the
+      // assign node, there is no need to allocate the output in pinned memory.
+      for (int fanout : transitive_fanout) {
+        const NodeDef& fanout_node = optimized_graph->node(fanout);
+        if (relax_constraint &&
+            (fanout_node.device() != assign_device || IsSend(fanout_node))) {
+          relax_constraint = false;
+        }
+        if (optimized_nodes.find(fanout) == optimized_nodes.end() &&
+            IsAssign(fanout_node)) {
+          assign_nodes_in_fanout.push_back(fanout);
+        }
+      }
+
+      for (int assign_idx : assign_nodes_in_fanout) {
+        if (relax_constraint) {
+          // If all devices match in fanout of node(i) then, by transitivity,
+          // they must also match in the fanout of other assign nodes
+          // node(assign_idx) in the fanout, so we can process them here,
+          // and save computing their transitive fanout later.
+          optimized_nodes.insert(assign_idx);
+
+          // Set an attribute telling AssignOp to ignore allocator constraints.
+          NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
+          (*assign_node
+                ->mutable_attr())["_grappler_relax_allocator_constraints"]
+              .set_b(true);
+        }
+      }
+    }
+  }
+  return Status::OK();
+}
+
 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                  GraphDef* optimized_graph) {
   *optimized_graph = item.graph;
@@ -1251,6 +1325,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
     }
   }
 
+  TF_RETURN_IF_ERROR(RelaxAllocatorConstraints(&optimized_item.graph));
+
   optimized_graph->Swap(&optimized_item.graph);
   return Status::OK();
 }
index a1f8080..a3f0e07 100644 (file)
@@ -440,6 +440,140 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
   }
 }
 
+class RelaxAllocatorConstraintsTest : public GrapplerTest {};
+
+TEST_F(RelaxAllocatorConstraintsTest, SameDevice) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+                               -3.14f, {128, 128});
+  Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+                                  {128, 128}, DT_FLOAT);
+  Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+                              variable, constant);
+  Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/cpu:0"), assign);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  auto node = output.node(2);
+  EXPECT_EQ("assign", node.name());
+  EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
+  EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
+
+  item.fetch = {"exp"};
+  item.init_ops = {"variable"};
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, DifferentDevice) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+                               -3.14f, {128, 128});
+  Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+                                  {128, 128}, DT_FLOAT);
+  Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+                              variable, constant);
+  // exp runs on a different device, so we cannot relax the allocation
+  // constraints on assign.
+  Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/gpu:0"), assign);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  auto node = output.node(2);
+  EXPECT_EQ("assign", node.name());
+  EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+#if GOOGLE_CUDA
+  item.fetch = {"exp"};
+  item.init_ops = {"variable"};
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+#endif
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, SendNode) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+                               -3.14f, {128, 128});
+  Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+                                  {128, 128}, DT_FLOAT);
+  Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+                              variable, constant);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  NodeDef* send = item.graph.add_node();
+  // Add a send node to the graph in the fanout of "assign".
+  send->set_name("send");
+  send->set_op("_Send");
+  send->add_input("assign");
+
+  MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  auto node = output.node(2);
+  EXPECT_EQ("assign", node.name());
+  EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, AssignNodeInFanout) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output constant0 = ops::Const(s.WithOpName("constant0").WithDevice("/cpu:0"),
+                                -42.0f, {128, 128});
+  Output variable0 = ops::Variable(
+      s.WithOpName("variable0").WithDevice("/cpu:0"), {128, 128}, DT_FLOAT);
+  Output assign0 = ops::Assign(s.WithOpName("assign0").WithDevice("/cpu:0"),
+                               variable0, constant0);
+  // The rest of the graph is on a second device, so we can relax the
+  // constraint for assign1, but not for assign0.
+  Output exp1 = ops::Exp(s.WithOpName("exp1").WithDevice("/gpu:0"), assign0);
+  Output variable1 = ops::Variable(
+      s.WithOpName("variable1").WithDevice("/gpu:0"), {128, 128}, DT_FLOAT);
+  Output assign1 = ops::Assign(s.WithOpName("assign1").WithDevice("/gpu:0"),
+                               variable1, exp1);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  auto node = output.node(3);
+  EXPECT_EQ("assign0", node.name());
+  EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+
+  node = output.node(5);
+  EXPECT_EQ("assign1", node.name());
+  EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
+  EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
+
+#if GOOGLE_CUDA
+  item.fetch = {"assign0", "assign1"};
+  item.init_ops = {"exp1", "variable1"};
+  auto tensors_expected = EvaluateFetchNodes(item);
+  GrapplerItem optimized(item, std::move(output));
+  auto tensors = EvaluateFetchNodes(optimized);
+  for (int i = 0; i < tensors_expected.size(); ++i) {
+    test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
+  }
+#endif
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
index 6db6d71..c8e63f9 100644 (file)
@@ -435,7 +435,8 @@ void SimpleGraphView::DepthFirstSearch(
     std::set<int>* nodes_found) const {
   nodes_found->clear();
   const string& op_type = graph_->node(root_node).op();
-  if (op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
+  if (!op_types_to_traverse.empty() &&
+      op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
     return;
   }
   std::vector<int> stack;
@@ -446,7 +447,8 @@ void SimpleGraphView::DepthFirstSearch(
     stack.pop_back();
     nodes_found->insert(node_idx);
     const string& op_type = graph_->node(node_idx).op();
-    if (op_types_to_traverse.find(op_type) != op_types_to_traverse.end()) {
+    if (op_types_to_traverse.empty() ||
+        op_types_to_traverse.find(op_type) != op_types_to_traverse.end()) {
       for (auto output_idx : this->outputs(node_idx)) {
         if (nodes_found->find(output_idx) == nodes_found->end()) {
           stack.push_back(output_idx);
index 15f6b36..9776e99 100644 (file)
@@ -251,6 +251,7 @@ class SimpleGraphView {
   // visited in nodes_found. If a node has an op in `op_types_to_traverse`, the
   // walk continues to its children. It is assumed that *graph_ was not modified
   // after the call to Initialize().
+  // If `op_types_to_traverse` is empty the DFS will traverse any node type.
   void DepthFirstSearch(const std::unordered_set<string>& op_types_to_traverse,
                         int node_idx, std::set<int>* nodes_found) const;
 
index 19b38f9..a450b1d 100644 (file)
@@ -36,6 +36,12 @@ class AssignOp : public OpKernel {
                    context->GetAttr("validate_shape", &validate_shape_));
     OP_REQUIRES(context, IsRefType(context->input_type(0)),
                 errors::InvalidArgument("lhs input needs to be a ref type"));
+    if (!context
+             ->GetAttr("_grappler_relax_allocator_constraints",
+                       &relax_constraints_)
+             .ok()) {
+      relax_constraints_ = false;
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -44,48 +50,37 @@ class AssignOp : public OpKernel {
     // We always return the input ref.
     context->forward_ref_input_to_ref_output(0, 0);
 
-    // We can't always know how this value will be used downstream,
-    // so make conservative assumptions in specifying constraints on
-    // the memory allocation attributes.
-    // TODO(rmlarsen): These conservative constraints make buffer
-    // forwarding unlikely to happen very often. Try to use graph analysis
-    // (possibly the InferAllocAttr pass in the executer) to improve the
-    // situation.
+    // We can't always know how this value will be used downstream, so make
+    // conservative assumptions in specifying constraints on the memory
+    // allocation attributes, unless the Grappler graph analysis determined that
+    // it was safe not to.
     AllocatorAttributes attr;
-    attr.set_gpu_compatible(true);
-    attr.set_nic_compatible(true);
+    if (!relax_constraints_) {
+      attr.set_gpu_compatible(true);
+      attr.set_nic_compatible(true);
+    }
 
     {
       mutex_lock l(*context->input_ref_mutex(0));
       const Tensor& old_lhs = context->mutable_input(0, /* lock_held */ true);
       const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
       if (validate_shape_) {
-        OP_REQUIRES(
-            context, same_shape,
-            errors::InvalidArgument(
-                "Assign requires shapes of both tensors to match. lhs shape= ",
-                old_lhs.shape().DebugString(),
-                " rhs shape= ", rhs.shape().DebugString()));
+        OP_REQUIRES(context, same_shape,
+                    errors::InvalidArgument(
+                        "Assign requires shapes of both tensors to match. "
+                        "lhs shape= ",
+                        old_lhs.shape().DebugString(),
+                        " rhs shape= ", rhs.shape().DebugString()));
       }
 
       // In the code below we try to minimize the amount of memory allocation
       // and copying by trying the following two shortcuts:
-      // 1. If we can reuse the rhs buffer we avoid both a memory allocation
-      //   and copying.
-      // 2. If the lhs is initialized and has the same number of elements as the
-      //    rhs we can avoid a memory allocation.
-
-      // 1. Try to reuse the rhs.
-      std::unique_ptr<Tensor> input_alias = context->forward_input(
-          1, OpKernelContext::Params::kNoReservation /*output_index*/,
-          rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr);
-      if (input_alias != nullptr) {
-        // Transfer ownership to the ref.
-        context->replace_ref_input(0, *input_alias, /* lock_held */ true);
-        return;
-      }
+      // 1. If the lhs is initialized and has the same number of elements as
+      //    the rhs we can avoid a memory allocation.
+      // 2. If we can reuse the rhs buffer we avoid both a memory allocation
+      //    and copying.
 
-      // 2. Try to copy into an existing buffer.
+      // 1. Try to copy into an existing buffer.
       if (old_lhs.IsInitialized() &&
           old_lhs.shape().num_elements() == rhs.shape().num_elements()) {
         // The existing lhs tensor has already been initialized and the right
@@ -95,15 +90,26 @@ class AssignOp : public OpKernel {
           reshaped_old_lhs = old_lhs;
         } else {
           CHECK(reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape()));
-          context->replace_ref_input(0, reshaped_old_lhs, /* lock_held */ true);
+          context->replace_ref_input(0, reshaped_old_lhs,
+                                     /* lock_held */ true);
         }
         if (use_exclusive_lock_) {
           Copy(context, &reshaped_old_lhs, rhs);
           return;
         }
       } else {
-        // Create a new persistent tensor whose shape matches the right hand
-        // side, hand off to lhs and copy the rhs into it.
+        // 2. Try to reuse the rhs.
+        std::unique_ptr<Tensor> input_alias = context->forward_input(
+            1, OpKernelContext::Params::kNoReservation /*output_index*/,
+            rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr);
+        if (input_alias != nullptr) {
+          // Update the ref to point to the new buffer.
+          context->replace_ref_input(0, *input_alias, /* lock_held */ true);
+          return;
+        }
+
+        // Otherwise, create a new persistent tensor whose shape matches the
+        // right hand side, hand off to lhs and copy the rhs into it.
         PersistentTensor copy;
         Tensor* copyTensor = nullptr;
         OP_REQUIRES_OK(
@@ -132,6 +138,7 @@ class AssignOp : public OpKernel {
 
   bool use_exclusive_lock_;
   bool validate_shape_;
+  bool relax_constraints_;
 };
 
 }  // end namespace tensorflow
index 916869f..a8bcc7f 100644 (file)
@@ -211,6 +211,11 @@ class AssignVariableOp : public OpKernel {
  public:
   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+    if (!c->GetAttr("_grappler_relax_allocator_constraints",
+                    &relax_constraints_)
+             .ok()) {
+      relax_constraints_ = false;
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -228,8 +233,10 @@ class AssignVariableOp : public OpKernel {
               PersistentTensor unused;
               Tensor* tmp;
               AllocatorAttributes attr;
-              attr.set_gpu_compatible(true);
-              attr.set_nic_compatible(true);
+              if (!relax_constraints_) {
+                attr.set_gpu_compatible(true);
+                attr.set_nic_compatible(true);
+              }
               TF_RETURN_IF_ERROR(context->allocate_persistent(
                   dtype_, context->input(1).shape(), &unused, &tmp, attr));
               *(*ptr)->tensor() = *tmp;
@@ -245,8 +252,10 @@ class AssignVariableOp : public OpKernel {
 
     const Tensor& value = context->input(1);
     AllocatorAttributes attr;
-    attr.set_gpu_compatible(true);
-    attr.set_nic_compatible(true);
+    if (!relax_constraints_) {
+      attr.set_gpu_compatible(true);
+      attr.set_nic_compatible(true);
+    }
 
     // Copying is unnecessary if we are the last user of the value
     // tensor, we can just adopt the input tensor's buffer instead.
@@ -277,6 +286,7 @@ class AssignVariableOp : public OpKernel {
 
  private:
   DataType dtype_;
+  bool relax_constraints_;
 };
 
 template <typename Device>