Rollforward switch group identification with fixes.
authorJacques Pienaar <jpienaar@google.com>
Tue, 13 Feb 2018 01:15:51 +0000 (17:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 01:19:48 +0000 (17:19 -0800)
Fixed computing the switch depth: with the erroneous switch depth incorrect
clusters could be formed. Change the way the switch depth is determined (the
switch depth is now on the output side, so a switch always has a switch depth
one higher than all its inputs), add further checking during execution.

PiperOrigin-RevId: 185461054

tensorflow/compiler/tf2xla/functionalize_control_flow.cc
tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
tensorflow/compiler/tf2xla/graph_compiler.cc

index bf30410..f816979 100644 (file)
@@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
 Status FunctionalizeLoop(Graph* graph, Frame* frame,
                          FunctionLibraryDefinition* library) {
   VLOG(2) << "Frame " << frame->name << " before: "
-          << dump_graph::DumpGraphToFile("functionalize_before", *graph);
+          << dump_graph::DumpGraphToFile("functionalize_before", *graph,
+                                         library);
 
   // Split loop-varying Enter nodes with multiple successors. If the same
   // Tensor is fed as input to multiple loop arguments, we may end up with a
@@ -470,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
 
   VLOG(2) << "Frame " << frame->name << " condition: "
-          << dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
+          << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
           << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
 
   static std::atomic<int64> sequence_num(0LL);
@@ -551,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
   frame->parent->nodes.insert(while_node);
 
   VLOG(2) << "Frame " << frame->name << " after: "
-          << dump_graph::DumpGraphToFile("functionalize_after", *graph);
+          << dump_graph::DumpGraphToFile("functionalize_after", *graph,
+                                         library);
 
   return Status::OK();
 }
@@ -584,11 +586,11 @@ class FunctionalizeCond {
     explicit CondArgNode(Node* input) : input(input) {}
     string ToString() const {
       return strings::StrCat("input=", input->name(),
-                             " switches=", NodesToString(switch_nodes));
+                             " switches=", NodesToString(switches));
     }
 
     Node* input;
-    std::vector<Node*> switch_nodes;
+    std::vector<Node*> switches;
   };
   using CondArgNodes = std::vector<CondArgNode>;
 
@@ -602,15 +604,22 @@ class FunctionalizeCond {
     int count;
   };
 
-  struct PredicateSwitches {
-    explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
+  // Group of switch nodes that will be part of the same XlaIf.
+  struct SwitchCluster {
+    explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
+    string ToString() const {
+      return strings::StrCat(name, " predicate=", predicate->name(),
+                             " switches=", NodesToString(switches));
+    }
 
+    string name;
     Node* predicate;
     std::vector<Node*> switches;
   };
 
-  FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
-      : library_(library), graph_(graph) {}
+  FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
+                    bool dump_graphs)
+      : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
 
   // Perform the actual cond functionalization. Iterate over groups of switch
   // nodes (linked by common predicate), from innermost to outermost, and
@@ -621,27 +630,25 @@ class FunctionalizeCond {
   // frontier (the nodes where the cond ends).
   StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
                      std::unordered_set<Node*>>>
-  DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
+  DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
 
   // Returns XlaIf node created from subgraph of merge and switch nodes. This
   // encapsulates the process of extracting the bodies needed for the then and
   // else branch, creates a XlaIf node, removing the nodes of the branches from
   // the graph and replacing the merge node with a XlaIf.
   StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
-                                 const std::vector<Node*>& switch_nodes,
-                                 const std::vector<Node*>& merge_nodes,
-                                 Node* predicate);
+                                 const SwitchCluster& switch_cluster,
+                                 const std::vector<Node*>& switches);
 
   // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
   StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
-                                     const std::vector<Node*>& switch_nodes,
-                                     const std::vector<Node*>& merge_nodes,
-                                     Node* predicate);
+                                     const SwitchCluster& switch_cluster,
+                                     const std::vector<Node*>& merge_nodes);
 
   // Extracts a function body corresponding to the given input edge of the merge
   // node.
   Status ExtractBody(const CondArgNodes& cond_arg_nodes,
-                     const std::vector<Node*>& switch_nodes,
+                     const std::vector<Node*>& switches,
                      const std::vector<Node*>& merge_nodes, int input_edge,
                      Graph* body);
 
