Refactor and enable loop optimizer tests.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Mar 2018 18:16:21 +0000 (11:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 18:20:34 +0000 (11:20 -0700)
PiperOrigin-RevId: 189215781

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc

index ffeaa38..0df5307 100644 (file)
@@ -538,13 +538,7 @@ cc_library(
 
 tf_cc_test(
     name = "loop_optimizer_test",
-    size = "small",
     srcs = ["loop_optimizer_test.cc"],
-    tags = [
-        "manual",
-        "no_oss",  # b/74111495
-        "notap",
-    ],
     deps = [
         ":loop_optimizer",
         "//tensorflow/cc:cc_ops",
index 0bd202a..0d45ba9 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -26,72 +27,56 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
-class LoopOptimizerTest : public ::testing::Test {
+class LoopOptimizerTest : public GrapplerTest {
  protected:
-  static NodeDef CreateNode(const string& name,
-                            const std::vector<string>& inputs) {
-    return CreateNode(name, "Identity", "", false, 0, inputs);
-  }
-  static NodeDef CreateNode(const string& name, const string& op,
-                            const std::vector<string>& inputs) {
-    return CreateNode(name, op, "", false, 0, inputs);
+  // These helpers always sets T=DT_FLOAT.
+  void AddEnterNode(const string& name, const string& frame,
+                    const bool is_constant, const int piterations,
+                    const std::vector<string>& inputs, GraphDef* graph) const {
+    std::vector<std::pair<string, AttrValue>> attributes;
+    AttrValue type;
+    type.set_type(DT_FLOAT);
+    attributes.emplace_back("T", type);
+    AttrValue frame_name;
+    frame_name.set_s(frame);
+    attributes.emplace_back("frame_name", frame_name);
+    AttrValue is_const;
+    is_const.set_b(is_constant);
+    attributes.emplace_back("is_constant", is_const);
+    AttrValue parallel_iterations;
+    parallel_iterations.set_i(piterations);
+    attributes.emplace_back("parallel_iterations", parallel_iterations);
+    AddNode(name, "Enter", inputs, attributes, graph);
   }
-  static NodeDef CreateNode(const string& name, const string& op,
-                            const string& frame,
-                            const bool is_constant,
-                            const int piterations,
-                            const std::vector<string>& inputs) {
-    NodeDef node;
-    node.set_name(name);
-    if (!op.empty()) {
-      node.set_op(op);
-    }
-    if (!frame.empty()) {
-      AttrValue frame_name;
-      frame_name.set_s(frame);
-      node.mutable_attr()->insert({"frame_name", frame_name});
-    }
-    if (op == "Enter") {
-      AttrValue is_const;
-      is_const.set_b(is_constant);
-      node.mutable_attr()->insert({"is_constant", is_const});
-      AttrValue parallel_iterations;
-      parallel_iterations.set_i(piterations);
-      node.mutable_attr()->insert(
-          {"parallel_iterations", parallel_iterations});
-    }
+
+  void AddSimpleNode(const string& name, const string& op,
+                     const std::vector<string>& inputs, GraphDef* graph) const {
+    std::vector<std::pair<string, AttrValue>> attributes;
     AttrValue type;
     type.set_type(DT_FLOAT);
-    node.mutable_attr()->insert({"T", type});
-    for (const string& input : inputs) {
-      node.add_input(input);
-    }
-    return node;
+    attributes.emplace_back("T", type);
+    AddNode(name, op, inputs, attributes, graph);
   }
 };
 
 TEST_F(LoopOptimizerTest, Basic) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"VariantAdd", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"VariantAdd"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -123,27 +108,22 @@ TEST_F(LoopOptimizerTest, Basic) {
 
 TEST_F(LoopOptimizerTest, Const) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode("Const", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "Const"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"VariantAdd", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"VariantAdd"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("Const", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "Const"}, &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -174,27 +154,23 @@ TEST_F(LoopOptimizerTest, Const) {
 
 TEST_F(LoopOptimizerTest, ControlOutput) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode(
-      "Less", "Less", {"VariantAdd", "less/y", "^InvariantAdd"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"VariantAdd"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y", "^InvariantAdd"},
+                &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -223,47 +199,38 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
 
 TEST_F(LoopOptimizerTest, NestedLoop1) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"Exit2"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
-
-  *graph.add_node() = CreateNode(
-      "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
-      {"VariantAdd"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter2", "Enter", "while/while/while_context", false, 1,
-      {"VariantEnter"});
-  *graph.add_node() = CreateNode(
-      "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
-  *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
-  *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
-  *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
-  *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration2", "NextIteration", {"VariantAdd2"});
-  *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+  AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+               {"VariantAdd"}, &graph);
+  AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
+                &graph);
+  AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+  AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+               {"VariantEnter"}, &graph);
+  AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+  AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+  AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+  AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+  AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+  AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+  AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -299,47 +266,38 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
 
 TEST_F(LoopOptimizerTest, NestedLoop2) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"Exit2"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
-
-  *graph.add_node() = CreateNode(
-      "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
-      {"InvariantAdd"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter2", "Enter", "while/while/while_context", false, 1,
-      {"VariantEnter"});
-  *graph.add_node() = CreateNode(
-      "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
-  *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
-  *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
-  *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
-  *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration2", "NextIteration", {"VariantAdd2"});
-  *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+  AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+               {"InvariantAdd"}, &graph);
+  AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
+                &graph);
+  AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+  AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+               {"VariantEnter"}, &graph);
+  AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+  AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+  AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+  AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+  AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+  AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+  AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -371,48 +329,38 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
 
 TEST_F(LoopOptimizerTest, NestedLoopConst1) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"Exit2"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
-
-  *graph.add_node() = CreateNode(
-      "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
-      {"VariantAdd"});
-  *graph.add_node() = CreateNode("Const2", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd2", "Add", {"InvariantEnter2", "Const2"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter2", "Enter", "while/while/while_context", false, 1,
-      {"VariantEnter"});
-  *graph.add_node() = CreateNode(
-      "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
-  *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
-  *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
-  *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
-  *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration2", "NextIteration", {"VariantAdd2"});
-  *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+  AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+               {"VariantAdd"}, &graph);
+  AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
+  AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+  AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+               {"VariantEnter"}, &graph);
+  AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+  AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+  AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+  AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+  AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+  AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+  AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -445,48 +393,38 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
 
 TEST_F(LoopOptimizerTest, NestedLoopConst2) {
   GraphDef graph;
-  *graph.add_node() = CreateNode("0", {});
-  *graph.add_node() = CreateNode(
-      "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd", "Add", {"InvariantAdd", "Identity"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
-  *graph.add_node() = CreateNode(
-      "Merge", "Merge", {"VariantEnter", "NextIteration"});
-  *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
-  *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
-  *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
-  *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
-  *graph.add_node() = CreateNode("Identity", {"Switch:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration", "NextIteration", {"Exit2"});
-  *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
-  *graph.add_node() = CreateNode("1", {"Exit"});
-
-  *graph.add_node() = CreateNode(
-      "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
-      {"InvariantAdd"});
-  *graph.add_node() = CreateNode("Const2", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode(
-      "InvariantAdd2", "Add", {"InvariantEnter2", "Const2"});
-  *graph.add_node() = CreateNode(
-      "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
-  *graph.add_node() = CreateNode(
-      "VariantEnter2", "Enter", "while/while/while_context", false, 1,
-      {"VariantEnter"});
-  *graph.add_node() = CreateNode(
-      "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
-  *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
-  *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
-  *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
-  *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
-  *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
-  *graph.add_node() = CreateNode(
-      "NextIteration2", "NextIteration", {"VariantAdd2"});
-  *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+  AddSimpleNode("In", "Identity", {}, &graph);
+  AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+               &graph);
+  AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+                &graph);
+  AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+  AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+  AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+  AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+  AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+  AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+  AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+  AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+  AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+  AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+  AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+  AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+               {"InvariantAdd"}, &graph);
+  AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
+  AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+  AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+               {"VariantEnter"}, &graph);
+  AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+  AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+  AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+  AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+  AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+  AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+  AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+  AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
 
   GrapplerItem item;
   item.graph = graph;
@@ -544,50 +482,26 @@ TEST_F(LoopOptimizerTest, NoOp) {
   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
 }
 
-namespace {
-NodeDef* AddNode(const string& name, const string& op,
-                 const std::vector<string>& inputs,
-                 const std::vector<std::pair<string, AttrValue>>& attributes,
-                 GraphDef* graph) {
-  NodeDef* node = graph->add_node();
-  node->set_name(name);
-  node->set_op(op);
-  for (const string& input : inputs) {
-    node->add_input(input);
-  }
-  for (auto attr : attributes) {
-    (*node->mutable_attr())[attr.first] = attr.second;
-  }
-  return node;
-}
-}  // namespace
-
 TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
   GrapplerItem item;
-  AttrValue frame_name;
-  frame_name.set_s("foo");
-  AttrValue type;
-  type.set_type(DT_RESOURCE);
   GraphDef& graph = item.graph;
-  AddNode("c", "Const", {}, {}, &graph);
+  AddSimpleNode("c", "Const", {}, &graph);
   // Stack with corresponding push/pop.
-  AddNode("stack1", "StackV2", {}, {}, &graph);
-  AddNode("push1", "StackPushV2", {"stack1", "c"}, {}, &graph);
-  AddNode("pop1", "StackPopV2", {"stack1"}, {}, &graph);
-  AddNode("id1", "Identity", {"pop1"}, {}, &graph);
+  AddSimpleNode("stack1", "StackV2", {}, &graph);
+  AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
+  AddSimpleNode("pop1", "StackPopV2", {"stack1"}, &graph);
+  AddSimpleNode("id1", "Identity", {"pop1"}, &graph);
   // Stack with corresponding push/pop behind Enter.
-  AddNode("stack2", "StackV2", {}, {}, &graph);
-  AddNode("push_enter", "Enter", {"stack2"},
-          {{"T", type}, {"frame_name", frame_name}}, &graph);
-  AddNode("push2", "StackPushV2", {"push_enter", "c"}, {}, &graph);
-  AddNode("pop_enter", "Enter", {"stack2"},
-          {{"T", type}, {"frame_name", frame_name}}, &graph);
-  AddNode("pop2", "StackPopV2", {"pop_enter"}, {}, &graph);
-  AddNode("id2", "Identity", {"pop2"}, {}, &graph);
+  AddSimpleNode("stack2", "StackV2", {}, &graph);
+  AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
+  AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+  AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
+  AddSimpleNode("pop2", "StackPopV2", {"enter2_stack2"}, &graph);
+  AddSimpleNode("id2", "Identity", {"pop2"}, &graph);
   // Stack with unexpected op type in fanout of Stack.
-  AddNode("stack3", "StackV2", {}, {}, &graph);
-  AddNode("push3", "StackPushV2", {"stack3", "c"}, {}, &graph);
-  AddNode("stop", "StopGradient", {"stack3"}, {}, &graph);
+  AddSimpleNode("stack3", "StackV2", {}, &graph);
+  AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
+  AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
   LoopOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
@@ -598,29 +512,25 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
 TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
   GrapplerItem item;
   GraphDef& graph = item.graph;
-  AttrValue frame_name;
-  frame_name.set_s("foo");
-  AttrValue type;
-  type.set_type(DT_RESOURCE);
-  AddNode("c", "Const", {}, {}, &graph);
+  AddSimpleNode("c", "Const", {}, &graph);
   // Push without Pop.
-  AddNode("stack1", "StackV2", {}, {}, &graph);
-  AddNode("push1", "StackPushV2", {"stack1", "c"}, {}, &graph);
+  AddSimpleNode("stack1", "StackV2", {}, &graph);
+  AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
   // Push without Pop behind Enter.
-  AddNode("stack2", "StackV2", {}, {}, &graph);
-  AddNode("push_enter", "Enter", {"stack2"},
-          {{"T", type}, {"frame_name", frame_name}}, &graph);
-  AddNode("push2", "StackPushV2", {"push_enter", "c"}, {}, &graph);
+  AddSimpleNode("stack2", "StackV2", {}, &graph);
+  AddEnterNode("enter_c", "frame_name", false, 1, {"c"}, &graph);
+  AddEnterNode("enter_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+  AddSimpleNode("push2", "StackPushV2", {"enter_stack2", "enter_c"}, &graph);
   // Pop without consumer.
-  AddNode("stack3", "StackV2", {}, {}, &graph);
-  AddNode("push3", "StackPushV2", {"stack3", "c"}, {}, &graph);
-  AddNode("pop3", "StackPopV2", {"stack3"}, {}, &graph);
+  AddSimpleNode("stack3", "StackV2", {}, &graph);
+  AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
+  AddSimpleNode("pop3", "StackPopV2", {"stack3"}, &graph);
 
   LoopOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
-  EXPECT_EQ(9, output.node_size());
+  EXPECT_EQ(10, output.node_size());
   for (int i = 0; i < output.node_size(); ++i) {
     const NodeDef& node = output.node(i);
     if (node.name() == "push1") {
@@ -631,8 +541,8 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
     } else if (node.name() == "push2") {
       EXPECT_EQ("Identity", node.op());
       EXPECT_EQ(2, node.input_size());
-      EXPECT_EQ("c", node.input(0));
-      EXPECT_EQ("^push_enter", node.input(1));
+      EXPECT_EQ("enter_c", node.input(0));
+      EXPECT_EQ("^enter_stack2", node.input(1));
     } else if (node.name() == "push3") {
       EXPECT_EQ("Identity", node.op());
       EXPECT_EQ(2, node.input_size());