From: Jacques Pienaar Date: Tue, 13 Feb 2018 01:15:51 +0000 (-0800) Subject: Rollforward switch group identification with fixes. X-Git-Tag: upstream/v1.7.0~31^2~764 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4b9ef6c8e07dea7d18f552fa4955c3176646f95d;p=platform%2Fupstream%2Ftensorflow.git Rollforward switch group identification with fixes. Fixed computing the switch depth: with the erroneous switch depth incorrect clusters could be formed. Change the way the switch depth is determined (the switch depth is now on the output side, so a switch always has a switch depth one higher than all its inputs), add further checking during execution. PiperOrigin-RevId: 185461054 --- diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index bf30410..f816979 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, Status FunctionalizeLoop(Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph); + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -470,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); @@ -551,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph); + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); return Status::OK(); } @@ -584,11 +586,11 @@ class FunctionalizeCond { explicit CondArgNode(Node* input) : input(input) {} string ToString() const { return strings::StrCat("input=", input->name(), - " switches=", NodesToString(switch_nodes)); + " switches=", NodesToString(switches)); } Node* input; - std::vector switch_nodes; + std::vector switches; }; using CondArgNodes = std::vector; @@ -602,15 +604,22 @@ class FunctionalizeCond { int count; }; - struct PredicateSwitches { - explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} + // Group of switch nodes that will be part of the same XlaIf. + struct SwitchCluster { + explicit SwitchCluster(Node* predicate) : predicate(predicate) {} + string ToString() const { + return strings::StrCat(name, " predicate=", predicate->name(), + " switches=", NodesToString(switches)); + } + string name; Node* predicate; std::vector switches; }; - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : library_(library), graph_(graph) {} + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + bool dump_graphs) + : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} // Perform the actual cond functionalization. Iterate over groups of switch // nodes (linked by common predicate), from innermost to outermost, and @@ -621,27 +630,25 @@ class FunctionalizeCond { // frontier (the nodes where the cond ends). StatusOr, std::unordered_set>> - DetermineBranchMapAndFrontier(const std::vector& switches); + DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); // Returns XlaIf node created from subgraph of merge and switch nodes. This // encapsulates the process of extracting the bodies needed for the then and // else branch, creates a XlaIf node, removing the nodes of the branches from // the graph and replacing the merge node with a XlaIf. StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& switches); // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& merge_nodes); // Extracts a function body corresponding to the given input edge of the merge // node. Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body); @@ -652,9 +659,9 @@ class FunctionalizeCond { // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Returns the switches of graph_ (along with grouping predicates) in - // postorder. Dead switch nodes are skipped and removed from the graph. - std::vector DeterminePredicateSwitchOrder(); + // Returns the switch clusters of graph_ in postorder. Dead switch nodes are + // skipped and removed from the graph. + StatusOr> DeterminePredicateSwitchOrder(); // Update the state for destination based on the state of source and the node // being updated. @@ -677,6 +684,7 @@ class FunctionalizeCond { FunctionLibraryDefinition* library_; Graph* graph_; + bool dump_graphs_; }; bool IsDeadSwitch(const Node* node) { @@ -724,10 +732,13 @@ Status FunctionalizeCond::ValidateFrontier( ") in both Else and Then branch should be in Both."); } } - if (pending[kBoth].empty() && pending[kThenBranch].empty() && - pending[kElseBranch].empty()) { - return errors::Internal("Unexpected empty frontier for switch nodes"); - } + // An empty frontier indicates a dead switch. Above we attempt to remove dead + // switch nodes, but not all are removed so don't treat it as an error yet. + // TODO(jpienaar): Find out why dead switch nodes remain. + // if (pending[kBoth].empty() && pending[kThenBranch].empty() && + // pending[kElseBranch].empty()) { + // return errors::Internal("Unexpected empty frontier for switch nodes"); + // } return Status::OK(); } @@ -754,33 +765,191 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, return Status::OK(); } -std::vector +StatusOr> FunctionalizeCond::DeterminePredicateSwitchOrder() { + struct Cluster { + bool operator==(const Cluster& other) const { + return representative == other.representative; + } + int representative = -1; + }; + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Identify dead switches; + // * Initialize the cluster's representative; + std::vector> clusters(graph_->num_node_ids()); std::vector dead_switches; std::vector switch_order; - DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + std::vector rev_topo_sorted_nodes; + DFS(*graph_, nullptr, [&](Node* n) { + clusters[n->id()].Get().representative = n->id(); if (IsSwitch(n)) { if (IsDeadSwitch(n)) { dead_switches.push_back(n); } else { + rev_topo_sorted_nodes.push_back(n); switch_order.push_back(n); } + } else if (n->IsOp()) { + // Exclude src and sink nodes from further consideration. + rev_topo_sorted_nodes.push_back(n); } }); + std::vector switch_clusters; + // Return early if there are no switches in the graph. + if (switch_order.empty()) { + return switch_clusters; + } + // Remove all dead switch nodes. for (Node* n : dead_switches) { VLOG(2) << "Removing dead switch: " << n->DebugString(); graph_->RemoveNode(n); } - std::vector predicate_switch_order; - if (switch_order.empty()) { - return predicate_switch_order; + // Identify switch nodes that are part of the same control flow context by + // considering the operands of operations: an operation is part of the same + // control context as its operands unless the operation is a switch. Control + // dependencies are considered part of the same control flow context if the + // switch depth is the same (see comment below). + + // entry_cluster records the input cluster to a switch node. This is used when + // merging with a merge node where the dst's cluster is merged with the entry + // cluster of the merge node's cluster (which corresponds to a switch cluster + // and so has an entry cluster). + std::unordered_map*> entry_cluster; + + // Returns the output cluster of a node. Where the output cluster is cluster + // where the output of the node is used. For non-merge nodes this is simply + // the cluster they are part of, while for merge nodes it is the entry cluster + // of the cluster they are part of (this will correspond to the entry node of + // a switch node that dominates the merge). + auto find_output_cluster = [&](Node* n) { + UnionFind* cluster = &clusters[n->id()]; + if (!IsMerge(n)) return cluster; + auto it = entry_cluster.find(clusters[n->id()].Get().representative); + // If the cluster is not found in the entry_cluster map then an + // instruction not dominated by a switch node has been merged into the + // cluster of the merge. This indicates a failure of the clustering. + CHECK(it != entry_cluster.end()) + << "Unable to find entry for n=" << n->id() << " (" + << cluster->Get().representative << ")"; + return it->second; + }; + + // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. + std::vector switch_depth(graph_->num_node_ids()); + for (auto it = rev_topo_sorted_nodes.rbegin(); + it != rev_topo_sorted_nodes.rend(); ++it) { + Node* n = *it; + + // Compute switch depth. + int new_switch_depth = 0; + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + new_switch_depth = std::max( + new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); + } + switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); + + // Only merge the input operands of a switch. The switch's clustering itself + // is determined by the interaction of the switch's outputs. + if (IsSwitch(n)) { + Node* input; + TF_CHECK_OK(n->input_node(0, &input)); + entry_cluster[n->id()] = &clusters[input->id()]; + UnionFind* cluster = find_output_cluster(input); + int cluster_depth = switch_depth[cluster->Get().representative]; + // Merge the inputs of the switch node with one another. This results in + // predicates and control input residing in the same cluster. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + UnionFind* src_cluster = find_output_cluster(src); + int src_cluster_depth = switch_depth[src_cluster->Get().representative]; + if (cluster_depth != src_cluster_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Switch ('", + n->name(), "') has operands ('", input->name(), "' and '", + src->name(), "') that have different switch depths (", + cluster_depth, " != ", src_cluster_depth, ")"); + } + cluster->Merge(src_cluster); + } + continue; + } + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (!src->IsOp()) continue; + UnionFind* cluster = find_output_cluster(src); + // Merge a node with its data operands and with its control operands if + // the src and dst are in the same ControlContext. The ControlContext is + // not explicitly available here, and instead the switch depth is used as + // a proxy here. Due to the invariant that control edges can only be from + // a containing scope to an inner scope or from the inner scope to its + // containing scope (for exit nodes), the switch depth will only match if + // the src and dst are in the same ControlContext. Control edges between + // ControlContexts are handled during the extraction. + int src_id = cluster->Get().representative; + int src_depth = switch_depth[src_id]; + if (!e->IsControlEdge() || new_switch_depth == src_depth) { + if (src_depth != new_switch_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Operand ('", + src->name(), "') and operator ('", n->name(), + "') have different switch depths (", src_depth, + " != ", new_switch_depth, ")"); + } + cluster->Merge(&clusters[n->id()]); + } + } } + if (dump_graphs_) { + // Mark the switch cluster each node is part of. + for (Node* n : graph_->nodes()) { + n->ClearAttr("_XlaFunctionalizeSwitchGroup"); + n->AddAttr("_XlaFunctionalizeSwitchGroup", + clusters[n->id()].Get().representative); + } + LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " + << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, + library_); + } + + // Verify all the nodes of a cluster are at the same depth. + std::unordered_map> cluster_to_depth_node; + for (Node* n : graph_->nodes()) { + int depth = switch_depth[n->id()]; + int cluster_rep = clusters[n->id()].Get().representative; + auto it = cluster_to_depth_node.find(cluster_rep); + if (it == cluster_to_depth_node.end()) { + cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); + } else { + if (it->second.first != depth) { + return errors::Internal( + "Illegal clustering created, mismatch in depths:", "\n\t", + n->DebugString(), "(", clusters[n->id()].Get().representative, + ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), + "(", clusters[n->id()].Get().representative, ") at depth ", + it->second.first); + } + } + } + + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second.representative)); + } + }; + // Merge Switch nodes with common predicate. - std::unordered_map predicate_index; + std::unordered_map, int, Hash> predicate_index; // The nodes in switch_order are in reverse topological order, but the // clustered switches need not be (i.e., when considered as a cluster one // element of a cluster may be later in the topological order than another @@ -789,13 +958,19 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { Node* pred; TF_CHECK_OK((*it)->input_node(1, &pred)); - if (predicate_index.find(pred) == predicate_index.end()) { - predicate_index[pred] = predicate_switch_order.size(); - predicate_switch_order.emplace_back(pred); + auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); + if (predicate_index.find(repr) == predicate_index.end()) { + predicate_index[repr] = switch_clusters.size(); + switch_clusters.emplace_back(pred); + // Generate a name by concatenating with the cluster representative as + // there could be multiple switch clusters with the same predicate. + switch_clusters[predicate_index[repr]].name = + strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); } - predicate_switch_order[predicate_index[pred]].switches.push_back(*it); + switch_clusters[predicate_index[repr]].switches.push_back(*it); } - return predicate_switch_order; + + return switch_clusters; } StatusOr> @@ -843,10 +1018,10 @@ StatusOr< std::pair, std::unordered_set>> FunctionalizeCond::DetermineBranchMapAndFrontier( - const std::vector& switches) { + const SwitchCluster& switch_cluster) { std::unordered_map branch_map; std::unordered_set frontier; - std::vector stack = switches; + std::vector stack = switch_cluster.switches; std::vector visited(graph_->num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); @@ -888,7 +1063,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } } - if (VLOG_IS_ON(2)) { + if (dump_graphs_) { for (const auto& kv : branch_map) { // Append attribute to the graph if running with logging to make the // changes clearer in the visualization. @@ -900,8 +1075,8 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } Status FunctionalizeCond::FunctionalizeInternal() { - std::vector predicate_switch_order = - DeterminePredicateSwitchOrder(); + TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, + DeterminePredicateSwitchOrder()); // Iterate from innermost set of clustered switches to outermost, replacing // matching switch->merge subgraphs with single XlaIf nodes. @@ -914,10 +1089,12 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::unordered_map branch_map; std::unordered_set frontier; TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps.switches)); + DetermineBranchMapAndFrontier(ps)); - VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, + library_); TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); // Sort the merge and switch nodes using NodeCmp. The switch-nodes are @@ -934,7 +1111,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { input_index[in] = cond_arg_nodes.size(); cond_arg_nodes.emplace_back(in); } - cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node); + cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); } std::vector merge_nodes(frontier.begin(), frontier.end()); std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); @@ -943,9 +1120,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { EnsureDominanceAndReturnNonDominatedControlNodes( branch_map, ps.switches)); - TF_ASSIGN_OR_RETURN( - Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate)); + TF_ASSIGN_OR_RETURN(Node * if_node, + ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); for (Node* old : old_control_nodes) { graph_->AddControlEdge(old, if_node); } @@ -954,25 +1130,26 @@ Status FunctionalizeCond::FunctionalizeInternal() { graph_->RemoveNode(del_kv.first); } for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switch_nodes) { + for (Node* node : kv.switches) { graph_->RemoveNode(node); } } - VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, + library_); } return Status::OK(); } StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " - << NodesToString(switch_nodes); + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(2) << "Build if op for " << switch_cluster.name; NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); + NodeDefBuilder builder(switch_cluster.name, "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -982,12 +1159,9 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, + merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): " - << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", branch[i]), *body); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); @@ -999,7 +1173,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( DataTypeVector in_arg_types; for (auto& kv : cond_arg_nodes) { bool inserted = false; - for (const Node* arg : kv.switch_nodes) { + for (const Node* arg : kv.switches) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1026,10 +1200,11 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(predicate->assigned_device_name()); + builder.Device(switch_cluster.predicate->assigned_device_name()); // Conditional should be the first input ... builder.Input( - NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); + NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, + switch_cluster.predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1039,7 +1214,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( } Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body) { VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " @@ -1049,7 +1224,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, int arg_count = 0; for (auto& kv : cond_arg_nodes) { Node* arg_node = nullptr; - for (const auto* arg : kv.switch_nodes) { + for (const auto* arg : kv.switches) { DataType dtype = arg->input_type(0); if (arg_node == nullptr) { TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); @@ -1073,8 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, node_map.at(in->id()) = body->CopyNode(in); } - if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == - switch_nodes.end()) { + if (std::find(switches.begin(), switches.end(), in) == switches.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1096,7 +1270,7 @@ Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, graph_->AddEdge(predicate, 0, if_node, index++); for (auto& kv : cond_arg_nodes) { bool inserted = false; - for (const Node* arg : kv.switch_nodes) { + for (const Node* arg : kv.switches) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1139,16 +1313,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, } StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> " + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " << NodesToString(merge_nodes); // Extract bodies and builds a If operator. TF_ASSIGN_OR_RETURN( Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node)); + BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); + TF_RETURN_IF_ERROR( + AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return if_node; @@ -1157,18 +1332,19 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library); + FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); return fc.FunctionalizeInternal(); } } // namespace -// Transformation that converts Tensorflow's graph control flow constructs into +// Transformation that converts TensorFlow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + << dump_graph::DumpGraphToFile("functionalize_initial", *graph, + library); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. @@ -1180,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph, for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; - VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name << " frame: " << (cf.frame ? cf.frame->name() : "---") << " parent_frame: " << (cf.parent_frame ? cf.parent_frame->name() : "---"); @@ -1248,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph, TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph); + << dump_graph::DumpGraphToFile("functionalize_final", *graph, + library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 71f12a1..bc7276c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -38,10 +38,11 @@ namespace { // Returns the names of the "then" and "else" functions for the XlaIf node in a // graph. -Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, - NameAttrList* else_fn) { +Status FindIfThenAndElse(const GraphDef& graph, string* op_name, + NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "XlaIf") { + *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); *then_fn = *result; @@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) { GraphDef graph_def; graph.ToGraphDef(&graph_def); + string op_name; NameAttrList then_fn; NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); InstantiationResultForTest else_result; TF_EXPECT_OK( InstantiateFunctionForTest(else_fn.name(), library, &else_result)); @@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 1418d95..058a1f2 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -134,7 +134,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; - tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); } OpKernelContext op_context(¶ms, n->num_outputs());