Status FunctionalizeLoop(Graph* graph, Frame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
- << dump_graph::DumpGraphToFile("functionalize_before", *graph);
+ << dump_graph::DumpGraphToFile("functionalize_before", *graph,
+ library);
// Split loop-varying Enter nodes with multiple successors. If the same
// Tensor is fed as input to multiple loop arguments, we may end up with a
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
VLOG(2) << "Frame " << frame->name << " condition: "
- << dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
+ << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
<< " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
static std::atomic<int64> sequence_num(0LL);
frame->parent->nodes.insert(while_node);
VLOG(2) << "Frame " << frame->name << " after: "
- << dump_graph::DumpGraphToFile("functionalize_after", *graph);
+ << dump_graph::DumpGraphToFile("functionalize_after", *graph,
+ library);
return Status::OK();
}
explicit CondArgNode(Node* input) : input(input) {}
string ToString() const {
return strings::StrCat("input=", input->name(),
- " switches=", NodesToString(switch_nodes));
+ " switches=", NodesToString(switches));
}
Node* input;
- std::vector<Node*> switch_nodes;
+ std::vector<Node*> switches;
};
using CondArgNodes = std::vector<CondArgNode>;
int count;
};
- struct PredicateSwitches {
- explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
+ // Group of switch nodes that will be part of the same XlaIf.
+ struct SwitchCluster {
+ explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
+ string ToString() const {
+ return strings::StrCat(name, " predicate=", predicate->name(),
+ " switches=", NodesToString(switches));
+ }
+ string name;
Node* predicate;
std::vector<Node*> switches;
};
- FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
- : library_(library), graph_(graph) {}
+ FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
+ bool dump_graphs)
+ : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
// Perform the actual cond functionalization. Iterate over groups of switch
// nodes (linked by common predicate), from innermost to outermost, and
// frontier (the nodes where the cond ends).
StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
std::unordered_set<Node*>>>
- DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
+ DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
// Returns XlaIf node created from subgraph of merge and switch nodes. This
// encapsulates the process of extracting the bodies needed for the then and
// else branch, creates a XlaIf node, removing the nodes of the branches from
// the graph and replacing the merge node with a XlaIf.
StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes,
- Node* predicate);
+ const SwitchCluster& switch_cluster,
+ const std::vector<Node*>& switches);
// Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes,
- Node* predicate);
+ const SwitchCluster& switch_cluster,
+ const std::vector<Node*>& merge_nodes);
// Extracts a function body corresponding to the given input edge of the merge
// node.
Status ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switch_nodes,
+ const std::vector<Node*>& switches,
const std::vector<Node*>& merge_nodes, int input_edge,
Graph* body);
// Adds all output edges from the `if_node`.
Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
- // Returns the switches of graph_ (along with grouping predicates) in
- // postorder. Dead switch nodes are skipped and removed from the graph.
- std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
+ // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
+ // skipped and removed from the graph.
+ StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
// Update the state for destination based on the state of source and the node
// being updated.
FunctionLibraryDefinition* library_;
Graph* graph_;
+ bool dump_graphs_;
};
bool IsDeadSwitch(const Node* node) {
") in both Else and Then branch should be in Both.");
}
}
- if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
- pending[kElseBranch].empty()) {
- return errors::Internal("Unexpected empty frontier for switch nodes");
- }
+ // An empty frontier indicates a dead switch. Above we attempt to remove dead
+ // switch nodes, but not all are removed so don't treat it as an error yet.
+ // TODO(jpienaar): Find out why dead switch nodes remain.
+ // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
+ // pending[kElseBranch].empty()) {
+ // return errors::Internal("Unexpected empty frontier for switch nodes");
+ // }
return Status::OK();
}
return Status::OK();
}
-std::vector<FunctionalizeCond::PredicateSwitches>
+StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
FunctionalizeCond::DeterminePredicateSwitchOrder() {
+ struct Cluster {
+ bool operator==(const Cluster& other) const {
+ return representative == other.representative;
+ }
+ int representative = -1;
+ };
+
+ // Perform a DFS over the graph and
+ // * Determine the reverse topological order of the nodes (there should be no
+ // cycles at this point so the post-order numbering corresponds to the
+ // reverse topological sorting);
+ // * Identify dead switches;
+ // * Initialize the cluster's representative;
+ std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
std::vector<Node*> dead_switches;
std::vector<Node*> switch_order;
- DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
+ std::vector<Node*> rev_topo_sorted_nodes;
+ DFS(*graph_, nullptr, [&](Node* n) {
+ clusters[n->id()].Get().representative = n->id();
if (IsSwitch(n)) {
if (IsDeadSwitch(n)) {
dead_switches.push_back(n);
} else {
+ rev_topo_sorted_nodes.push_back(n);
switch_order.push_back(n);
}
+ } else if (n->IsOp()) {
+ // Exclude src and sink nodes from further consideration.
+ rev_topo_sorted_nodes.push_back(n);
}
});
+ std::vector<SwitchCluster> switch_clusters;
+ // Return early if there are no switches in the graph.
+ if (switch_order.empty()) {
+ return switch_clusters;
+ }
+
// Remove all dead switch nodes.
for (Node* n : dead_switches) {
VLOG(2) << "Removing dead switch: " << n->DebugString();
graph_->RemoveNode(n);
}
- std::vector<PredicateSwitches> predicate_switch_order;
- if (switch_order.empty()) {
- return predicate_switch_order;
+ // Identify switch nodes that are part of the same control flow context by
+ // considering the operands of operations: an operation is part of the same
+ // control context as its operands unless the operation is a switch. Control
+ // dependencies are considered part of the same control flow context if the
+ // switch depth is the same (see comment below).
+
+ // entry_cluster records the input cluster to a switch node. This is used when
+ // merging with a merge node where the dst's cluster is merged with the entry
+ // cluster of the merge node's cluster (which corresponds to a switch cluster
+ // and so has an entry cluster).
+ std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
+
+ // Returns the output cluster of a node. Where the output cluster is cluster
+ // where the output of the node is used. For non-merge nodes this is simply
+ // the cluster they are part of, while for merge nodes it is the entry cluster
+ // of the cluster they are part of (this will correspond to the entry node of
+ // a switch node that dominates the merge).
+ auto find_output_cluster = [&](Node* n) {
+ UnionFind<Cluster>* cluster = &clusters[n->id()];
+ if (!IsMerge(n)) return cluster;
+ auto it = entry_cluster.find(clusters[n->id()].Get().representative);
+ // If the cluster is not found in the entry_cluster map then an
+ // instruction not dominated by a switch node has been merged into the
+ // cluster of the merge. This indicates a failure of the clustering.
+ CHECK(it != entry_cluster.end())
+ << "Unable to find entry for n=" << n->id() << " ("
+ << cluster->Get().representative << ")";
+ return it->second;
+ };
+
+ // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
+ std::vector<int> switch_depth(graph_->num_node_ids());
+ for (auto it = rev_topo_sorted_nodes.rbegin();
+ it != rev_topo_sorted_nodes.rend(); ++it) {
+ Node* n = *it;
+
+ // Compute switch depth.
+ int new_switch_depth = 0;
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ new_switch_depth = std::max(
+ new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
+ }
+ switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
+
+ // Only merge the input operands of a switch. The switch's clustering itself
+ // is determined by the interaction of the switch's outputs.
+ if (IsSwitch(n)) {
+ Node* input;
+ TF_CHECK_OK(n->input_node(0, &input));
+ entry_cluster[n->id()] = &clusters[input->id()];
+ UnionFind<Cluster>* cluster = find_output_cluster(input);
+ int cluster_depth = switch_depth[cluster->Get().representative];
+ // Merge the inputs of the switch node with one another. This results in
+ // predicates and control input residing in the same cluster.
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ UnionFind<Cluster>* src_cluster = find_output_cluster(src);
+ int src_cluster_depth = switch_depth[src_cluster->Get().representative];
+ if (cluster_depth != src_cluster_depth) {
+ return errors::InvalidArgument(
+ "Unable to functionalize control flow in graph: Switch ('",
+ n->name(), "') has operands ('", input->name(), "' and '",
+ src->name(), "') that have different switch depths (",
+ cluster_depth, " != ", src_cluster_depth, ")");
+ }
+ cluster->Merge(src_cluster);
+ }
+ continue;
+ }
+
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ if (!src->IsOp()) continue;
+ UnionFind<Cluster>* cluster = find_output_cluster(src);
+ // Merge a node with its data operands and with its control operands if
+ // the src and dst are in the same ControlContext. The ControlContext is
+ // not explicitly available here, and instead the switch depth is used as
+ // a proxy here. Due to the invariant that control edges can only be from
+ // a containing scope to an inner scope or from the inner scope to its
+ // containing scope (for exit nodes), the switch depth will only match if
+ // the src and dst are in the same ControlContext. Control edges between
+ // ControlContexts are handled during the extraction.
+ int src_id = cluster->Get().representative;
+ int src_depth = switch_depth[src_id];
+ if (!e->IsControlEdge() || new_switch_depth == src_depth) {
+ if (src_depth != new_switch_depth) {
+ return errors::InvalidArgument(
+ "Unable to functionalize control flow in graph: Operand ('",
+ src->name(), "') and operator ('", n->name(),
+ "') have different switch depths (", src_depth,
+ " != ", new_switch_depth, ")");
+ }
+ cluster->Merge(&clusters[n->id()]);
+ }
+ }
}
+ if (dump_graphs_) {
+ // Mark the switch cluster each node is part of.
+ for (Node* n : graph_->nodes()) {
+ n->ClearAttr("_XlaFunctionalizeSwitchGroup");
+ n->AddAttr("_XlaFunctionalizeSwitchGroup",
+ clusters[n->id()].Get().representative);
+ }
+ LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
+ << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
+ library_);
+ }
+
+ // Verify all the nodes of a cluster are at the same depth.
+ std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
+ for (Node* n : graph_->nodes()) {
+ int depth = switch_depth[n->id()];
+ int cluster_rep = clusters[n->id()].Get().representative;
+ auto it = cluster_to_depth_node.find(cluster_rep);
+ if (it == cluster_to_depth_node.end()) {
+ cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
+ } else {
+ if (it->second.first != depth) {
+ return errors::Internal(
+ "Illegal clustering created, mismatch in depths:", "\n\t",
+ n->DebugString(), "(", clusters[n->id()].Get().representative,
+ ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
+ "(", clusters[n->id()].Get().representative, ") at depth ",
+ it->second.first);
+ }
+ }
+ }
+
+ struct Hash {
+ size_t operator()(const std::pair<Node*, Cluster>& item) const {
+ return Hash64Combine(hash<Node*>()(item.first),
+ std::hash<int>()(item.second.representative));
+ }
+ };
+
// Merge Switch nodes with common predicate.
- std::unordered_map<Node*, int> predicate_index;
+ std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
// The nodes in switch_order are in reverse topological order, but the
// clustered switches need not be (i.e., when considered as a cluster one
// element of a cluster may be later in the topological order than another
for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
Node* pred;
TF_CHECK_OK((*it)->input_node(1, &pred));
- if (predicate_index.find(pred) == predicate_index.end()) {
- predicate_index[pred] = predicate_switch_order.size();
- predicate_switch_order.emplace_back(pred);
+ auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
+ if (predicate_index.find(repr) == predicate_index.end()) {
+ predicate_index[repr] = switch_clusters.size();
+ switch_clusters.emplace_back(pred);
+ // Generate a name by concatenating with the cluster representative as
+ // there could be multiple switch clusters with the same predicate.
+ switch_clusters[predicate_index[repr]].name =
+ strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
}
- predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
+ switch_clusters[predicate_index[repr]].switches.push_back(*it);
}
- return predicate_switch_order;
+
+ return switch_clusters;
}
StatusOr<std::vector<Node*>>
std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
std::unordered_set<Node*>>>
FunctionalizeCond::DetermineBranchMapAndFrontier(
- const std::vector<Node*>& switches) {
+ const SwitchCluster& switch_cluster) {
std::unordered_map<Node*, ForwardFlowNode> branch_map;
std::unordered_set<Node*> frontier;
- std::vector<Node*> stack = switches;
+ std::vector<Node*> stack = switch_cluster.switches;
std::vector<bool> visited(graph_->num_node_ids(), false);
while (!stack.empty()) {
Node* n = stack.back();
}
}
- if (VLOG_IS_ON(2)) {
+ if (dump_graphs_) {
for (const auto& kv : branch_map) {
// Append attribute to the graph if running with logging to make the
// changes clearer in the visualization.
}
Status FunctionalizeCond::FunctionalizeInternal() {
- std::vector<PredicateSwitches> predicate_switch_order =
- DeterminePredicateSwitchOrder();
+ TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
+ DeterminePredicateSwitchOrder());
// Iterate from innermost set of clustered switches to outermost, replacing
// matching switch->merge subgraphs with single XlaIf nodes.
std::unordered_map<Node*, ForwardFlowNode> branch_map;
std::unordered_set<Node*> frontier;
TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
- DetermineBranchMapAndFrontier(ps.switches));
+ DetermineBranchMapAndFrontier(ps));
- VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
+ if (dump_graphs_)
+ LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
+ << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
+ library_);
TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
// Sort the merge and switch nodes using NodeCmp. The switch-nodes are
input_index[in] = cond_arg_nodes.size();
cond_arg_nodes.emplace_back(in);
}
- cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node);
+ cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
}
std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
EnsureDominanceAndReturnNonDominatedControlNodes(
branch_map, ps.switches));
- TF_ASSIGN_OR_RETURN(
- Node * if_node,
- ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate));
+ TF_ASSIGN_OR_RETURN(Node * if_node,
+ ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
for (Node* old : old_control_nodes) {
graph_->AddControlEdge(old, if_node);
}
graph_->RemoveNode(del_kv.first);
}
for (auto& kv : cond_arg_nodes) {
- for (Node* node : kv.switch_nodes) {
+ for (Node* node : kv.switches) {
graph_->RemoveNode(node);
}
}
- VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_ac", *graph_);
+ if (dump_graphs_)
+ LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
+ << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
+ library_);
}
return Status::OK();
}
StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
- const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes, Node* predicate) {
- VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
- << NodesToString(switch_nodes);
+ const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
+ const std::vector<Node*>& merge_nodes) {
+ VLOG(2) << "Build if op for " << switch_cluster.name;
NodeDef if_def;
// Create a new If node using the name of the merge node.
- NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf");
+ NodeDefBuilder builder(switch_cluster.name, "XlaIf");
string branch[] = {"else_branch", "then_branch"};
for (int i = 0; i < 2; ++i) {
static std::atomic<int64> sequence_num(0LL);
body_name.set_name(
strings::StrCat("_functionalize_if_", branch[i], "_", id));
auto body = xla::MakeUnique<Graph>(graph_->op_registry());
- TF_RETURN_IF_ERROR(
- ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get()));
+ TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
+ merge_nodes, i, body.get()));
VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
- VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): "
- << dump_graph::DumpGraphToFile(
- strings::StrCat("functionalize_", branch[i]), *body);
FunctionDef body_fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
DataTypeVector in_arg_types;
for (auto& kv : cond_arg_nodes) {
bool inserted = false;
- for (const Node* arg : kv.switch_nodes) {
+ for (const Node* arg : kv.switches) {
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
builder.Attr("Tout", out_type);
builder.Attr("Tcond", DT_BOOL);
- builder.Device(predicate->assigned_device_name());
+ builder.Device(switch_cluster.predicate->assigned_device_name());
// Conditional should be the first input ...
builder.Input(
- NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0)));
+ NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
+ switch_cluster.predicate->output_type(0)));
// ... followed by the other inputs.
builder.Input(inputs);
}
Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switch_nodes,
+ const std::vector<Node*>& switches,
const std::vector<Node*>& merge_nodes,
int input_edge, Graph* body) {
VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
int arg_count = 0;
for (auto& kv : cond_arg_nodes) {
Node* arg_node = nullptr;
- for (const auto* arg : kv.switch_nodes) {
+ for (const auto* arg : kv.switches) {
DataType dtype = arg->input_type(0);
if (arg_node == nullptr) {
TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
node_map.at(in->id()) = body->CopyNode(in);
}
- if (std::find(switch_nodes.begin(), switch_nodes.end(), in) ==
- switch_nodes.end()) {
+ if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
node_map.at(node->id()), 0);
} else {
graph_->AddEdge(predicate, 0, if_node, index++);
for (auto& kv : cond_arg_nodes) {
bool inserted = false;
- for (const Node* arg : kv.switch_nodes) {
+ for (const Node* arg : kv.switches) {
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
}
StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
- const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes, Node* predicate) {
- VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
+ const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
+ const std::vector<Node*>& merge_nodes) {
+ VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
<< NodesToString(merge_nodes);
// Extract bodies and builds a If operator.
TF_ASSIGN_OR_RETURN(
Node * if_node,
- BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate));
- TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node));
+ BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
+ TF_RETURN_IF_ERROR(
+ AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
return if_node;
Status FunctionalizeCond::Functionalize(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(1) << "FunctionalizeCond::Functionalize";
- FunctionalizeCond fc(graph, library);
+ FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
return fc.FunctionalizeInternal();
}
} // namespace
-// Transformation that converts Tensorflow's graph control flow constructs into
+// Transformation that converts TensorFlow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
- << dump_graph::DumpGraphToFile("functionalize_initial", *graph);
+ << dump_graph::DumpGraphToFile("functionalize_initial", *graph,
+ library);
// Note: BuildControlFlowInfo() requires that the graph's source node is
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
- VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
+ VLOG(2) << "node: " << node->name() << " (" << node->id()
+ << ") frame_name: " << cf.frame_name
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
<< " parent_frame: "
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
VLOG(2) << "FunctionalizeControlFlow (final): "
- << dump_graph::DumpGraphToFile("functionalize_final", *graph);
+ << dump_graph::DumpGraphToFile("functionalize_final", *graph,
+ library);
return Status::OK();
}