Prepare nodes that will be allocated using ScopedAllocator.
authorAyush Dubey <ayushd@google.com>
Mon, 30 Apr 2018 17:36:00 +0000 (10:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 17:39:01 +0000 (10:39 -0700)
This includes changes to Executor that (1) set scope_id on nodes that are
decorated with _scoped_allocator attribute, (2) mark such nodes to never
forward input.

PiperOrigin-RevId: 194807086

tensorflow/core/common_runtime/executor.cc
tensorflow/core/graph/graph.cc
tensorflow/core/graph/graph.h

index 0c461a9..e389eb9 100644 (file)
@@ -322,6 +322,7 @@ class GraphView {
 
   void Initialize(const Graph* g);
   Status SetAllocAttrs(const Graph* g, const Device* device);
+  void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
 
   NodeItem* node(size_t id) const {
     DCHECK_GE(id, 0);
@@ -566,11 +567,46 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
     DCHECK_EQ(item->input_type(i), n->input_type(i));
   }
 
-  uint8* output_types = item->output_type_base();
-  for (int i = 0; i < num_outputs; i++) {
-    output_types[i] = static_cast<uint8>(n->output_type(i));
-    DCHECK_EQ(item->output_type(i), n->output_type(i));
+  // Check ScopedAllocatorAttrs and forward_from.  Also assign output_types.
+  {
+    std::vector<int> forward_input;
+    Status fwd_status =
+        GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
+    std::vector<int> scoped_allocator_attrs;
+    Status sa_status =
+        GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
+
+    int* forward_from = item->forward_from_base();
+    uint8* output_types = item->output_type_base();
+    for (int i = 0; i < num_outputs; ++i) {
+      output_types[i] = static_cast<uint8>(n->output_type(i));
+      DCHECK_EQ(item->output_type(i), n->output_type(i));
+
+      forward_from[i] = OpKernelContext::Params::kNoReservation;
+      if (sa_status.ok()) {
+        for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
+          if (scoped_allocator_attrs[j] == i) {
+            // This output slot must be explicitly allocated from a
+            // ScopedAllocator.
+            forward_from[i] = OpKernelContext::Params::kNeverForward;
+            DCHECK_EQ(output_attrs[i].scope_id, 0);
+            output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
+          }
+        }
+      }
+      if (fwd_status.ok() && forward_from[i] == -1) {
+        DCHECK_EQ(forward_input.size() % 2, 0);
+        for (int j = 0; j < forward_input.size(); j += 2) {
+          if (forward_input[j + 1] == i) {
+            DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
+            forward_from[i] = forward_input[j];
+            break;
+          }
+        }
+      }
+    }
   }
+
   return ptr;
 }
 
@@ -696,22 +732,85 @@ Status ExecutorImpl::Initialize() {
   return gview_.SetAllocAttrs(graph_.get(), params_.device);
 }
 
+// If a Node has been marked to use a ScopedAllocator x for output i, then
+// sc_attr will contain the subsequence (i, x) at an even offset.  This function
+// extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
+// only allow one ScopedAllocator use per Node.
+bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
+                                int output_index,
+                                AllocatorAttributes* alloc_attr) {
+  DCHECK_LE(2, sc_attr.size());
+  for (int i = 0; i < sc_attr.size(); i += 2) {
+    if (sc_attr[i] == output_index) {
+      CHECK_EQ(alloc_attr->scope_id, 0);
+      alloc_attr->scope_id = sc_attr[i + 1];
+      return true;
+    }
+  }
+  return false;
+}
+
+void GraphView::SetScopedAllocatorAttrs(
+    const std::vector<const Node*>& sa_nodes) {
+  for (const Node* sa : sa_nodes) {
+    NodeItem* sa_item = node(sa->id());
+    AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
+    // Control edges out of the ScopedAllocator should be use instances, but may
+    // include a few other nodes.
+    for (const auto& e : sa->out_edges()) {
+      if (!e->IsControlEdge()) {
+        continue;
+      }
+      Node* use_node = e->dst();
+      NodeItem* item = node(use_node->id());
+      AllocatorAttributes* use_attrs = item->output_attr_base();
+      std::vector<int> scoped_allocator_attrs;
+      Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
+                             &scoped_allocator_attrs);
+      if (!s.ok()) {
+        VLOG(2) << "Failed to find expected ScopedAllocator attr on "
+                << use_node->name();
+        continue;
+      }
+      // There should be exactly one output using ScopedAllocation.
+      for (const auto& e : use_node->out_edges()) {
+        if (!e->IsControlEdge()) {
+          AllocatorAttributes attr;
+          if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
+                                         e->src_output(), &attr)) {
+            // Set the scope_id on this use instance node.
+            (use_attrs + e->src_output())->Merge(attr);
+            // Propagate the other attributes of this node back to the SA node.
+            attr = *(use_attrs + e->src_output());
+            attr.scope_id = 0;
+            sa_attrs->Merge(attr);
+          }
+        }
+      }
+    }
+  }
+}
+
 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
   Status s;
   DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
 
