Refactor LoopOptimizer:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:00:41 +0000 (16:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:03:12 +0000 (16:03 -0700)
  * Put loop-invariant node motion in its own class.
  * Add granular control of which passes to run.
Swap order of LINM and stack push removal.

PiperOrigin-RevId: 191953537

tensorflow/core/grappler/optimizers/loop_optimizer.cc
tensorflow/core/grappler/optimizers/loop_optimizer.h
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc

index a063dc3..28ce2c7 100644 (file)
@@ -16,18 +16,17 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
 
 #include <algorithm>
+#include <deque>
 #include <limits>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
-#include <deque>
 
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -46,74 +45,36 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
-std::vector<int> GetStackPushNodesToConvert(
-    const SimpleGraphView& graph_view,
-    const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
-  VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
-  const std::unordered_set<string> op_types_to_traverse(
-      {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
-       "Identity", "RefIdentity"});
-  std::vector<int> nodes_to_convert;
-  std::set<int> fanout;
-  graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
-  for (int fanout_idx : fanout) {
-    const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
-    VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
-    if (IsStackPushOp(fanout_node)) {
-      nodes_to_convert.push_back(fanout_idx);
-    } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
-               op_types_to_traverse.find(fanout_node.op()) !=
-                   op_types_to_traverse.end()) {
-      continue;
-    } else if (!IsStackPopOp(fanout_node) ||
-               (!graph_view.outputs(fanout_idx).empty() ||
-                nodes_to_preserve.find(fanout_node.name()) !=
-                    nodes_to_preserve.end())) {
-      // The node is either a stack pop with consumers or something unexpected
-      // so we leave the graph alone.
-      nodes_to_convert.clear();
-      break;
-    }
-  }
-  return nodes_to_convert;
-}
+class LoopInvariantNodeMotionOptimizer {
+ public:
+  explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
+      : optimized_graph_(optimized_graph) {}
+  virtual ~LoopInvariantNodeMotionOptimizer() = default;
+  Status Optimize();
 
-Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
-  const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
-  const GraphDef& graph = item.graph;
-  *optimized_graph = graph;
-  NodeMap node_map(optimized_graph);
-  SimpleGraphView graph_view;
-  TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
-  for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
-    if (IsStackOp(graph.node(node_idx))) {
-      for (int push_node_idx : GetStackPushNodesToConvert(
-               graph_view, nodes_to_preserve, node_idx)) {
-        // We found push nodes without corresponding pops. Convert them to
-        // Identity passing the data through and add a control dependency from
-        // the op supplying the stack handle.
-        NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
-        VLOG(1) << "Converting " << push_node_idx << " : "
-                << push_node->DebugString();
-        if (push_node->attr().count("swap_memory") != 0) {
-          push_node->mutable_attr()->erase("swap_memory");
-        }
-        push_node->set_op("Identity");
-        push_node->mutable_input()->SwapElements(0, 1);
-        const string ctrl_dep = ConstantFolding::AddControlDependency(
-            push_node->input(1), optimized_graph, &node_map);
-        push_node->set_input(1, ctrl_dep);
-        VLOG(1) << "After converting: " << push_node->DebugString();
-      }
-    }
-  }
-  return Status::OK();
-}
+ private:
+  Status FindInvariantNodes(NodeDef* node);
+  Status RevertInvariantNodes();
+  Status MoveInvariantNodes(const int frame_id);
+  Status HandleInvariantNode(NodeDef* node, const int num_outputs,
+                             const int frame_id);
+  Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
+  Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
 
-}  // namespace
+  GraphDef* optimized_graph_;  // Not owned.
+  std::unique_ptr<NodeMap> node_map_;
+  std::map<NodeDef*, int> invariant_nodes_;
+  std::set<int> empty_set_;
+  // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
+  std::map<int, std::set<int>> frame_children_;
+  std::map<int, int> frame_parent_;
+  std::map<int, const NodeDef*> loop_cond_;
+  std::map<int, std::vector<NodeDef*>> invariant_enters_;
+  int new_enter_id_;
+};
 
-Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
-                                               const int num_outputs) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
+    NodeDef* node, const int num_outputs) {
   auto consumers = node_map_->GetOutputs(node->name());
   std::vector<string> enter_control_inputs;
   string enter_input;
@@ -142,8 +103,9 @@ Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
   return Status::OK();
 }
 