@@ -652,9 +659,9 @@ class FunctionalizeCond {
   // Adds all output edges from the `if_node`.
   Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
 
-  // Returns the switches of graph_ (along with grouping predicates) in
-  // postorder. Dead switch nodes are skipped and removed from the graph.
-  std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
+  // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
+  // skipped and removed from the graph.
+  StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
 
   // Update the state for destination based on the state of source and the node
   // being updated.
@@ -677,6 +684,7 @@ class FunctionalizeCond {
 
   FunctionLibraryDefinition* library_;
   Graph* graph_;
+  bool dump_graphs_;
 };
 
 bool IsDeadSwitch(const Node* node) {
@@ -724,10 +732,13 @@ Status FunctionalizeCond::ValidateFrontier(
           ") in both Else and Then branch should be in Both.");
     }
   }
-  if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
-      pending[kElseBranch].empty()) {
-    return errors::Internal("Unexpected empty frontier for switch nodes");
-  }
+  // An empty frontier indicates a dead switch. Above we attempt to remove dead
+  // switch nodes, but not all are removed so don't treat it as an error yet.
+  // TODO(jpienaar): Find out why dead switch nodes remain.
+  // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
+  //     pending[kElseBranch].empty()) {
+  //   return errors::Internal("Unexpected empty frontier for switch nodes");
+  // }
   return Status::OK();
 }
 
@@ -754,33 +765,191 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
   return Status::OK();
 }
 
