Refactor pruning code to support custom node rewrites for feeds and fetches.
authorDerek Murray <mrry@google.com>
Wed, 21 Mar 2018 15:45:19 +0000 (08:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 15:48:04 +0000 (08:48 -0700)
PiperOrigin-RevId: 189913309

tensorflow/core/common_runtime/graph_execution_state.cc
tensorflow/core/common_runtime/graph_execution_state.h
tensorflow/core/graph/subgraph.cc
tensorflow/core/graph/subgraph.h

index f5e3d78..2f17af2 100644 (file)
@@ -237,6 +237,42 @@ void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
   }
 }
 
+Status GraphExecutionState::PruneGraph(
+    const BuildGraphOptions& options, Graph* graph,
+    subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
+  std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
+  feed_rewrites.reserve(options.callable_options.feed_size());
+  std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
+  fetch_rewrites.reserve(options.callable_options.fetch_size());
+  const DeviceAttributes* device_info =
+      &device_set_->client_device()->attributes();
+  if (options.use_function_convention) {
+    for (int i = 0; i < options.callable_options.feed_size(); ++i) {
+      feed_rewrites.emplace_back(new subgraph::ArgFeedRewrite(
+          &options.callable_options.feed(i), device_info, i));
+    }
+    for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
+      fetch_rewrites.emplace_back(new subgraph::RetvalFetchRewrite(
+          &options.callable_options.fetch(i), device_info, i));
+    }
+  } else {
+    for (const string& feed : options.callable_options.feed()) {
+      feed_rewrites.emplace_back(
+          new subgraph::RecvFeedRewrite(&feed, device_info));
+    }
+    for (const string& fetch : options.callable_options.fetch()) {
+      fetch_rewrites.emplace_back(
+          new subgraph::SendFetchRewrite(&fetch, device_info));
+    }
+  }
+  std::vector<string> target_node_names(
+      options.callable_options.target().begin(),
+      options.callable_options.target().end());
+  return subgraph::RewriteGraphForExecution(graph, feed_rewrites,
+                                            fetch_rewrites, target_node_names,
+                                            out_rewrite_metadata);
+}
+
 Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
   const GraphDef* graph_def = &original_graph_def_;
 
@@ -251,10 +287,8 @@ Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
       session_options_->config.graph_options().place_pruned_graph()) {
     // Rewrite the graph before placement.
     rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
-    TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
-        new_graph.get(), options.callable_options,
-        device_set_->client_device()->attributes(),
-        options.use_function_convention, rewrite_metadata_.get()));
+    TF_RETURN_IF_ERROR(
+        PruneGraph(options, new_graph.get(), rewrite_metadata_.get()));
   }
 
   // Save stateful placements before placing.