-Status LoopOptimizer::LINMHandleConst(NodeDef* node,
-    const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
+                                                     const int num_outputs,
+                                                     const int frame_id) {
   NodeDef* const_node;
   if (num_outputs == 0) {
     // all successor nodes are invariant
@@ -185,8 +147,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
     int parent_id = parent_it->second;
     auto loop_cond_it = loop_cond_.find(parent_id);
     if (loop_cond_it == loop_cond_.end()) {
-      return errors::InvalidArgument(
-          "Frame ", frame_id, " doesn't have a LoopCond node");
+      return errors::InvalidArgument("Frame ", frame_id,
+                                     " doesn't have a LoopCond node");
     }
     auto& loop_cond_name = loop_cond_it->second->name();
     NodeDef* switch_node = nullptr;
@@ -197,9 +159,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
       }
     }
     if (!switch_node) {
-      return errors::InvalidArgument(
-          "LoopCond node of Frame ", frame_id,
-          " doesn't connect to any Switch node");
+      return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
+                                     " doesn't connect to any Switch node");
     }
     string switch_output = StrCat(switch_node->name(), ":1");
     const string ctrl_dep = ConstantFolding::AddControlDependency(
@@ -210,8 +171,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
   return Status::OK();
 }
 
-Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
-    const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
+    NodeDef* node, const int num_outputs, const int frame_id) {
   // have to remove control inputs to the invariant node from the same frame
   // when moving this node out of this frame
   for (int i = 0; i < node->input_size(); ++i) {
@@ -228,16 +189,14 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
   DataTypeVector output_types;
   OpRegistryInterface* op_registry = OpRegistry::Global();
   const OpRegistrationData* op_reg_data = nullptr;
-  TF_RETURN_IF_ERROR(
-      op_registry->LookUp(node->op(), &op_reg_data));
-  TF_RETURN_IF_ERROR(
-      InOutTypesForNode(*node, op_reg_data->op_def,
-                        &input_types, &output_types));
+  TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
+  TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
+                                       &output_types));
 
   auto consumers = node_map_->GetOutputs(node->name());
   string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