-std::vector<FunctionalizeCond::PredicateSwitches>
+StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
 FunctionalizeCond::DeterminePredicateSwitchOrder() {
+  struct Cluster {
+    bool operator==(const Cluster& other) const {
+      return representative == other.representative;
+    }
+    int representative = -1;
+  };
+
+  // Perform a DFS over the graph and
+  // * Determine the reverse topological order of the nodes (there should be no
+  //   cycles at this point so the post-order numbering corresponds to the
+  //   reverse topological sorting);
+  // * Identify dead switches;
+  // * Initialize the cluster's representative;
+  std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
   std::vector<Node*> dead_switches;
   std::vector<Node*> switch_order;
-  DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
+  std::vector<Node*> rev_topo_sorted_nodes;
+  DFS(*graph_, nullptr, [&](Node* n) {
+    clusters[n->id()].Get().representative = n->id();
     if (IsSwitch(n)) {
       if (IsDeadSwitch(n)) {
         dead_switches.push_back(n);
       } else {
+        rev_topo_sorted_nodes.push_back(n);
         switch_order.push_back(n);
       }
+    } else if (n->IsOp()) {
+      // Exclude src and sink nodes from further consideration.
+      rev_topo_sorted_nodes.push_back(n);
     }
   });
 
+  std::vector<SwitchCluster> switch_clusters;
+  // Return early if there are no switches in the graph.
+  if (switch_order.empty()) {
+    return switch_clusters;
+  }
+
   // Remove all dead switch nodes.
   for (Node* n : dead_switches) {
     VLOG(2) << "Removing dead switch: " << n->DebugString();
     graph_->RemoveNode(n);
   }
 
-  std::vector<PredicateSwitches> predicate_switch_order;
-  if (switch_order.empty()) {
-    return predicate_switch_order;
+  // Identify switch nodes that are part of the same control flow context by
+  // considering the operands of operations: an operation is part of the same
+  // control context as its operands unless the operation is a switch. Control
+  // dependencies are considered part of the same control flow context if the
+  // switch depth is the same (see comment below).
+
+  // entry_cluster records the input cluster to a switch node. This is used when
+  // merging with a merge node where the dst's cluster is merged with the entry
+  // cluster of the merge node's cluster (which corresponds to a switch cluster
+  // and so has an entry cluster).
+  std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
+
+  // Returns the output cluster of a node. Where the output cluster is cluster
+  // where the output of the node is used. For non-merge nodes this is simply
+  // the cluster they are part of, while for merge nodes it is the entry cluster
+  // of the cluster they are part of (this will correspond to the entry node of
+  // a switch node that dominates the merge).
+  auto find_output_cluster = [&](Node* n) {
+    UnionFind<Cluster>* cluster = &clusters[n->id()];
+    if (!IsMerge(n)) return cluster;
+    auto it = entry_cluster.find(clusters[n->id()].Get().representative);
+    // If the cluster is not found in the entry_cluster map then an
+    // instruction not dominated by a switch node has been merged into the
+    // cluster of the merge. This indicates a failure of the clustering.
+    CHECK(it != entry_cluster.end())
+        << "Unable to find entry for n=" << n->id() << " ("
+        << cluster->Get().representative << ")";
+    return it->second;
+  };
+
+  // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
+  std::vector<int> switch_depth(graph_->num_node_ids());
+  for (auto it = rev_topo_sorted_nodes.rbegin();
+       it != rev_topo_sorted_nodes.rend(); ++it) {
+    Node* n = *it;
+
+    // Compute switch depth.
+    int new_switch_depth = 0;
+    for (const Edge* e : n->in_edges()) {
+      Node* src = e->src();
+      new_switch_depth = std::max(
+          new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
+    }
+    switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
+
+    // Only merge the input operands of a switch. The switch's clustering itself
+    // is determined by the interaction of the switch's outputs.
+    if (IsSwitch(n)) {
+      Node* input;
+      TF_CHECK_OK(n->input_node(0, &input));
+      entry_cluster[n->id()] = &clusters[input->id()];
+      UnionFind<Cluster>* cluster = find_output_cluster(input);
+      int cluster_depth = switch_depth[cluster->Get().representative];
+      // Merge the inputs of the switch node with one another. This results in
+      // predicates and control input residing in the same cluster.
+      for (const Edge* e : n->in_edges()) {
+        Node* src = e->src();
+        UnionFind<Cluster>* src_cluster = find_output_cluster(src);
+        int src_cluster_depth = switch_depth[src_cluster->Get().representative];
+        if (cluster_depth != src_cluster_depth) {
+          return errors::InvalidArgument(
+              "Unable to functionalize control flow in graph: Switch ('",
+              n->name(), "') has operands ('", input->name(), "' and '",
+              src->name(), "') that have different switch depths (",
+              cluster_depth, " != ", src_cluster_depth, ")");
+        }
+        cluster->Merge(src_cluster);
+      }
+      continue;
+    }
+
+    for (const Edge* e : n->in_edges()) {
+      Node* src = e->src();
+      if (!src->IsOp()) continue;
+      UnionFind<Cluster>* cluster = find_output_cluster(src);
+      // Merge a node with its data operands and with its control operands if
+      // the src and dst are in the same ControlContext. The ControlContext is
+      // not explicitly available here, and instead the switch depth is used as
+      // a proxy here. Due to the invariant that control edges can only be from
+      // a containing scope to an inner scope or from the inner scope to its
+      // containing scope (for exit nodes), the switch depth will only match if
+      // the src and dst are in the same ControlContext. Control edges between
+      // ControlContexts are handled during the extraction.
+      int src_id = cluster->Get().representative;
+      int src_depth = switch_depth[src_id];
+      if (!e->IsControlEdge() || new_switch_depth == src_depth) {
+        if (src_depth != new_switch_depth) {
+          return errors::InvalidArgument(
+              "Unable to functionalize control flow in graph: Operand ('",
+              src->name(), "') and operator ('", n->name(),
+              "') have different switch depths (", src_depth,
+              " != ", new_switch_depth, ")");
+        }
+        cluster->Merge(&clusters[n->id()]);
+      }
+    }
   }
 
+  if (dump_graphs_) {
+    // Mark the switch cluster each node is part of.
+    for (Node* n : graph_->nodes()) {
+      n->ClearAttr("_XlaFunctionalizeSwitchGroup");
+      n->AddAttr("_XlaFunctionalizeSwitchGroup",
+                 clusters[n->id()].Get().representative);
+    }
+    LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
+              << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
+                                             library_);
+  }
+
+  // Verify all the nodes of a cluster are at the same depth.
+  std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
+  for (Node* n : graph_->nodes()) {
+    int depth = switch_depth[n->id()];
+    int cluster_rep = clusters[n->id()].Get().representative;
+    auto it = cluster_to_depth_node.find(cluster_rep);
+    if (it == cluster_to_depth_node.end()) {
+      cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
+    } else {
+      if (it->second.first != depth) {
+        return errors::Internal(
+            "Illegal clustering created, mismatch in depths:", "\n\t",
+            n->DebugString(), "(", clusters[n->id()].Get().representative,
+            ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
+            "(", clusters[n->id()].Get().representative, ") at depth ",
+            it->second.first);
+      }
+    }
+  }
+
+  struct Hash {
+    size_t operator()(const std::pair<Node*, Cluster>& item) const {
+      return Hash64Combine(hash<Node*>()(item.first),
+                           std::hash<int>()(item.second.representative));
+    }
+  };
+
   // Merge Switch nodes with common predicate.
-  std::unordered_map<Node*, int> predicate_index;
+  std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
   // The nodes in switch_order are in reverse topological order, but the
   // clustered switches need not be (i.e., when considered as a cluster one
   // element of a cluster may be later in the topological order than another
@@ -789,13 +958,19 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
   for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
     Node* pred;
     TF_CHECK_OK((*it)->input_node(1, &pred));
-    if (predicate_index.find(pred) == predicate_index.end()) {
-      predicate_index[pred] = predicate_switch_order.size();
-      predicate_switch_order.emplace_back(pred);
+    auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
+    if (predicate_index.find(repr) == predicate_index.end()) {
+      predicate_index[repr] = switch_clusters.size();
+      switch_clusters.emplace_back(pred);
+      // Generate a name by concatenating with the cluster representative as
+      // there could be multiple switch clusters with the same predicate.
+      switch_clusters[predicate_index[repr]].name =
+          strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
     }
-    predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
+    switch_clusters[predicate_index[repr]].switches.push_back(*it);
   }
-  return predicate_switch_order;
+
+  return switch_clusters;
 }
 
 StatusOr<std::vector<Node*>>
@@ -843,10 +1018,10 @@ StatusOr<
     std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
               std::unordered_set<Node*>>>
 FunctionalizeCond::DetermineBranchMapAndFrontier(
-    const std::vector<Node*>& switches) {
+    const SwitchCluster& switch_cluster) {
   std::unordered_map<Node*, ForwardFlowNode> branch_map;
   std::unordered_set<Node*> frontier;
-  std::vector<Node*> stack = switches;
+  std::vector<Node*> stack = switch_cluster.switches;
   std::vector<bool> visited(graph_->num_node_ids(), false);
   while (!stack.empty()) {
     Node* n = stack.back();
@@ -888,7 +1063,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
     }
   }
 
-  if (VLOG_IS_ON(2)) {
+  if (dump_graphs_) {
     for (const auto& kv : branch_map) {
       // Append attribute to the graph if running with logging to make the
       // changes clearer in the visualization.
@@ -900,8 +1075,8 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
 }
 
 Status FunctionalizeCond::FunctionalizeInternal() {
-  std::vector<PredicateSwitches> predicate_switch_order =
-      DeterminePredicateSwitchOrder();
+  TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
+                      DeterminePredicateSwitchOrder());
 
   // Iterate from innermost set of clustered switches to outermost, replacing
   // matching switch->merge subgraphs with single XlaIf nodes.
@@ -914,10 +1089,12 @@ Status FunctionalizeCond::FunctionalizeInternal() {
     std::unordered_map<Node*, ForwardFlowNode> branch_map;
     std::unordered_set<Node*> frontier;
     TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
-                        DetermineBranchMapAndFrontier(ps.switches));
+                        DetermineBranchMapAndFrontier(ps));
 
-    VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
-            << dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
+    if (dump_graphs_)
+      LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
+                << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
+                                               library_);
     TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
 
     // Sort the merge and switch nodes using NodeCmp. The switch-nodes are
@@ -934,7 +1111,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
         input_index[in] = cond_arg_nodes.size();
         cond_arg_nodes.emplace_back(in);
       }