@@ -404,12 +438,7 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
   subgraph::RewriteGraphMetadata rewrite_metadata;
   if (session_options_ == nullptr ||
       !session_options_->config.graph_options().place_pruned_graph()) {
-    // Extract the subset of the graph that needs to be run, adding feed/fetch
-    // ops as needed.
-    TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
-        ng.get(), options.callable_options,
-        device_set_->client_device()->attributes(),
-        options.use_function_convention, &rewrite_metadata));
+    TF_RETURN_IF_ERROR(PruneGraph(options, ng.get(), &rewrite_metadata));
   } else {
     // This GraphExecutionState represents a graph that was
     // pruned when this was constructed, so we copy the metadata from
index 2312e1a..2154ef5 100644 (file)
@@ -177,6 +177,11 @@ class GraphExecutionState {
   void SaveStatefulNodes(Graph* graph);
   void RestoreStatefulNodes(Graph* graph);
 
+  // Extract the subset of the graph that needs to be run, adding feed/fetch
+  // ops as needed.
+  Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
+                    subgraph::RewriteGraphMetadata* out_rewrite_metadata);
+
   Status OptimizeGraph(const BuildGraphOptions& options,
                        std::unique_ptr<Graph>* optimized_graph);
 
index ca93d04..193cf88 100644 (file)
@@ -28,13 +28,13 @@ limitations under the License.
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
+namespace subgraph {
 
 // ----------------------------------------------------------------------------
 // Subgraph construction-related routines
@@ -44,6 +44,8 @@ namespace tensorflow {
 
 namespace {
 
+typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
+
 // Rewrite graph by replacing the output tensors specified in
 // "fed_outputs" with special feed nodes for each specified output
 // tensor, and removing any nodes that are now disconnected from the
@@ -53,59 +55,33 @@ namespace {
 // Return true on success.  On error, return false and sets *error to
 // an appropriate error message (and *g is left in an indeterminate
 // state).
-static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
-                         const gtl::ArraySlice<string>& fed_outputs,
-                         bool use_function_convention,
-                         subgraph::NameIndex* name_index,
-                         DataTypeVector* out_feed_types) {
+Status FeedInputs(
+    Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+    NameIndex* name_index, DataTypeVector* out_feed_types) {
   out_feed_types->clear();
-  out_feed_types->reserve(fed_outputs.size());
-  for (size_t i = 0; i < fed_outputs.size(); ++i) {
-    const string& t = fed_outputs[i];
+  out_feed_types->reserve(feed_rewrites.size());
+  for (size_t i = 0; i < feed_rewrites.size(); ++i) {
+    const string& t = feed_rewrites[i]->endpoint_name();
     TensorId id(ParseTensorName(t));
 
     auto iter = name_index->find(id.first);
     if (iter == name_index->end()) {
       return errors::NotFound("FeedInputs: unable to find feed output ", t);
     }
-    const Node* n = iter->second;
+    Node* n = iter->second;
     DCHECK_EQ(n->name(), id.first);
     if (id.second >= n->num_outputs()) {
       return errors::InvalidArgument(
           "FeedInputs: ", t, " should have output index < ", n->num_outputs());
     }
 
-    Node* recv_node;
-
-    if (!use_function_convention) {
-      TF_RETURN_IF_ERROR(
-          NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
-                      "_Recv")
-              .Attr("tensor_type", BaseType(n->output_type(id.second)))
-              .Attr("tensor_name", t)
-              .Attr("send_device", device_info.name())
-              .Attr("recv_device", device_info.name())
-              .Attr("send_device_incarnation",
-                    static_cast<int64>(device_info.incarnation()))
-              .Attr("client_terminated", true)
-              .Finalize(g, &recv_node));
-    } else {
-      // NOTE(mrry): We must include the index as part of the node
-      // name, because _Arg is a "stateful" kernel and therefore
-      // its name must uniquely identify a kernel instance across all
-      // graphs in the same session.
-      TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
-                                                     id.second, "_", i),
-                                     "_Arg")
-                             .Attr("T", BaseType(n->output_type(id.second)))
-                             .Attr("index", static_cast<int32>(i))
-                             .Finalize(g, &recv_node));
-    }
-    recv_node->set_assigned_device_name(device_info.name());
+    Node* feed_node;
+    TF_RETURN_IF_ERROR(
+        feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
 
     // Update name_index
-    (*name_index)[recv_node->name()] = recv_node;
-    g->AddControlEdge(g->source_node(), recv_node);
+    (*name_index)[feed_node->name()] = feed_node;
+    g->AddControlEdge(g->source_node(), feed_node);
 
     // Look through edges coming out of "n" for edges whose src_output() index
     // matches "output_index".  If found, replace the edges with a connection
@@ -119,7 +95,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
                   n->type_string() == "PlaceholderV2")) {
         // When feeding a Placeholder node, any outgoing control edges
         // will be replaced with a control edge from the replacement
-        // recv_node.
+        // feed_node.
         // TODO(josh11b,mrry): Come up with a more elegant way of addressing
         // the general version of this problem.
         to_remove.emplace_back(e);
@@ -128,10 +104,10 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
 
     for (const Edge* e : to_remove) {
       if (e->src_output() == id.second) {
-        g->AddEdge(recv_node, 0, e->dst(), e->dst_input());
+        g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
       } else {
         CHECK_EQ(Graph::kControlSlot, e->src_output());
-        g->AddControlEdge(recv_node, e->dst());
+        g->AddControlEdge(feed_node, e->dst());
       }
       g->RemoveEdge(e);
     }
@@ -140,9 +116,61 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
   return Status::OK();
 }
 
-static bool AddNodeToTargets(const string& node_or_tensor_name,
-                             const subgraph::NameIndex& name_index,
-                             std::unordered_set<const Node*>* targets) {
+Status FetchOutputs(
+    Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
+    NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
+    DataTypeVector* out_fetch_types) {
+  out_fetch_nodes->clear();
+  out_fetch_nodes->reserve(fetch_rewrites.size());
+  for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
+    const string& t = fetch_rewrites[i]->endpoint_name();
+
+    // Parse t into node_name and output_index.
+    TensorId id(ParseTensorName(t));
+
+    // Find node in graph with that name.
+    auto iter = name_index->find(id.first);
+    if (iter == name_index->end()) {
+      return errors::NotFound("FetchOutputs node ", t, ": not found");
+    }
+    Node* n = iter->second;
+    DCHECK_EQ(n->name(), id.first);
+    VLOG(2) << "Found fetch node for " << t;
+
+    // Validate output_index
+    if (n->num_outputs() == 0) {
+      return errors::InvalidArgument(
+          "Tried to fetch data for '", t,
+          "', which produces no output.  To run to a node but not fetch any "
+          "data, pass '",
+          t,
+          "' as an argument to the 'target_node_names' argument of the "
+          "Session::Run API.");
+    } else if (id.second >= n->num_outputs()) {
+      return errors::InvalidArgument("FetchOutputs ", t,
+                                     ": output index too large, must be < ",
+                                     n->num_outputs());
+    }
+
+    // Create the fetch Node and connect it up
+    Node* fetch_node;
+    TF_RETURN_IF_ERROR(
+        fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
+
+    // Update the index.
+    (*name_index)[fetch_node->name()] = fetch_node;
+
+    g->AddControlEdge(fetch_node, g->sink_node());
+    out_fetch_nodes->push_back(fetch_node);
+    out_fetch_types->push_back(BaseType(n->output_type(id.second)));
+  }
+
+  return Status::OK();
+}
+
+bool AddNodeToTargets(const string& node_or_tensor_name,
+                      const NameIndex& name_index,
+                      std::unordered_set<const Node*>* targets) {
   TensorId id = ParseTensorName(node_or_tensor_name);
   auto iter = name_index.find(id.first);
   if (iter == name_index.end()) {
@@ -154,9 +182,9 @@ static bool AddNodeToTargets(const string& node_or_tensor_name,
   return true;
 }
 
-static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
-                              const std::vector<Node*>& fetch_nodes,
-                              const gtl::ArraySlice<string>& target_nodes) {
+Status PruneForTargets(Graph* g, const NameIndex& name_index,
+                       const std::vector<Node*>& fetch_nodes,
+                       const gtl::ArraySlice<string>& target_nodes) {
   string not_found;
   std::unordered_set<const Node*> targets;
   for (Node* n : fetch_nodes) {
@@ -183,108 +211,149 @@ static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
 
 }  // namespace
 
-namespace subgraph {
+Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+                               Node** out_node) {
+  // NOTE(mrry): We must include the index as part of the node
+  // name, because _Arg is a "stateful" kernel and therefore
+  // its name must uniquely identify a kernel instance across all
+  // graphs in the same session.
+  TF_RETURN_IF_ERROR(
+      NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
+                                  feed_tensor.index, "_", arg_index_),
+                  "_Arg")
+          .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
+          .Attr("index", arg_index_)
+          .Finalize(g, out_node));
+  (*out_node)->set_assigned_device_name(device_info().name());
+  return Status::OK();
+}
 
-Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
-                    const gtl::ArraySlice<string>& fetch_outputs,
-                    bool use_function_convention, NameIndex* name_index,
-                    std::vector<Node*>* out_fetch_nodes,
-                    DataTypeVector* out_fetch_types) {
-  out_fetch_nodes->clear();
-  out_fetch_nodes->reserve(fetch_outputs.size());
-  for (size_t i = 0; i < fetch_outputs.size(); ++i) {
-    const string& t = fetch_outputs[i];
+Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+                                Node** out_node) {
+  TF_RETURN_IF_ERROR(
+      NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
+                                  feed_tensor.index),
+                  "_Recv")
+          .Attr("tensor_type",
+                BaseType(feed_tensor.node->output_type(feed_tensor.index)))
+          .Attr("tensor_name", endpoint_name())
+          .Attr("send_device", device_info().name())
+          .Attr("recv_device", device_info().name())
+          .Attr("send_device_incarnation",
+                static_cast<int64>(device_info().incarnation()))
+          .Attr("client_terminated", true)
+          .Finalize(g, out_node));
+
+  (*out_node)->set_assigned_device_name(device_info().name());
+  return Status::OK();
+}
 
-    // Parse t into node_name and output_index.
-    TensorId id(ParseTensorName(t));
+Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+                                   Node** out_node) {
+  // NOTE(mrry): We must include the index as part of the node
+  // name, because _Retval is a "stateful" kernel and therefore
+  // its name must uniquely identify a kernel instance across all
+  // graphs in the same session.
+  TF_RETURN_IF_ERROR(
+      NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
+                                  fetch_tensor.index, "_", retval_index_),
+                  "_Retval")
+          .Input(fetch_tensor.node, fetch_tensor.index)
+          .Attr("T",
+                BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
+          .Attr("index", retval_index_)
+          .Finalize(g, out_node));
+  (*out_node)->set_assigned_device_name(device_info().name());
+  return Status::OK();
+}
 
-    // Find node in graph with that name.
-    auto iter = name_index->find(id.first);
-    if (iter == name_index->end()) {
-      return errors::NotFound("FetchOutputs node ", t, ": not found");
-    }
-    Node* n = iter->second;
-    DCHECK_EQ(n->name(), id.first);
-    VLOG(2) << "Found fetch node for " << t;
+Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+                                 Node** out_node) {
+  TF_RETURN_IF_ERROR(
+      NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
+                                  fetch_tensor.index),
+                  "_Send")
+          .Input(fetch_tensor.node, fetch_tensor.index)
+          .Attr("tensor_name", endpoint_name())
+          .Attr("send_device", device_info().name())
+          .Attr("recv_device", device_info().name())
+          .Attr("send_device_incarnation",
+                static_cast<int64>(device_info().incarnation()))
+          .Attr("client_terminated", true)
+          .Finalize(g, out_node));
+  (*out_node)->set_assigned_device_name(device_info().name());
+  return Status::OK();
+}
 