+  std::vector<const Node*> scoped_allocator_instances;
   for (const Node* n : g->nodes()) {
     NodeItem* item = node(n->id());
     AllocatorAttributes* attrs = item->output_attr_base();
+    if (IsScopedAllocator(n)) {
+      scoped_allocator_instances.push_back(n);
+    }
 
     // Examine the out edges of each node looking for special use
     // cases that may affect memory allocation attributes.
-    for (auto e : n->out_edges()) {
+    for (const auto& e : n->out_edges()) {
       if (!e->IsControlEdge()) {
         AllocatorAttributes attr;
         s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
         if (!s.ok()) return s;
-        if (attr.value != 0) {
+        if (attr.value != 0 || attr.scope_id != 0) {
           attrs[e->src_output()].Merge(attr);
         }
       }
@@ -728,6 +827,7 @@ Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
       }
     }
   }
+  SetScopedAllocatorAttrs(scoped_allocator_instances);
   return s;
 }
 
@@ -1614,7 +1714,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
       params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
       params.is_input_dead = is_input_dead;
       params.output_attr_array = item.output_attrs();
-      params.forward_from_array = nullptr;  // later: item.forward_from();
+      params.forward_from_array = item.forward_from();
 
       if (item.kernel_is_async) {
         // Asynchronous computes.
index fb8a6c3..eeb6c60 100644 (file)
@@ -79,6 +79,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
         {"Size", NC_METADATA},
         {"Shape", NC_METADATA},
         {"Rank", NC_METADATA},
+        {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
     });
 
 #undef REF_CLASS
index f7ca7d0..83a69e6 100644 (file)
@@ -34,8 +34,8 @@ limitations under the License.
 // between output O of layer A and input I of layer B using
 // "input index" and "output index" labels per edge.
 
-#ifndef TENSORFLOW_GRAPH_GRAPH_H_
-#define TENSORFLOW_GRAPH_GRAPH_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_H_
 
 #include <functional>
 #include <string>
@@ -162,6 +162,7 @@ class Node {
   }
   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
+  bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
 
   bool IsMetadata() const { return class_ == NC_METADATA; }
 
@@ -233,6 +234,7 @@ class Node {
     NC_GET_SESSION_TENSOR,
     NC_DELETE_SESSION_TENSOR,
     NC_METADATA,
+    NC_SCOPED_ALLOCATOR,
     NC_OTHER  // Not a special kind of node
   };
 
@@ -696,6 +698,8 @@ inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
 
+inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
+
 inline bool IsHostMemoryPreserving(const Node* node) {
   return IsIdentity(node) || IsControlFlow(node);
 }
@@ -827,4 +831,4 @@ inline const string& Node::assigned_device_name() const {
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_GRAPH_GRAPH_H_
+#endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_