-      cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node);
+      cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
     }
     std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
     std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
@@ -943,9 +1120,8 @@ Status FunctionalizeCond::FunctionalizeInternal() {
                         EnsureDominanceAndReturnNonDominatedControlNodes(
                             branch_map, ps.switches));
 
-    TF_ASSIGN_OR_RETURN(
-        Node * if_node,
-        ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate));
+    TF_ASSIGN_OR_RETURN(Node * if_node,
+                        ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
     for (Node* old : old_control_nodes) {
       graph_->AddControlEdge(old, if_node);
     }
@@ -954,25 +1130,26 @@ Status FunctionalizeCond::FunctionalizeInternal() {
       graph_->RemoveNode(del_kv.first);
     }
     for (auto& kv : cond_arg_nodes) {
-      for (Node* node : kv.switch_nodes) {
+      for (Node* node : kv.switches) {
         graph_->RemoveNode(node);
       }
     }
-    VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): "
-            << dump_graph::DumpGraphToFile("functionalize_ac", *graph_);
+    if (dump_graphs_)
+      LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
+                << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
+                                               library_);
   }
   return Status::OK();
 }
 
 StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
-    const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
-    const std::vector<Node*>& merge_nodes, Node* predicate) {
-  VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
-          << NodesToString(switch_nodes);
+    const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
+    const std::vector<Node*>& merge_nodes) {
+  VLOG(2) << "Build if op for " << switch_cluster.name;
 
   NodeDef if_def;
   // Create a new If node using the name of the merge node.
-  NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf");
+  NodeDefBuilder builder(switch_cluster.name, "XlaIf");
   string branch[] = {"else_branch", "then_branch"};
   for (int i = 0; i < 2; ++i) {
     static std::atomic<int64> sequence_num(0LL);
@@ -982,12 +1159,9 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
     body_name.set_name(
         strings::StrCat("_functionalize_if_", branch[i], "_", id));
     auto body = xla::MakeUnique<Graph>(graph_->op_registry());
-    TF_RETURN_IF_ERROR(
-        ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get()));
+    TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
+                                   merge_nodes, i, body.get()));
     VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