-    // Validate output_index
-    if (n->num_outputs() == 0) {
-      return errors::InvalidArgument(
-          "Tried to fetch data for '", t,
-          "', which produces no output.  To run to a node but not fetch any "
-          "data, pass '",
-          t,
-          "' as an argument to the 'target_node_names' argument of the "
-          "Session::Run API.");
-    } else if (id.second >= n->num_outputs()) {
-      return errors::InvalidArgument("FetchOutputs ", t,
-                                     ": output index too large, must be < ",
-                                     n->num_outputs());
+Status RewriteGraphForExecution(
+    Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+    const gtl::ArraySlice<string>& fetch_outputs,
+    const gtl::ArraySlice<string>& target_node_names,
+    const DeviceAttributes& device_info, bool use_function_convention,
+    RewriteGraphMetadata* out_metadata) {
+  std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
+  feed_rewrites.reserve(fed_outputs.size());
+  if (use_function_convention) {
+    for (size_t i = 0; i < fed_outputs.size(); ++i) {
+      feed_rewrites.emplace_back(new ArgFeedRewrite(
+          &fed_outputs[i], &device_info, static_cast<int32>(i)));
     }
-
-    // Create the fetch Node and connect it up
-    Node* send_node;
-    if (!use_function_convention) {
-      TF_RETURN_IF_ERROR(
-          NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
-                      "_Send")
-              .Input(n, id.second)
-              .Attr("tensor_name", t)
-              .Attr("send_device", device_info.name())
-              .Attr("recv_device", device_info.name())
-              .Attr("send_device_incarnation",
-                    static_cast<int64>(device_info.incarnation()))
-              .Attr("client_terminated", true)
-              .Finalize(g, &send_node));
-    } else {
-      // NOTE(mrry): We must include the index as part of the node
-      // name, because _Retval is a "stateful" kernel and therefore
-      // its name must uniquely identify a kernel instance across all
-      // graphs in the same session.
-      TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
-                                                     id.second, "_", i),
-                                     "_Retval")
-                             .Input(n, id.second)
-                             .Attr("T", BaseType(n->output_type(id.second)))
-                             .Attr("index", static_cast<int32>(i))
-                             .Finalize(g, &send_node));
+  } else {
+    for (const string& fed_output : fed_outputs) {
+      feed_rewrites.emplace_back(
+          new RecvFeedRewrite(&fed_output, &device_info));
     }
-    send_node->set_assigned_device_name(device_info.name());
-
-    // Update the index.
-    (*name_index)[send_node->name()] = send_node;
+  }
 
