void Initialize(const Graph* g);
Status SetAllocAttrs(const Graph* g, const Device* device);
+ void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
NodeItem* node(size_t id) const {
DCHECK_GE(id, 0);
DCHECK_EQ(item->input_type(i), n->input_type(i));
}
- uint8* output_types = item->output_type_base();
- for (int i = 0; i < num_outputs; i++) {
- output_types[i] = static_cast<uint8>(n->output_type(i));
- DCHECK_EQ(item->output_type(i), n->output_type(i));
+ // Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
+ {
+ std::vector<int> forward_input;
+ Status fwd_status =
+ GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
+ std::vector<int> scoped_allocator_attrs;
+ Status sa_status =
+ GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
+
+ int* forward_from = item->forward_from_base();
+ uint8* output_types = item->output_type_base();
+ for (int i = 0; i < num_outputs; ++i) {
+ output_types[i] = static_cast<uint8>(n->output_type(i));
+ DCHECK_EQ(item->output_type(i), n->output_type(i));
+
+ forward_from[i] = OpKernelContext::Params::kNoReservation;
+ if (sa_status.ok()) {
+ for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
+ if (scoped_allocator_attrs[j] == i) {
+ // This output slot must be explicitly allocated from a
+ // ScopedAllocator.
+ forward_from[i] = OpKernelContext::Params::kNeverForward;
+ DCHECK_EQ(output_attrs[i].scope_id, 0);
+ output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
+ }
+ }
+ }
+ if (fwd_status.ok() && forward_from[i] == -1) {
+ DCHECK_EQ(forward_input.size() % 2, 0);
+ for (int j = 0; j < forward_input.size(); j += 2) {
+ if (forward_input[j + 1] == i) {
+ DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
+ forward_from[i] = forward_input[j];
+ break;
+ }
+ }
+ }
+ }
}
+
return ptr;
}
return gview_.SetAllocAttrs(graph_.get(), params_.device);
}
+// If a Node has been marked to use a ScopedAllocator x for output i, then
+// sc_attr will contain the subsequence (i, x) at an even offset. This function
+// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
+// only allow one ScopedAllocator use per Node.
+bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
+ int output_index,
+ AllocatorAttributes* alloc_attr) {
+ DCHECK_LE(2, sc_attr.size());
+ for (int i = 0; i < sc_attr.size(); i += 2) {
+ if (sc_attr[i] == output_index) {
+ CHECK_EQ(alloc_attr->scope_id, 0);
+ alloc_attr->scope_id = sc_attr[i + 1];
+ return true;
+ }
+ }
+ return false;
+}
+
+void GraphView::SetScopedAllocatorAttrs(
+ const std::vector<const Node*>& sa_nodes) {
+ for (const Node* sa : sa_nodes) {
+ NodeItem* sa_item = node(sa->id());
+ AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
+ // Control edges out of the ScopedAllocator should be use instances, but may
+ // include a few other nodes.
+ for (const auto& e : sa->out_edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+ Node* use_node = e->dst();
+ NodeItem* item = node(use_node->id());
+ AllocatorAttributes* use_attrs = item->output_attr_base();
+ std::vector<int> scoped_allocator_attrs;
+ Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
+ &scoped_allocator_attrs);
+ if (!s.ok()) {
+ VLOG(2) << "Failed to find expected ScopedAllocator attr on "
+ << use_node->name();
+ continue;
+ }
+ // There should be exactly one output using ScopedAllocation.
+ for (const auto& e : use_node->out_edges()) {
+ if (!e->IsControlEdge()) {
+ AllocatorAttributes attr;
+ if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
+ e->src_output(), &attr)) {
+ // Set the scope_id on this use instance node.
+ (use_attrs + e->src_output())->Merge(attr);
+ // Propagate the other attributes of this node back to the SA node.
+ attr = *(use_attrs + e->src_output());
+ attr.scope_id = 0;
+ sa_attrs->Merge(attr);
+ }
+ }
+ }
+ }
+ }
+}
+
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
Status s;
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
+ std::vector<const Node*> scoped_allocator_instances;
for (const Node* n : g->nodes()) {
NodeItem* item = node(n->id());
AllocatorAttributes* attrs = item->output_attr_base();
+ if (IsScopedAllocator(n)) {
+ scoped_allocator_instances.push_back(n);
+ }
// Examine the out edges of each node looking for special use
// cases that may affect memory allocation attributes.
- for (auto e : n->out_edges()) {
+ for (const auto& e : n->out_edges()) {
if (!e->IsControlEdge()) {
AllocatorAttributes attr;
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
if (!s.ok()) return s;
- if (attr.value != 0) {
+ if (attr.value != 0 || attr.scope_id != 0) {
attrs[e->src_output()].Merge(attr);
}
}
}
}
}
+ SetScopedAllocatorAttrs(scoped_allocator_instances);
return s;
}
params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
- params.forward_from_array = nullptr; // later: item.forward_from();
+ params.forward_from_array = item.forward_from();
if (item.kernel_is_async) {
// Asynchronous computes.
// between output O of layer A and input I of layer B using
// "input index" and "output index" labels per edge.
-#ifndef TENSORFLOW_GRAPH_GRAPH_H_
-#define TENSORFLOW_GRAPH_GRAPH_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_H_
#include <functional>
#include <string>
}
bool IsHostSend() const { return class_ == NC_HOST_SEND; }
bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
+ bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
bool IsMetadata() const { return class_ == NC_METADATA; }
NC_GET_SESSION_TENSOR,
NC_DELETE_SESSION_TENSOR,
NC_METADATA,
+ NC_SCOPED_ALLOCATOR,
NC_OTHER // Not a special kind of node
};
// (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
+inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
+
inline bool IsHostMemoryPreserving(const Node* node) {
return IsIdentity(node) || IsControlFlow(node);
}
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_