From 09e529ff5adb916e40481563698dee72e8a15162 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Mon, 30 Apr 2018 10:36:00 -0700 Subject: [PATCH] Prepare nodes that will be allocated using ScopedAllocator. This includes changes to Executor that (1) set scope_id on nodes that are decorated with _scoped_allocator attribute, (2) mark such nodes to never forward input. PiperOrigin-RevId: 194807086 --- tensorflow/core/common_runtime/executor.cc | 114 +++++++++++++++++++++++++++-- tensorflow/core/graph/graph.cc | 1 + tensorflow/core/graph/graph.h | 10 ++- 3 files changed, 115 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 0c461a9..e389eb9 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -322,6 +322,7 @@ class GraphView { void Initialize(const Graph* g); Status SetAllocAttrs(const Graph* g, const Device* device); + void SetScopedAllocatorAttrs(const std::vector& sa_nodes); NodeItem* node(size_t id) const { DCHECK_GE(id, 0); @@ -566,11 +567,46 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { DCHECK_EQ(item->input_type(i), n->input_type(i)); } - uint8* output_types = item->output_type_base(); - for (int i = 0; i < num_outputs; i++) { - output_types[i] = static_cast(n->output_type(i)); - DCHECK_EQ(item->output_type(i), n->output_type(i)); + // Check ScopedAllocatorAttrs and forward_from. Also assign output_types. + { + std::vector forward_input; + Status fwd_status = + GetNodeAttr(n->attrs(), "_forward_input", &forward_input); + std::vector scoped_allocator_attrs; + Status sa_status = + GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs); + + int* forward_from = item->forward_from_base(); + uint8* output_types = item->output_type_base(); + for (int i = 0; i < num_outputs; ++i) { + output_types[i] = static_cast(n->output_type(i)); + DCHECK_EQ(item->output_type(i), n->output_type(i)); + + forward_from[i] = OpKernelContext::Params::kNoReservation; + if (sa_status.ok()) { + for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) { + if (scoped_allocator_attrs[j] == i) { + // This output slot must be explicitly allocated from a + // ScopedAllocator. + forward_from[i] = OpKernelContext::Params::kNeverForward; + DCHECK_EQ(output_attrs[i].scope_id, 0); + output_attrs[i].scope_id = scoped_allocator_attrs[j + 1]; + } + } + } + if (fwd_status.ok() && forward_from[i] == -1) { + DCHECK_EQ(forward_input.size() % 2, 0); + for (int j = 0; j < forward_input.size(); j += 2) { + if (forward_input[j + 1] == i) { + DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation); + forward_from[i] = forward_input[j]; + break; + } + } + } + } } + return ptr; } @@ -696,22 +732,85 @@ Status ExecutorImpl::Initialize() { return gview_.SetAllocAttrs(graph_.get(), params_.device); } +// If a Node has been marked to use a ScopedAllocator x for output i, then +// sc_attr will contain the subsequence (i, x) at an even offset. This function +// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we +// only allow one ScopedAllocator use per Node. +bool ExtractScopedAllocatorAttr(const std::vector& sc_attr, + int output_index, + AllocatorAttributes* alloc_attr) { + DCHECK_LE(2, sc_attr.size()); + for (int i = 0; i < sc_attr.size(); i += 2) { + if (sc_attr[i] == output_index) { + CHECK_EQ(alloc_attr->scope_id, 0); + alloc_attr->scope_id = sc_attr[i + 1]; + return true; + } + } + return false; +} + +void GraphView::SetScopedAllocatorAttrs( + const std::vector& sa_nodes) { + for (const Node* sa : sa_nodes) { + NodeItem* sa_item = node(sa->id()); + AllocatorAttributes* sa_attrs = sa_item->output_attr_base(); + // Control edges out of the ScopedAllocator should be use instances, but may + // include a few other nodes. + for (const auto& e : sa->out_edges()) { + if (!e->IsControlEdge()) { + continue; + } + Node* use_node = e->dst(); + NodeItem* item = node(use_node->id()); + AllocatorAttributes* use_attrs = item->output_attr_base(); + std::vector scoped_allocator_attrs; + Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator", + &scoped_allocator_attrs); + if (!s.ok()) { + VLOG(2) << "Failed to find expected ScopedAllocator attr on " + << use_node->name(); + continue; + } + // There should be exactly one output using ScopedAllocation. + for (const auto& e : use_node->out_edges()) { + if (!e->IsControlEdge()) { + AllocatorAttributes attr; + if (ExtractScopedAllocatorAttr(scoped_allocator_attrs, + e->src_output(), &attr)) { + // Set the scope_id on this use instance node. + (use_attrs + e->src_output())->Merge(attr); + // Propagate the other attributes of this node back to the SA node. + attr = *(use_attrs + e->src_output()); + attr.scope_id = 0; + sa_attrs->Merge(attr); + } + } + } + } + } +} + Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { Status s; DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); + std::vector scoped_allocator_instances; for (const Node* n : g->nodes()) { NodeItem* item = node(n->id()); AllocatorAttributes* attrs = item->output_attr_base(); + if (IsScopedAllocator(n)) { + scoped_allocator_instances.push_back(n); + } // Examine the out edges of each node looking for special use // cases that may affect memory allocation attributes. - for (auto e : n->out_edges()) { + for (const auto& e : n->out_edges()) { if (!e->IsControlEdge()) { AllocatorAttributes attr; s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); if (!s.ok()) return s; - if (attr.value != 0) { + if (attr.value != 0 || attr.scope_id != 0) { attrs[e->src_output()].Merge(attr); } } @@ -728,6 +827,7 @@ Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { } } } + SetScopedAllocatorAttrs(scoped_allocator_instances); return s; } @@ -1614,7 +1714,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); - params.forward_from_array = nullptr; // later: item.forward_from(); + params.forward_from_array = item.forward_from(); if (item.kernel_is_async) { // Asynchronous computes. diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index fb8a6c3..eeb6c60 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -79,6 +79,7 @@ const std::unordered_map& Node::kNodeClassTable = {"Size", NC_METADATA}, {"Shape", NC_METADATA}, {"Rank", NC_METADATA}, + {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, }); #undef REF_CLASS diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index f7ca7d0..83a69e6 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -34,8 +34,8 @@ limitations under the License. // between output O of layer A and input I of layer B using // "input index" and "output index" labels per edge. -#ifndef TENSORFLOW_GRAPH_GRAPH_H_ -#define TENSORFLOW_GRAPH_GRAPH_H_ +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_H_ #include #include @@ -162,6 +162,7 @@ class Node { } bool IsHostSend() const { return class_ == NC_HOST_SEND; } bool IsHostRecv() const { return class_ == NC_HOST_RECV; } + bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; } bool IsMetadata() const { return class_ == NC_METADATA; } @@ -233,6 +234,7 @@ class Node { NC_GET_SESSION_TENSOR, NC_DELETE_SESSION_TENSOR, NC_METADATA, + NC_SCOPED_ALLOCATOR, NC_OTHER // Not a special kind of node }; @@ -696,6 +698,8 @@ inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } +inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); } + inline bool IsHostMemoryPreserving(const Node* node) { return IsIdentity(node) || IsControlFlow(node); } @@ -827,4 +831,4 @@ inline const string& Node::assigned_device_name() const { } // namespace tensorflow -#endif // TENSORFLOW_GRAPH_GRAPH_H_ +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_ -- 2.7.4