-    g->AddControlEdge(send_node, g->sink_node());
-    out_fetch_nodes->push_back(send_node);
-    out_fetch_types->push_back(BaseType(n->output_type(id.second)));
+  std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
+  fetch_rewrites.reserve(fetch_outputs.size());
+  if (use_function_convention) {
+    for (size_t i = 0; i < fetch_outputs.size(); ++i) {
+      fetch_rewrites.emplace_back(new RetvalFetchRewrite(
+          &fetch_outputs[i], &device_info, static_cast<int32>(i)));
+    }
+  } else {
+    for (const string& fetch_output : fetch_outputs) {
+      fetch_rewrites.emplace_back(
+          new SendFetchRewrite(&fetch_output, &device_info));
+    }
   }
 
-  return Status::OK();
+  return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
+                                  target_node_names, out_metadata);
+}
+
+namespace {
+template <typename StringContainer>
+std::vector<string> ConvertToVector(StringContainer field) {
+  return std::vector<string>(field.begin(), field.end());
 }
+}  // namespace
 
 Status RewriteGraphForExecution(
-    Graph* g, const gtl::ArraySlice<string>& fed_outputs,
-    const gtl::ArraySlice<string>& fetch_outputs,
+    Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+    const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
     const gtl::ArraySlice<string>& target_node_names,
-    const DeviceAttributes& device_info, bool use_function_convention,
     RewriteGraphMetadata* out_metadata) {
-  if (fetch_outputs.empty() && target_node_names.empty()) {
+  if (fetch_rewrites.empty() && target_node_names.empty()) {
     return errors::InvalidArgument(
         "Must specify at least one target to fetch or execute.");
   }
 
   std::unordered_set<string> endpoints;
-  for (const string& endpoint_name : fed_outputs) {
-    auto result = endpoints.insert(endpoint_name);
+  for (const auto& feed_rewrite : feed_rewrites) {
+    auto result = endpoints.insert(feed_rewrite->endpoint_name());
     if (!result.second) {
-      return errors::InvalidArgument("Endpoint \"", endpoint_name,
+      return errors::InvalidArgument("Endpoint \"",
+                                     feed_rewrite->endpoint_name(),
                                      "\" fed more than once.");
     }
   }
 
-  for (const auto& fetch : fetch_outputs) {
-    if (endpoints.count(fetch) > 0) {
-      return errors::InvalidArgument(fetch, " is both fed and fetched.");
+  for (const auto& fetch_rewrite : fetch_rewrites) {
+    if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
+      return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
+                                     " is both fed and fetched.");
     }
   }
 
@@ -297,19 +366,17 @@ Status RewriteGraphForExecution(
   }
 
   // Add the feeds.  This may replace nodes in the graph, including the nodes
-  // currently listed in "fetch_nodes".  We pass "name_index" so the index is
+  // currently listed in "fetch_rewrites".  We pass "name_index" so the index is
   // kept up to date.
-  if (!fed_outputs.empty()) {
-    TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
-                                  use_function_convention, &name_index,
-                                  &out_metadata->feed_types));
+  if (!feed_rewrites.empty()) {
+    TF_RETURN_IF_ERROR(
+        FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
   }
 
   // Add the fetch nodes, also updating "name_index".
   std::vector<Node*> fetch_nodes;
-  if (!fetch_outputs.empty()) {
-    TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
-                                    use_function_convention, &name_index,
+  if (!fetch_rewrites.empty()) {
+    TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
                                     &fetch_nodes, &out_metadata->fetch_types));
   }
 
@@ -323,25 +390,6 @@ Status RewriteGraphForExecution(
   return Status::OK();
 }
 
-namespace {
-template <typename StringContainer>
-std::vector<string> ConvertToVector(StringContainer field) {
-  return std::vector<string>(field.begin(), field.end());
-}
-}  // namespace
-
-Status RewriteGraphForExecution(Graph* g,
-                                const CallableOptions& callable_options,
-                                const DeviceAttributes& device_info,
-                                bool use_function_convention,
-                                RewriteGraphMetadata* out_metadata) {
-  return RewriteGraphForExecution(g, ConvertToVector(callable_options.feed()),
-                                  ConvertToVector(callable_options.fetch()),
-                                  ConvertToVector(callable_options.target()),
-                                  device_info, use_function_convention,
-                                  out_metadata);
-}
-
 }  // namespace subgraph
 
 }  // namespace tensorflow
index 0dc5958..ba35846 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/protobuf/config.pb.h"
@@ -39,6 +40,37 @@ struct RewriteGraphMetadata {
   DataTypeVector fetch_types;
 };
 
+// Describes the action to take on a particular tensor endpoint (described by
+// a "<node_name>:<output_index>" pair) when pruning the graph.
+//
+// The `AddNode()` method must be overridden to describe this action. The method
+// will be invoked once during `RewriteGraphForExecution()` with tensor endpoint
+// named by `endpoint_name`, and it may either create a single new node, or fail
+// with an error if the resulting graph would be invalid.
+class PruneRewrite {
+ public:
+  // `endpoint_name` and `device_info` must outlive this object.
+  PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info)
+      : endpoint_name_(endpoint_name), device_info_(device_info) {}
+  virtual ~PruneRewrite() {}
+
+  // Creates a new node whose output replaces the given `tensor` in graph `g`.
+  // The node will be assigned to the device named in `device_info`.
+  virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor,
+                         Node** out_node) = 0;
+
+  // Returns the name of the tensor to which this rewrite applies.
+  const string& endpoint_name() { return *endpoint_name_; }
+
+ protected:
+  // The device on which the new node will be created.
+  const DeviceAttributes& device_info() { return *device_info_; }
+
+ private:
+  const string* const endpoint_name_;          // Not owned.
+  const DeviceAttributes* const device_info_;  // Not owned.
+};
+
 // Rewrite the graph structure of "*g" to deal with feeding node
 // outputs, fetching node outputs, and only running a subset of the
 // graph.  "fed_outputs" and "fetch_outputs" are both lists of