-  int piterations = invariant_enters_[frame_id][0]
-                    ->attr().at("parallel_iterations").i();
+  int piterations =
+      invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
   for (auto* consumer : consumers) {
     if (!invariant_nodes_.count(consumer)) {
       for (int i = 0; i < consumer->input_size(); ++i) {
@@ -281,28 +240,27 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
   return Status::OK();
 }
 
-Status LoopOptimizer::MoveInvariantNodes(const int frame_id) {
-  for (auto iter = invariant_nodes_.begin();
-       iter != invariant_nodes_.end(); ++iter) {
+Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
+    const int frame_id) {
+  for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
+       ++iter) {
     auto* invariant_node = iter->first;
     const int num_outputs = iter->second;
     if (IsEnter(*invariant_node)) {
-      TF_RETURN_IF_ERROR(
-          LINMHandleInvariantEnter(invariant_node, num_outputs));
+      TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
     } else if (IsConstant(*invariant_node)) {
-      TF_RETURN_IF_ERROR(
-          LINMHandleConst(invariant_node, num_outputs, frame_id));
+      TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
     } else {
       TF_RETURN_IF_ERROR(
-          LINMHandleInvariantNode(invariant_node, num_outputs, frame_id));
+          HandleInvariantNode(invariant_node, num_outputs, frame_id));
     }
   }
   return Status::OK();
 }
 
-Status LoopOptimizer::RevertInvariantNodes() {
+Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
   std::deque<const NodeDef*> reverted_nodes;
-  for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
+  for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
     bool erased = false;
     const auto* node = iter->first;
     if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
@@ -331,8 +289,8 @@ Status LoopOptimizer::RevertInvariantNodes() {
       auto* producer = node_map_->GetNode(input);
       auto iter = invariant_nodes_.find(producer);
       if (iter != invariant_nodes_.end()) {
-        if (IsControlInput(input) &&
-            !IsConstant(*producer) && !IsEnter(*producer)) {
+        if (IsControlInput(input) && !IsConstant(*producer) &&
+            !IsEnter(*producer)) {
           reverted_nodes.push_back(producer);
           invariant_nodes_.erase(iter);
         } else {
@@ -357,12 +315,11 @@ Status LoopOptimizer::RevertInvariantNodes() {
   return Status::OK();
 }
 
-Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
+Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) {
   auto consumers = node_map_->GetOutputs(node->name());
   invariant_nodes_.insert(std::make_pair(node, consumers.size()));
   for (auto* consumer : consumers) {
-    if (invariant_nodes_.count(consumer) ||
-        ModifiesFrameInfo(*consumer)) {
+    if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
       continue;
     }
     bool is_invariant = true;
@@ -399,9 +356,14 @@ Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
   return Status::OK();
 }
 
-Status LoopOptimizer::LoopInvariantNodeMotion() {
+Status LoopInvariantNodeMotionOptimizer::Optimize() {
+  node_map_.reset(new NodeMap(optimized_graph_));
+  FrameMap frame_map;
+  int num_frames;
+  TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
+                                               &frame_map, &num_frames));
   std::deque<int> worklist;
-  for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) {
+  for (auto iter = frame_map.begin(); iter != frame_map.end(); ++iter) {
     auto* node = iter->first;
     auto& frame_ids = iter->second;
     if (frame_ids.size() >= 3) {
@@ -467,19 +429,82 @@ Status LoopOptimizer::LoopInvariantNodeMotion() {
   return Status::OK();
 }
 
-Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
-                               GraphDef* optimized_graph) {
+std::vector<int> GetStackPushNodesToConvert(
+    const SimpleGraphView& graph_view,
+    const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
+  VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
+  const std::unordered_set<string> op_types_to_traverse(
+      {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
+       "Identity", "RefIdentity"});
+  std::vector<int> nodes_to_convert;
+  std::set<int> fanout;
+  graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
+  for (int fanout_idx : fanout) {
+    const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
+    VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
+    if (IsStackPushOp(fanout_node)) {
+      nodes_to_convert.push_back(fanout_idx);
+    } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
+               op_types_to_traverse.find(fanout_node.op()) !=
+                   op_types_to_traverse.end()) {
+      continue;
+    } else if (!IsStackPopOp(fanout_node) ||
+               (!graph_view.outputs(fanout_idx).empty() ||
+                nodes_to_preserve.find(fanout_node.name()) !=
+                    nodes_to_preserve.end())) {
+      // The node is either a stack pop with consumers or something unexpected
+      // so we leave the graph alone.
+      nodes_to_convert.clear();
+      break;
+    }
+  }
+  return nodes_to_convert;
+}
+
+Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
+  const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
+  const GraphDef& graph = item.graph;
+  *optimized_graph = graph;
+  NodeMap node_map(optimized_graph);
+  SimpleGraphView graph_view;
+  TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
+  for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
+    if (IsStackOp(graph.node(node_idx))) {
+      for (int push_node_idx : GetStackPushNodesToConvert(
+               graph_view, nodes_to_preserve, node_idx)) {
+        // We found push nodes without corresponding pops. Convert them to
+        // Identity passing the data through and add a control dependency from
+        // the op supplying the stack handle.
+        NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
+        VLOG(1) << "Converting " << push_node_idx << " : "
+                << push_node->DebugString();
+        if (push_node->attr().count("swap_memory") != 0) {
+          push_node->mutable_attr()->erase("swap_memory");
+        }
+        push_node->set_op("Identity");
+        push_node->mutable_input()->SwapElements(0, 1);
+        const string ctrl_dep = ConstantFolding::AddControlDependency(
+            push_node->input(1), optimized_graph, &node_map);
+        push_node->set_input(1, ctrl_dep);
+        VLOG(1) << "After converting: " << push_node->DebugString();
+      }
+    }
+  }
+  return Status::OK();
+}
 
-  TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
+}  // namespace
 
-  if (opt_level_ == RewriterConfig::AGGRESSIVE) {
-    optimized_graph_ = optimized_graph;
-    // Set up helper data structures.
-    node_map_.reset(new NodeMap(optimized_graph_));
-    int num_frames;
-    TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
-                                                 &frame_map_, &num_frames));
-    TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+                               GraphDef* optimized_graph) {
+  *optimized_graph = item.graph;
+  // Set up helper data structures.
+  if (options_.enable_loop_invariant_node_motion) {
+    LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
+    TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
+  }
+  if (options_.enable_stack_push_removal) {
+    TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
   }
 
   return Status::OK();
index c1b0321..83c499b 100644 (file)
@@ -30,9 +30,13 @@ constexpr char kLoopOptimizer[] = "LoopOptimizer";
 
 class LoopOptimizer : public GraphOptimizer {
  public:
-  LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
+  LoopOptimizer()
+      : opt_level_(RewriterConfig::ON),
+        options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
   explicit LoopOptimizer(RewriterConfig::Toggle opt_level)
-      : opt_level_(opt_level) {}
+      : opt_level_(opt_level),
+        options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
+
   ~LoopOptimizer() override {}
 
   string name() const override { return "loop_optimizer"; };
@@ -44,29 +48,24 @@ class LoopOptimizer : public GraphOptimizer {
                 const GraphDef& optimized_graph, double result) override;
 
  private:
-  Status LoopInvariantNodeMotion();
-  Status FindInvariantNodes(NodeDef* node);
-  Status RevertInvariantNodes();
-  Status MoveInvariantNodes(const int frame_id);
-  Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs,
-      const int frame_id);
-  Status LINMHandleConst(NodeDef* node, const int num_outputs,
-      const int frame_id);
-  Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs);
-
-  std::map<NodeDef*, int> invariant_nodes_;
-  std::set<int> empty_set_;
-  std::map<int, std::set<int>> frame_children_;
-  std::map<int, int> frame_parent_;
-  std::map<int, const NodeDef*> loop_cond_;
-  std::map<int, std::vector<NodeDef*>> invariant_enters_;
-  int new_enter_id_;
-  RewriterConfig::Toggle opt_level_;
+  friend class LoopOptimizerTest;
+
+  // Granular control for loop optimizer stages.
+  struct LoopOptimizerOptions {
+    bool enable_loop_invariant_node_motion = false;
+    bool enable_stack_push_removal = true;
+
+    static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
+      LoopOptimizerOptions options;
+      if (opt_level == RewriterConfig::AGGRESSIVE) {
+        options.enable_loop_invariant_node_motion = true;
+      }
+      return options;
+    }
+  };
 
-  std::unique_ptr<NodeMap> node_map_;
-  FrameMap frame_map_;
-  std::unique_ptr<GraphProperties> graph_properties_;
-  GraphDef* optimized_graph_;  // Not owned.
+  RewriterConfig::Toggle opt_level_;
+  LoopOptimizerOptions options_;
 };
 
 }  // end namespace grappler
