#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
+namespace subgraph {
// ----------------------------------------------------------------------------
// Subgraph construction-related routines
namespace {
+typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
+
// Rewrite graph by replacing the output tensors specified in
// "fed_outputs" with special feed nodes for each specified output
// tensor, and removing any nodes that are now disconnected from the
// Return true on success. On error, return false and sets *error to
// an appropriate error message (and *g is left in an indeterminate
// state).
-static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
- const gtl::ArraySlice<string>& fed_outputs,
- bool use_function_convention,
- subgraph::NameIndex* name_index,
- DataTypeVector* out_feed_types) {
+Status FeedInputs(
+ Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+ NameIndex* name_index, DataTypeVector* out_feed_types) {
out_feed_types->clear();
- out_feed_types->reserve(fed_outputs.size());
- for (size_t i = 0; i < fed_outputs.size(); ++i) {
- const string& t = fed_outputs[i];
+ out_feed_types->reserve(feed_rewrites.size());
+ for (size_t i = 0; i < feed_rewrites.size(); ++i) {
+ const string& t = feed_rewrites[i]->endpoint_name();
TensorId id(ParseTensorName(t));
auto iter = name_index->find(id.first);
if (iter == name_index->end()) {
return errors::NotFound("FeedInputs: unable to find feed output ", t);
}
- const Node* n = iter->second;
+ Node* n = iter->second;
DCHECK_EQ(n->name(), id.first);
if (id.second >= n->num_outputs()) {
return errors::InvalidArgument(
"FeedInputs: ", t, " should have output index < ", n->num_outputs());
}
- Node* recv_node;
-
- if (!use_function_convention) {
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
- "_Recv")
- .Attr("tensor_type", BaseType(n->output_type(id.second)))
- .Attr("tensor_name", t)
- .Attr("send_device", device_info.name())
- .Attr("recv_device", device_info.name())
- .Attr("send_device_incarnation",
- static_cast<int64>(device_info.incarnation()))
- .Attr("client_terminated", true)
- .Finalize(g, &recv_node));
- } else {
- // NOTE(mrry): We must include the index as part of the node
- // name, because _Arg is a "stateful" kernel and therefore
- // its name must uniquely identify a kernel instance across all
- // graphs in the same session.
- TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
- id.second, "_", i),
- "_Arg")
- .Attr("T", BaseType(n->output_type(id.second)))
- .Attr("index", static_cast<int32>(i))
- .Finalize(g, &recv_node));
- }
- recv_node->set_assigned_device_name(device_info.name());
+ Node* feed_node;
+ TF_RETURN_IF_ERROR(
+ feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
// Update name_index
- (*name_index)[recv_node->name()] = recv_node;
- g->AddControlEdge(g->source_node(), recv_node);
+ (*name_index)[feed_node->name()] = feed_node;
+ g->AddControlEdge(g->source_node(), feed_node);
// Look through edges coming out of "n" for edges whose src_output() index
// matches "output_index". If found, replace the edges with a connection
n->type_string() == "PlaceholderV2")) {
// When feeding a Placeholder node, any outgoing control edges
// will be replaced with a control edge from the replacement
- // recv_node.
+ // feed_node.
// TODO(josh11b,mrry): Come up with a more elegant way of addressing
// the general version of this problem.
to_remove.emplace_back(e);
for (const Edge* e : to_remove) {
if (e->src_output() == id.second) {
- g->AddEdge(recv_node, 0, e->dst(), e->dst_input());
+ g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
} else {
CHECK_EQ(Graph::kControlSlot, e->src_output());
- g->AddControlEdge(recv_node, e->dst());
+ g->AddControlEdge(feed_node, e->dst());
}
g->RemoveEdge(e);
}
return Status::OK();
}
-static bool AddNodeToTargets(const string& node_or_tensor_name,
- const subgraph::NameIndex& name_index,
- std::unordered_set<const Node*>* targets) {
+Status FetchOutputs(
+ Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
+ NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
+ DataTypeVector* out_fetch_types) {
+ out_fetch_nodes->clear();
+ out_fetch_nodes->reserve(fetch_rewrites.size());
+ for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
+ const string& t = fetch_rewrites[i]->endpoint_name();
+
+ // Parse t into node_name and output_index.
+ TensorId id(ParseTensorName(t));
+
+ // Find node in graph with that name.
+ auto iter = name_index->find(id.first);
+ if (iter == name_index->end()) {
+ return errors::NotFound("FetchOutputs node ", t, ": not found");
+ }
+ Node* n = iter->second;
+ DCHECK_EQ(n->name(), id.first);
+ VLOG(2) << "Found fetch node for " << t;
+
+ // Validate output_index
+ if (n->num_outputs() == 0) {
+ return errors::InvalidArgument(
+ "Tried to fetch data for '", t,
+ "', which produces no output. To run to a node but not fetch any "
+ "data, pass '",
+ t,
+ "' as an argument to the 'target_node_names' argument of the "
+ "Session::Run API.");
+ } else if (id.second >= n->num_outputs()) {
+ return errors::InvalidArgument("FetchOutputs ", t,
+ ": output index too large, must be < ",
+ n->num_outputs());
+ }
+
+ // Create the fetch Node and connect it up
+ Node* fetch_node;
+ TF_RETURN_IF_ERROR(
+ fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
+
+ // Update the index.
+ (*name_index)[fetch_node->name()] = fetch_node;
+
+ g->AddControlEdge(fetch_node, g->sink_node());
+ out_fetch_nodes->push_back(fetch_node);
+ out_fetch_types->push_back(BaseType(n->output_type(id.second)));
+ }
+
+ return Status::OK();
+}
+
+bool AddNodeToTargets(const string& node_or_tensor_name,
+ const NameIndex& name_index,
+ std::unordered_set<const Node*>* targets) {
TensorId id = ParseTensorName(node_or_tensor_name);
auto iter = name_index.find(id.first);
if (iter == name_index.end()) {
return true;
}
-static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
- const std::vector<Node*>& fetch_nodes,
- const gtl::ArraySlice<string>& target_nodes) {
+Status PruneForTargets(Graph* g, const NameIndex& name_index,
+ const std::vector<Node*>& fetch_nodes,
+ const gtl::ArraySlice<string>& target_nodes) {
string not_found;
std::unordered_set<const Node*> targets;
for (Node* n : fetch_nodes) {
} // namespace
-namespace subgraph {
+Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+ Node** out_node) {
+ // NOTE(mrry): We must include the index as part of the node
+ // name, because _Arg is a "stateful" kernel and therefore
+ // its name must uniquely identify a kernel instance across all
+ // graphs in the same session.
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
+ feed_tensor.index, "_", arg_index_),
+ "_Arg")
+ .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
+ .Attr("index", arg_index_)
+ .Finalize(g, out_node));
+ (*out_node)->set_assigned_device_name(device_info().name());
+ return Status::OK();
+}
-Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
- const gtl::ArraySlice<string>& fetch_outputs,
- bool use_function_convention, NameIndex* name_index,
- std::vector<Node*>* out_fetch_nodes,
- DataTypeVector* out_fetch_types) {
- out_fetch_nodes->clear();
- out_fetch_nodes->reserve(fetch_outputs.size());
- for (size_t i = 0; i < fetch_outputs.size(); ++i) {
- const string& t = fetch_outputs[i];
+Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+ Node** out_node) {
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
+ feed_tensor.index),
+ "_Recv")
+ .Attr("tensor_type",
+ BaseType(feed_tensor.node->output_type(feed_tensor.index)))
+ .Attr("tensor_name", endpoint_name())
+ .Attr("send_device", device_info().name())
+ .Attr("recv_device", device_info().name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info().incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, out_node));
+
+ (*out_node)->set_assigned_device_name(device_info().name());
+ return Status::OK();
+}
- // Parse t into node_name and output_index.
- TensorId id(ParseTensorName(t));
+Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+ Node** out_node) {
+ // NOTE(mrry): We must include the index as part of the node
+ // name, because _Retval is a "stateful" kernel and therefore
+ // its name must uniquely identify a kernel instance across all
+ // graphs in the same session.
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
+ fetch_tensor.index, "_", retval_index_),
+ "_Retval")
+ .Input(fetch_tensor.node, fetch_tensor.index)
+ .Attr("T",
+ BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
+ .Attr("index", retval_index_)
+ .Finalize(g, out_node));
+ (*out_node)->set_assigned_device_name(device_info().name());
+ return Status::OK();
+}
- // Find node in graph with that name.
- auto iter = name_index->find(id.first);
- if (iter == name_index->end()) {
- return errors::NotFound("FetchOutputs node ", t, ": not found");
- }
- Node* n = iter->second;
- DCHECK_EQ(n->name(), id.first);
- VLOG(2) << "Found fetch node for " << t;
+Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+ Node** out_node) {
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
+ fetch_tensor.index),
+ "_Send")
+ .Input(fetch_tensor.node, fetch_tensor.index)
+ .Attr("tensor_name", endpoint_name())
+ .Attr("send_device", device_info().name())
+ .Attr("recv_device", device_info().name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info().incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, out_node));
+ (*out_node)->set_assigned_device_name(device_info().name());
+ return Status::OK();
+}
- // Validate output_index
- if (n->num_outputs() == 0) {
- return errors::InvalidArgument(
- "Tried to fetch data for '", t,
- "', which produces no output. To run to a node but not fetch any "
- "data, pass '",
- t,
- "' as an argument to the 'target_node_names' argument of the "
- "Session::Run API.");
- } else if (id.second >= n->num_outputs()) {
- return errors::InvalidArgument("FetchOutputs ", t,
- ": output index too large, must be < ",
- n->num_outputs());
+Status RewriteGraphForExecution(
+ Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ const gtl::ArraySlice<string>& target_node_names,
+ const DeviceAttributes& device_info, bool use_function_convention,
+ RewriteGraphMetadata* out_metadata) {
+ std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
+ feed_rewrites.reserve(fed_outputs.size());
+ if (use_function_convention) {
+ for (size_t i = 0; i < fed_outputs.size(); ++i) {
+ feed_rewrites.emplace_back(new ArgFeedRewrite(
+ &fed_outputs[i], &device_info, static_cast<int32>(i)));
}
-
- // Create the fetch Node and connect it up
- Node* send_node;
- if (!use_function_convention) {
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
- "_Send")
- .Input(n, id.second)
- .Attr("tensor_name", t)
- .Attr("send_device", device_info.name())
- .Attr("recv_device", device_info.name())
- .Attr("send_device_incarnation",
- static_cast<int64>(device_info.incarnation()))
- .Attr("client_terminated", true)
- .Finalize(g, &send_node));
- } else {
- // NOTE(mrry): We must include the index as part of the node
- // name, because _Retval is a "stateful" kernel and therefore
- // its name must uniquely identify a kernel instance across all
- // graphs in the same session.
- TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
- id.second, "_", i),
- "_Retval")
- .Input(n, id.second)
- .Attr("T", BaseType(n->output_type(id.second)))
- .Attr("index", static_cast<int32>(i))
- .Finalize(g, &send_node));
+ } else {
+ for (const string& fed_output : fed_outputs) {
+ feed_rewrites.emplace_back(
+ new RecvFeedRewrite(&fed_output, &device_info));
}
- send_node->set_assigned_device_name(device_info.name());
-
- // Update the index.
- (*name_index)[send_node->name()] = send_node;
+ }
- g->AddControlEdge(send_node, g->sink_node());
- out_fetch_nodes->push_back(send_node);
- out_fetch_types->push_back(BaseType(n->output_type(id.second)));
+ std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
+ fetch_rewrites.reserve(fetch_outputs.size());
+ if (use_function_convention) {
+ for (size_t i = 0; i < fetch_outputs.size(); ++i) {
+ fetch_rewrites.emplace_back(new RetvalFetchRewrite(
+ &fetch_outputs[i], &device_info, static_cast<int32>(i)));
+ }
+ } else {
+ for (const string& fetch_output : fetch_outputs) {
+ fetch_rewrites.emplace_back(
+ new SendFetchRewrite(&fetch_output, &device_info));
+ }
}
- return Status::OK();
+ return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
+ target_node_names, out_metadata);
+}
+
+namespace {
+template <typename StringContainer>
+std::vector<string> ConvertToVector(StringContainer field) {
+ return std::vector<string>(field.begin(), field.end());
}
+} // namespace
Status RewriteGraphForExecution(
- Graph* g, const gtl::ArraySlice<string>& fed_outputs,
- const gtl::ArraySlice<string>& fetch_outputs,
+ Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+ const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
const gtl::ArraySlice<string>& target_node_names,
- const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata) {
- if (fetch_outputs.empty() && target_node_names.empty()) {
+ if (fetch_rewrites.empty() && target_node_names.empty()) {
return errors::InvalidArgument(
"Must specify at least one target to fetch or execute.");
}
std::unordered_set<string> endpoints;
- for (const string& endpoint_name : fed_outputs) {
- auto result = endpoints.insert(endpoint_name);
+ for (const auto& feed_rewrite : feed_rewrites) {
+ auto result = endpoints.insert(feed_rewrite->endpoint_name());
if (!result.second) {
- return errors::InvalidArgument("Endpoint \"", endpoint_name,
+ return errors::InvalidArgument("Endpoint \"",
+ feed_rewrite->endpoint_name(),
"\" fed more than once.");
}
}
- for (const auto& fetch : fetch_outputs) {
- if (endpoints.count(fetch) > 0) {
- return errors::InvalidArgument(fetch, " is both fed and fetched.");
+ for (const auto& fetch_rewrite : fetch_rewrites) {
+ if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
+ return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
+ " is both fed and fetched.");
}
}
}
// Add the feeds. This may replace nodes in the graph, including the nodes
- // currently listed in "fetch_nodes". We pass "name_index" so the index is
+ // currently listed in "fetch_rewrites". We pass "name_index" so the index is
// kept up to date.
- if (!fed_outputs.empty()) {
- TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
- use_function_convention, &name_index,
- &out_metadata->feed_types));
+ if (!feed_rewrites.empty()) {
+ TF_RETURN_IF_ERROR(
+ FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
}
// Add the fetch nodes, also updating "name_index".
std::vector<Node*> fetch_nodes;
- if (!fetch_outputs.empty()) {
- TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
- use_function_convention, &name_index,
+ if (!fetch_rewrites.empty()) {
+ TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
&fetch_nodes, &out_metadata->fetch_types));
}
return Status::OK();
}
-namespace {
-template <typename StringContainer>
-std::vector<string> ConvertToVector(StringContainer field) {
- return std::vector<string>(field.begin(), field.end());
-}
-} // namespace
-
-Status RewriteGraphForExecution(Graph* g,
- const CallableOptions& callable_options,
- const DeviceAttributes& device_info,
- bool use_function_convention,
- RewriteGraphMetadata* out_metadata) {
- return RewriteGraphForExecution(g, ConvertToVector(callable_options.feed()),
- ConvertToVector(callable_options.fetch()),
- ConvertToVector(callable_options.target()),
- device_info, use_function_convention,
- out_metadata);
-}
-
} // namespace subgraph
} // namespace tensorflow
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/protobuf/config.pb.h"
DataTypeVector fetch_types;
};
+// Describes the action to take on a particular tensor endpoint (described by
+// a "<node_name>:<output_index>" pair) when pruning the graph.
+//
+// The `AddNode()` method must be overridden to describe this action. The method
+// will be invoked once during `RewriteGraphForExecution()` with tensor endpoint
+// named by `endpoint_name`, and it may either create a single new node, or fail
+// with an error if the resulting graph would be invalid.
+class PruneRewrite {
+ public:
+ // `endpoint_name` and `device_info` must outlive this object.
+ PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info)
+ : endpoint_name_(endpoint_name), device_info_(device_info) {}
+ virtual ~PruneRewrite() {}
+
+ // Creates a new node whose output replaces the given `tensor` in graph `g`.
+ // The node will be assigned to the device named in `device_info`.
+ virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor,
+ Node** out_node) = 0;
+
+ // Returns the name of the tensor to which this rewrite applies.
+ const string& endpoint_name() { return *endpoint_name_; }
+
+ protected:
+ // The device on which the new node will be created.
+ const DeviceAttributes& device_info() { return *device_info_; }
+
+ private:
+ const string* const endpoint_name_; // Not owned.
+ const DeviceAttributes* const device_info_; // Not owned.
+};
+
// Rewrite the graph structure of "*g" to deal with feeding node
// outputs, fetching node outputs, and only running a subset of the
// graph. "fed_outputs" and "fetch_outputs" are both lists of
// In the resulting graph "*g", output edges in "fed_outputs" have
// been redirected to special "_recv" nodes introduced into the graph.
// If these fed nodes are not needed in order to compute the effects
-// of the nodes in "targets_nodes" and "fetch_outputs", then these may
+// of the nodes in "target_node_names" and "fetch_outputs", then these may
// be omitted from the graph.
//
// In the resulting graph "*g", additional "_send" nodes are connected
const gtl::ArraySlice<string>& target_node_names,
const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata);
-Status RewriteGraphForExecution(Graph* g,
- const CallableOptions& callable_options,
- const DeviceAttributes& device_info,
- bool use_function_convention,
- RewriteGraphMetadata* out_metadata);
-
-typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
-
-// Augment "*g" by adding special "fetch" nodes that connect to the
-// tensor outputs specified in "fetch_outputs" to retrieve the output
-// of the tensors. The new nodes added are set up to execute on
-// "client_device_name", and are returned in "*fetch_nodes".
-//
-// Return OK on success. On error, return false and sets *error to
-// an appropriate error message (and *g is left in an indeterminate
-// state).
-Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
- const gtl::ArraySlice<string>& fetch_outputs,
- NameIndex* name_index, std::vector<Node*>* fetch_nodes);
+
+// A more general version of the above function that supports
+// customizable rewriting actions for each fed and fetched tensor.
+Status RewriteGraphForExecution(
+ Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+ const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
+ const gtl::ArraySlice<string>& target_node_names,
+ RewriteGraphMetadata* out_metadata);
+
+/////////////////////////////////////////////////////////
+// Custom rewrite actions for fed and fetched tensors. //
+/////////////////////////////////////////////////////////
+
+// A rewrite action that adds an _Arg node for a fed tensor.
+class ArgFeedRewrite : public PruneRewrite {
+ public:
+ ArgFeedRewrite(const string* endpoint_name,
+ const DeviceAttributes* device_info, int32 arg_index)
+ : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {}
+ Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+ Node** out_node) override;
+
+ private:
+ const int32 arg_index_;
+};
+
+// A rewrite action that adds a client-terminated _Recv node for a fed tensor.
+class RecvFeedRewrite : public PruneRewrite {
+ public:
+ using PruneRewrite::PruneRewrite;
+ Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+ Node** out_node) override;
+};
+
+// A rewrite action that adds a _Retval node for a fetched tensor.
+class RetvalFetchRewrite : public PruneRewrite {
+ public:
+ RetvalFetchRewrite(const string* endpoint_name,
+ const DeviceAttributes* device_info, int32 retval_index)
+ : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {}
+ Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+ Node** out_node) override;
+
+ private:
+ const int32 retval_index_;
+};
+
+// A rewrite action that adds a client-terminated _Send node for a
+// fetched tensor.
+class SendFetchRewrite : public PruneRewrite {
+ public:
+ using PruneRewrite::PruneRewrite;
+ Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+ Node** out_node) override;
+};
} // namespace subgraph
} // namespace tensorflow