@@ -49,7 +81,7 @@ struct RewriteGraphMetadata {
 // In the resulting graph "*g", output edges in "fed_outputs" have
 // been redirected to special "_recv" nodes introduced into the graph.
 // If these fed nodes are not needed in order to compute the effects
-// of the nodes in "targets_nodes" and "fetch_outputs", then these may
+// of the nodes in "target_node_names" and "fetch_outputs", then these may
 // be omitted from the graph.
 //
 // In the resulting graph "*g", additional "_send" nodes are connected
@@ -71,25 +103,61 @@ Status RewriteGraphForExecution(
     const gtl::ArraySlice<string>& target_node_names,
     const DeviceAttributes& device_info, bool use_function_convention,
     RewriteGraphMetadata* out_metadata);
-Status RewriteGraphForExecution(Graph* g,
-                                const CallableOptions& callable_options,
-                                const DeviceAttributes& device_info,
-                                bool use_function_convention,
-                                RewriteGraphMetadata* out_metadata);
-
-typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
-
-// Augment "*g" by adding special "fetch" nodes that connect to the
-// tensor outputs specified in "fetch_outputs" to retrieve the output
-// of the tensors.  The new nodes added are set up to execute on
-// "client_device_name", and are returned in "*fetch_nodes".
-//
-// Return OK on success.  On error, return false and sets *error to
-// an appropriate error message (and *g is left in an indeterminate
-// state).
-Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
-                    const gtl::ArraySlice<string>& fetch_outputs,
-                    NameIndex* name_index, std::vector<Node*>* fetch_nodes);
+
+// A more general version of the above function that supports
+// customizable rewriting actions for each fed and fetched tensor.
+Status RewriteGraphForExecution(
+    Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
+    const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
+    const gtl::ArraySlice<string>& target_node_names,
+    RewriteGraphMetadata* out_metadata);
+
+/////////////////////////////////////////////////////////
+// Custom rewrite actions for fed and fetched tensors. //
+/////////////////////////////////////////////////////////
+
+// A rewrite action that adds an _Arg node for a fed tensor.
+class ArgFeedRewrite : public PruneRewrite {
+ public:
+  ArgFeedRewrite(const string* endpoint_name,
+                 const DeviceAttributes* device_info, int32 arg_index)
+      : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {}
+  Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+                 Node** out_node) override;
+
+ private:
+  const int32 arg_index_;
+};
+
+// A rewrite action that adds a client-terminated _Recv node for a fed tensor.
+class RecvFeedRewrite : public PruneRewrite {
+ public:
+  using PruneRewrite::PruneRewrite;
+  Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+                 Node** out_node) override;
+};
+
+// A rewrite action that adds a _Retval node for a fetched tensor.
+class RetvalFetchRewrite : public PruneRewrite {
+ public:
+  RetvalFetchRewrite(const string* endpoint_name,
+                     const DeviceAttributes* device_info, int32 retval_index)
+      : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {}
+  Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+                 Node** out_node) override;
+
+ private:
+  const int32 retval_index_;
+};
+
+// A rewrite action that adds a client-terminated _Send node for a
+// fetched tensor.
+class SendFetchRewrite : public PruneRewrite {
+ public:
+  using PruneRewrite::PruneRewrite;
+  Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
+                 Node** out_node) override;
+};
 
 }  // namespace subgraph
 }  // namespace tensorflow