index a0bd335..10ec544 100644 (file)
@@ -25,7 +25,6 @@ limitations under the License.
 
 namespace tensorflow {
 namespace grappler {
-namespace {
 
 class LoopOptimizerTest : public GrapplerTest {
  protected:
@@ -57,6 +56,23 @@ class LoopOptimizerTest : public GrapplerTest {
     attributes.emplace_back("T", type);
     AddNode(name, op, inputs, attributes, graph);
   }
+
+  void DisableAllStages(LoopOptimizer* optimizer) {
+    LoopOptimizer::LoopOptimizerOptions options;
+    options.enable_loop_invariant_node_motion = false;
+    options.enable_stack_push_removal = false;
+    optimizer->options_ = options;
+  }
+
+  void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.enable_loop_invariant_node_motion = true;
+  }
+
+  void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.enable_stack_push_removal = true;
+  }
 };
 
 TEST_F(LoopOptimizerTest, Basic) {
@@ -81,7 +97,8 @@ TEST_F(LoopOptimizerTest, Basic) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -128,7 +145,8 @@ TEST_F(LoopOptimizerTest, Const) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -175,7 +193,8 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -235,7 +254,8 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -302,7 +322,8 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -365,7 +386,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -429,7 +451,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
   GrapplerItem item;
   item.graph = graph;
 
-  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  LoopOptimizer optimizer;
+  EnableOnlyLoopInvariantNodeMotion(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -475,6 +498,7 @@ TEST_F(LoopOptimizerTest, NoOp) {
   CHECK(fake_input.NextItem(&item));
 
   LoopOptimizer optimizer;
+  EnableOnlyStackPushRemoval(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -504,6 +528,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
   AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
 
   LoopOptimizer optimizer;
+  EnableOnlyStackPushRemoval(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -534,6 +559,7 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
   item.fetch.push_back("pop4");
 
   LoopOptimizer optimizer;
+  EnableOnlyStackPushRemoval(&optimizer);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -563,6 +589,5 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
   }
 }
 
-}  // namespace
 }  // namespace grappler
 }  // namespace tensorflow