-    VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): "
-            << dump_graph::DumpGraphToFile(
-                   strings::StrCat("functionalize_", branch[i]), *body);
     FunctionDef body_fdef;
     TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
     TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
@@ -999,7 +1173,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
   DataTypeVector in_arg_types;
   for (auto& kv : cond_arg_nodes) {
     bool inserted = false;
-    for (const Node* arg : kv.switch_nodes) {
+    for (const Node* arg : kv.switches) {
       const Edge* in_edge;
       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
       if (in_edge->IsControlEdge()) {
@@ -1026,10 +1200,11 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
   builder.Attr("Tout", out_type);
 
   builder.Attr("Tcond", DT_BOOL);
-  builder.Device(predicate->assigned_device_name());
+  builder.Device(switch_cluster.predicate->assigned_device_name());
   // Conditional should be the first input ...
   builder.Input(
-      NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0)));
+      NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
+                              switch_cluster.predicate->output_type(0)));
   // ... followed by the other inputs.
   builder.Input(inputs);
 
@@ -1039,7 +1214,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
 }
 
 Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
-                                      const std::vector<Node*>& switch_nodes,
+                                      const std::vector<Node*>& switches,
                                       const std::vector<Node*>& merge_nodes,
                                       int input_edge, Graph* body) {
   VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
@@ -1049,7 +1224,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
   int arg_count = 0;
   for (auto& kv : cond_arg_nodes) {
     Node* arg_node = nullptr;
-    for (const auto* arg : kv.switch_nodes) {
+    for (const auto* arg : kv.switches) {
       DataType dtype = arg->input_type(0);
       if (arg_node == nullptr) {
         TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
@@ -1073,8 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
       node_map.at(in->id()) = body->CopyNode(in);
     }
 
-    if (std::find(switch_nodes.begin(), switch_nodes.end(), in) ==
-        switch_nodes.end()) {
+    if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
       body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
                     node_map.at(node->id()), 0);
     } else {
@@ -1096,7 +1270,7 @@ Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
   graph_->AddEdge(predicate, 0, if_node, index++);
   for (auto& kv : cond_arg_nodes) {
     bool inserted = false;
-    for (const Node* arg : kv.switch_nodes) {
+    for (const Node* arg : kv.switches) {
       const Edge* in_edge;
       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
       if (in_edge->IsControlEdge()) {
@@ -1139,16 +1313,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
 }
 
 StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
-    const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
-    const std::vector<Node*>& merge_nodes, Node* predicate) {
-  VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
+    const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
+    const std::vector<Node*>& merge_nodes) {
+  VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
           << NodesToString(merge_nodes);
 
   // Extract bodies and builds a If operator.
   TF_ASSIGN_OR_RETURN(
       Node * if_node,
-      BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate));
-  TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node));
+      BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
+  TF_RETURN_IF_ERROR(
+      AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
   TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
 
   return if_node;
@@ -1157,18 +1332,19 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
 Status FunctionalizeCond::Functionalize(Graph* graph,
                                         FunctionLibraryDefinition* library) {
   VLOG(1) << "FunctionalizeCond::Functionalize";
-  FunctionalizeCond fc(graph, library);
+  FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
   return fc.FunctionalizeInternal();
 }
 
 }  // namespace
 
-// Transformation that converts Tensorflow's graph control flow constructs into
+// Transformation that converts TensorFlow's graph control flow constructs into
 // functional equivalents.
 Status FunctionalizeControlFlow(Graph* graph,
                                 FunctionLibraryDefinition* library) {
   VLOG(2) << "FunctionalizeControlFlow (initial): "
-          << dump_graph::DumpGraphToFile("functionalize_initial", *graph);
+          << dump_graph::DumpGraphToFile("functionalize_initial", *graph,
+                                         library);
   // Note: BuildControlFlowInfo() requires that the graph's source node is
   // connected to all source nodes in the graph. Many graphs violate this
   // invariant.
@@ -1180,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph,
   for (Node* node : graph->op_nodes()) {
     const ControlFlowInfo& cf = cf_info[node->id()];
 
-    VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
+    VLOG(2) << "node: " << node->name() << " (" << node->id()
+            << ") frame_name: " << cf.frame_name
             << " frame: " << (cf.frame ? cf.frame->name() : "---")
             << " parent_frame: "
             << (cf.parent_frame ? cf.parent_frame->name() : "---");
@@ -1248,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph,
   TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
 
   VLOG(2) << "FunctionalizeControlFlow (final): "
-          << dump_graph::DumpGraphToFile("functionalize_final", *graph);
+          << dump_graph::DumpGraphToFile("functionalize_final", *graph,
+                                         library);
   return Status::OK();
 }
 
index 71f12a1..bc7276c 100644 (file)
@@ -38,10 +38,11 @@ namespace {
 
 // Returns the names of the "then" and "else" functions for the XlaIf node in a
 // graph.
-Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn,
-                         NameAttrList* else_fn) {
+Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
+                         NameAttrList* then_fn, NameAttrList* else_fn) {
   for (const NodeDef& node : graph.node()) {
     if (node.op() == "XlaIf") {
+      *op_name = node.name();
       const NameAttrList* result;
       TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
       *then_fn = *result;
@@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) {
 
   GraphDef graph_def;
   graph.ToGraphDef(&graph_def);
+  string op_name;
   NameAttrList then_fn;
   NameAttrList else_fn;
-  TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn));
+  TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
   InstantiationResultForTest else_result;
   TF_EXPECT_OK(
       InstantiateFunctionForTest(else_fn.name(), library, &else_result));
@@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
-    auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less,
+    auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
                             std::initializer_list<Input>{less, y, x}, then_fn,
                             else_fn, {DT_INT32});
     GraphDef expected;
index 1418d95..058a1f2 100644 (file)
@@ -134,7 +134,7 @@ Status GraphCompiler::Compile() {
       TF_RET_CHECK(src->id() < output_registry.size());
       const NodeOutputs& src_outputs = output_registry[src->id()];
 
-      tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()];
+      tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
     }
 
     OpKernelContext op_context(&params, n->num_outputs());