#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include <algorithm>
+#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
namespace grappler {
namespace {
-std::vector<int> GetStackPushNodesToConvert(
- const SimpleGraphView& graph_view,
- const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
- VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
- const std::unordered_set<string> op_types_to_traverse(
- {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
- "Identity", "RefIdentity"});
- std::vector<int> nodes_to_convert;
- std::set<int> fanout;
- graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
- for (int fanout_idx : fanout) {
- const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
- VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
- if (IsStackPushOp(fanout_node)) {
- nodes_to_convert.push_back(fanout_idx);
- } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
- op_types_to_traverse.find(fanout_node.op()) !=
- op_types_to_traverse.end()) {
- continue;
- } else if (!IsStackPopOp(fanout_node) ||
- (!graph_view.outputs(fanout_idx).empty() ||
- nodes_to_preserve.find(fanout_node.name()) !=
- nodes_to_preserve.end())) {
- // The node is either a stack pop with consumers or something unexpected
- // so we leave the graph alone.
- nodes_to_convert.clear();
- break;
- }
- }
- return nodes_to_convert;
-}
+class LoopInvariantNodeMotionOptimizer {
+ public:
+ explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
+ : optimized_graph_(optimized_graph) {}
+ virtual ~LoopInvariantNodeMotionOptimizer() = default;
+ Status Optimize();
-Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
- const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
- const GraphDef& graph = item.graph;
- *optimized_graph = graph;
- NodeMap node_map(optimized_graph);
- SimpleGraphView graph_view;
- TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
- for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
- if (IsStackOp(graph.node(node_idx))) {
- for (int push_node_idx : GetStackPushNodesToConvert(
- graph_view, nodes_to_preserve, node_idx)) {
- // We found push nodes without corresponding pops. Convert them to
- // Identity passing the data through and add a control dependency from
- // the op supplying the stack handle.
- NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
- VLOG(1) << "Converting " << push_node_idx << " : "
- << push_node->DebugString();
- if (push_node->attr().count("swap_memory") != 0) {
- push_node->mutable_attr()->erase("swap_memory");
- }
- push_node->set_op("Identity");
- push_node->mutable_input()->SwapElements(0, 1);
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- push_node->input(1), optimized_graph, &node_map);
- push_node->set_input(1, ctrl_dep);
- VLOG(1) << "After converting: " << push_node->DebugString();
- }
- }
- }
- return Status::OK();
-}
+ private:
+ Status FindInvariantNodes(NodeDef* node);
+ Status RevertInvariantNodes();
+ Status MoveInvariantNodes(const int frame_id);
+ Status HandleInvariantNode(NodeDef* node, const int num_outputs,
+ const int frame_id);
+ Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
+ Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
-} // namespace
+ GraphDef* optimized_graph_; // Not owned.
+ std::unique_ptr<NodeMap> node_map_;
+ std::map<NodeDef*, int> invariant_nodes_;
+ std::set<int> empty_set_;
+ // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
+ std::map<int, std::set<int>> frame_children_;
+ std::map<int, int> frame_parent_;
+ std::map<int, const NodeDef*> loop_cond_;
+ std::map<int, std::vector<NodeDef*>> invariant_enters_;
+ int new_enter_id_;
+};
-Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
- const int num_outputs) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
+ NodeDef* node, const int num_outputs) {
auto consumers = node_map_->GetOutputs(node->name());
std::vector<string> enter_control_inputs;
string enter_input;
return Status::OK();
}
-Status LoopOptimizer::LINMHandleConst(NodeDef* node,
- const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
+ const int num_outputs,
+ const int frame_id) {
NodeDef* const_node;
if (num_outputs == 0) {
// all successor nodes are invariant
int parent_id = parent_it->second;
auto loop_cond_it = loop_cond_.find(parent_id);
if (loop_cond_it == loop_cond_.end()) {
- return errors::InvalidArgument(
- "Frame ", frame_id, " doesn't have a LoopCond node");
+ return errors::InvalidArgument("Frame ", frame_id,
+ " doesn't have a LoopCond node");
}
auto& loop_cond_name = loop_cond_it->second->name();
NodeDef* switch_node = nullptr;
}
}
if (!switch_node) {
- return errors::InvalidArgument(
- "LoopCond node of Frame ", frame_id,
- " doesn't connect to any Switch node");
+ return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
+ " doesn't connect to any Switch node");
}
string switch_output = StrCat(switch_node->name(), ":1");
const string ctrl_dep = ConstantFolding::AddControlDependency(
return Status::OK();
}
-Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
- const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
+ NodeDef* node, const int num_outputs, const int frame_id) {
// have to remove control inputs to the invariant node from the same frame
// when moving this node out of this frame
for (int i = 0; i < node->input_size(); ++i) {
DataTypeVector output_types;
OpRegistryInterface* op_registry = OpRegistry::Global();
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(
- op_registry->LookUp(node->op(), &op_reg_data));
- TF_RETURN_IF_ERROR(
- InOutTypesForNode(*node, op_reg_data->op_def,
- &input_types, &output_types));
+ TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
+ &output_types));
auto consumers = node_map_->GetOutputs(node->name());
string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
- int piterations = invariant_enters_[frame_id][0]
- ->attr().at("parallel_iterations").i();
+ int piterations =
+ invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
for (auto* consumer : consumers) {
if (!invariant_nodes_.count(consumer)) {
for (int i = 0; i < consumer->input_size(); ++i) {
return Status::OK();
}
-Status LoopOptimizer::MoveInvariantNodes(const int frame_id) {
- for (auto iter = invariant_nodes_.begin();
- iter != invariant_nodes_.end(); ++iter) {
+Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
+ const int frame_id) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
+ ++iter) {
auto* invariant_node = iter->first;
const int num_outputs = iter->second;
if (IsEnter(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleInvariantEnter(invariant_node, num_outputs));
+ TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
} else if (IsConstant(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleConst(invariant_node, num_outputs, frame_id));
+ TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
} else {
TF_RETURN_IF_ERROR(
- LINMHandleInvariantNode(invariant_node, num_outputs, frame_id));
+ HandleInvariantNode(invariant_node, num_outputs, frame_id));
}
}
return Status::OK();
}
-Status LoopOptimizer::RevertInvariantNodes() {
+Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
std::deque<const NodeDef*> reverted_nodes;
- for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
bool erased = false;
const auto* node = iter->first;
if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
auto* producer = node_map_->GetNode(input);
auto iter = invariant_nodes_.find(producer);
if (iter != invariant_nodes_.end()) {
- if (IsControlInput(input) &&
- !IsConstant(*producer) && !IsEnter(*producer)) {
+ if (IsControlInput(input) && !IsConstant(*producer) &&
+ !IsEnter(*producer)) {
reverted_nodes.push_back(producer);
invariant_nodes_.erase(iter);
} else {
return Status::OK();
}
-Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
+Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) {
auto consumers = node_map_->GetOutputs(node->name());
invariant_nodes_.insert(std::make_pair(node, consumers.size()));
for (auto* consumer : consumers) {
- if (invariant_nodes_.count(consumer) ||
- ModifiesFrameInfo(*consumer)) {
+ if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
continue;
}
bool is_invariant = true;
return Status::OK();
}
-Status LoopOptimizer::LoopInvariantNodeMotion() {
+Status LoopInvariantNodeMotionOptimizer::Optimize() {
+ node_map_.reset(new NodeMap(optimized_graph_));
+ FrameMap frame_map;
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
+ &frame_map, &num_frames));
std::deque<int> worklist;
- for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) {
+ for (auto iter = frame_map.begin(); iter != frame_map.end(); ++iter) {
auto* node = iter->first;
auto& frame_ids = iter->second;
if (frame_ids.size() >= 3) {
return Status::OK();
}
-Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+std::vector<int> GetStackPushNodesToConvert(
+ const SimpleGraphView& graph_view,
+ const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
+ VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
+ const std::unordered_set<string> op_types_to_traverse(
+ {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
+ "Identity", "RefIdentity"});
+ std::vector<int> nodes_to_convert;
+ std::set<int> fanout;
+ graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
+ for (int fanout_idx : fanout) {
+ const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
+ VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
+ if (IsStackPushOp(fanout_node)) {
+ nodes_to_convert.push_back(fanout_idx);
+ } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
+ op_types_to_traverse.find(fanout_node.op()) !=
+ op_types_to_traverse.end()) {
+ continue;
+ } else if (!IsStackPopOp(fanout_node) ||
+ (!graph_view.outputs(fanout_idx).empty() ||
+ nodes_to_preserve.find(fanout_node.name()) !=
+ nodes_to_preserve.end())) {
+ // The node is either a stack pop with consumers or something unexpected
+ // so we leave the graph alone.
+ nodes_to_convert.clear();
+ break;
+ }
+ }
+ return nodes_to_convert;
+}
+
+Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
+ const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
+ const GraphDef& graph = item.graph;
+ *optimized_graph = graph;
+ NodeMap node_map(optimized_graph);
+ SimpleGraphView graph_view;
+ TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
+ for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
+ if (IsStackOp(graph.node(node_idx))) {
+ for (int push_node_idx : GetStackPushNodesToConvert(
+ graph_view, nodes_to_preserve, node_idx)) {
+ // We found push nodes without corresponding pops. Convert them to
+ // Identity passing the data through and add a control dependency from
+ // the op supplying the stack handle.
+ NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
+ VLOG(1) << "Converting " << push_node_idx << " : "
+ << push_node->DebugString();
+ if (push_node->attr().count("swap_memory") != 0) {
+ push_node->mutable_attr()->erase("swap_memory");
+ }
+ push_node->set_op("Identity");
+ push_node->mutable_input()->SwapElements(0, 1);
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ push_node->input(1), optimized_graph, &node_map);
+ push_node->set_input(1, ctrl_dep);
+ VLOG(1) << "After converting: " << push_node->DebugString();
+ }
+ }
+ }
+ return Status::OK();
+}
- TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
+} // namespace
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- optimized_graph_ = optimized_graph;
- // Set up helper data structures.
- node_map_.reset(new NodeMap(optimized_graph_));
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
- TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ // Set up helper data structures.
+ if (options_.enable_loop_invariant_node_motion) {
+ LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
+ TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
+ }
+ if (options_.enable_stack_push_removal) {
+ TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
}
return Status::OK();
namespace tensorflow {
namespace grappler {
-namespace {
class LoopOptimizerTest : public GrapplerTest {
protected:
attributes.emplace_back("T", type);
AddNode(name, op, inputs, attributes, graph);
}
+
+ void DisableAllStages(LoopOptimizer* optimizer) {
+ LoopOptimizer::LoopOptimizerOptions options;
+ options.enable_loop_invariant_node_motion = false;
+ options.enable_stack_push_removal = false;
+ optimizer->options_ = options;
+ }
+
+ void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_loop_invariant_node_motion = true;
+ }
+
+ void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_stack_push_removal = true;
+ }
};
TEST_F(LoopOptimizerTest, Basic) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
CHECK(fake_input.NextItem(&item));
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
item.fetch.push_back("pop4");
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
}
}
-} // namespace
} // namespace grappler
} // namespace tensorflow