#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace grappler {
namespace {
-class LoopOptimizerTest : public ::testing::Test {
+class LoopOptimizerTest : public GrapplerTest {
protected:
- static NodeDef CreateNode(const string& name,
- const std::vector<string>& inputs) {
- return CreateNode(name, "Identity", "", false, 0, inputs);
- }
- static NodeDef CreateNode(const string& name, const string& op,
- const std::vector<string>& inputs) {
- return CreateNode(name, op, "", false, 0, inputs);
+ // These helpers always sets T=DT_FLOAT.
+ void AddEnterNode(const string& name, const string& frame,
+ const bool is_constant, const int piterations,
+ const std::vector<string>& inputs, GraphDef* graph) const {
+ std::vector<std::pair<string, AttrValue>> attributes;
+ AttrValue type;
+ type.set_type(DT_FLOAT);
+ attributes.emplace_back("T", type);
+ AttrValue frame_name;
+ frame_name.set_s(frame);
+ attributes.emplace_back("frame_name", frame_name);
+ AttrValue is_const;
+ is_const.set_b(is_constant);
+ attributes.emplace_back("is_constant", is_const);
+ AttrValue parallel_iterations;
+ parallel_iterations.set_i(piterations);
+ attributes.emplace_back("parallel_iterations", parallel_iterations);
+ AddNode(name, "Enter", inputs, attributes, graph);
}
- static NodeDef CreateNode(const string& name, const string& op,
- const string& frame,
- const bool is_constant,
- const int piterations,
- const std::vector<string>& inputs) {
- NodeDef node;
- node.set_name(name);
- if (!op.empty()) {
- node.set_op(op);
- }
- if (!frame.empty()) {
- AttrValue frame_name;
- frame_name.set_s(frame);
- node.mutable_attr()->insert({"frame_name", frame_name});
- }
- if (op == "Enter") {
- AttrValue is_const;
- is_const.set_b(is_constant);
- node.mutable_attr()->insert({"is_constant", is_const});
- AttrValue parallel_iterations;
- parallel_iterations.set_i(piterations);
- node.mutable_attr()->insert(
- {"parallel_iterations", parallel_iterations});
- }
+
+ void AddSimpleNode(const string& name, const string& op,
+ const std::vector<string>& inputs, GraphDef* graph) const {
+ std::vector<std::pair<string, AttrValue>> attributes;
AttrValue type;
type.set_type(DT_FLOAT);
- node.mutable_attr()->insert({"T", type});
- for (const string& input : inputs) {
- node.add_input(input);
- }
- return node;
+ attributes.emplace_back("T", type);
+ AddNode(name, op, inputs, attributes, graph);
}
};
TEST_F(LoopOptimizerTest, Basic) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"VariantAdd", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"VariantAdd"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, Const) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode("Const", "Const", {"^Identity"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "Const"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"VariantAdd", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"VariantAdd"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("Const", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "Const"}, &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, ControlOutput) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode(
- "Less", "Less", {"VariantAdd", "less/y", "^InvariantAdd"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"VariantAdd"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y", "^InvariantAdd"},
+ &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, NestedLoop1) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"Exit2"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
-
- *graph.add_node() = CreateNode(
- "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
- {"VariantAdd"});
- *graph.add_node() = CreateNode(
- "InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"});
- *graph.add_node() = CreateNode(
- "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
- *graph.add_node() = CreateNode(
- "VariantEnter2", "Enter", "while/while/while_context", false, 1,
- {"VariantEnter"});
- *graph.add_node() = CreateNode(
- "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
- *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
- *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
- *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
- *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
- *graph.add_node() = CreateNode(
- "NextIteration2", "NextIteration", {"VariantAdd2"});
- *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+ AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+ {"VariantAdd"}, &graph);
+ AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
+ &graph);
+ AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+ AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+ {"VariantEnter"}, &graph);
+ AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+ AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+ AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+ AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+ AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+ AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+ AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, NestedLoop2) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"Exit2"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
-
- *graph.add_node() = CreateNode(
- "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
- {"InvariantAdd"});
- *graph.add_node() = CreateNode(
- "InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"});
- *graph.add_node() = CreateNode(
- "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
- *graph.add_node() = CreateNode(
- "VariantEnter2", "Enter", "while/while/while_context", false, 1,
- {"VariantEnter"});
- *graph.add_node() = CreateNode(
- "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
- *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
- *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
- *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
- *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
- *graph.add_node() = CreateNode(
- "NextIteration2", "NextIteration", {"VariantAdd2"});
- *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+ AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+ {"InvariantAdd"}, &graph);
+ AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
+ &graph);
+ AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+ AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+ {"VariantEnter"}, &graph);
+ AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+ AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+ AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+ AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+ AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+ AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+ AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, NestedLoopConst1) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"Exit2"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
-
- *graph.add_node() = CreateNode(
- "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
- {"VariantAdd"});
- *graph.add_node() = CreateNode("Const2", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode(
- "InvariantAdd2", "Add", {"InvariantEnter2", "Const2"});
- *graph.add_node() = CreateNode(
- "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
- *graph.add_node() = CreateNode(
- "VariantEnter2", "Enter", "while/while/while_context", false, 1,
- {"VariantEnter"});
- *graph.add_node() = CreateNode(
- "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
- *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
- *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
- *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
- *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
- *graph.add_node() = CreateNode(
- "NextIteration2", "NextIteration", {"VariantAdd2"});
- *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+ AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+ {"VariantAdd"}, &graph);
+ AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
+ AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+ AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+ {"VariantEnter"}, &graph);
+ AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+ AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+ AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+ AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+ AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+ AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+ AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
GrapplerItem item;
item.graph = graph;
TEST_F(LoopOptimizerTest, NestedLoopConst2) {
GraphDef graph;
- *graph.add_node() = CreateNode("0", {});
- *graph.add_node() = CreateNode(
- "InvariantEnter", "Enter", "while/while_context", true, 1, {"0"});
- *graph.add_node() = CreateNode(
- "InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"});
- *graph.add_node() = CreateNode(
- "VariantAdd", "Add", {"InvariantAdd", "Identity"});
- *graph.add_node() = CreateNode(
- "VariantEnter", "Enter", "while/while_context", false, 1, {"0"});
- *graph.add_node() = CreateNode(
- "Merge", "Merge", {"VariantEnter", "NextIteration"});
- *graph.add_node() = CreateNode("Less/y", "Const", {"^Identity"});
- *graph.add_node() = CreateNode("Less", "Less", {"Exit2", "less/y"});
- *graph.add_node() = CreateNode("LoopCond", "LoopCond", {"Less"});
- *graph.add_node() = CreateNode("Switch", "Switch", {"Merge", "LoopCond"});
- *graph.add_node() = CreateNode("Identity", {"Switch:1"});
- *graph.add_node() = CreateNode(
- "NextIteration", "NextIteration", {"Exit2"});
- *graph.add_node() = CreateNode("Exit", "Exit", {"Switch"});
- *graph.add_node() = CreateNode("1", {"Exit"});
-
- *graph.add_node() = CreateNode(
- "InvariantEnter2", "Enter", "while/while/while_context", true, 1,
- {"InvariantAdd"});
- *graph.add_node() = CreateNode("Const2", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode(
- "InvariantAdd2", "Add", {"InvariantEnter2", "Const2"});
- *graph.add_node() = CreateNode(
- "VariantAdd2", "Add", {"InvariantAdd2", "Identity2"});
- *graph.add_node() = CreateNode(
- "VariantEnter2", "Enter", "while/while/while_context", false, 1,
- {"VariantEnter"});
- *graph.add_node() = CreateNode(
- "Merge2", "Merge", {"VariantEnter2", "NextIteration2"});
- *graph.add_node() = CreateNode("Less2/y", "Const", {"^Identity2"});
- *graph.add_node() = CreateNode("Less2", "Less", {"VariantAdd2", "less2/y"});
- *graph.add_node() = CreateNode("LoopCond2", "LoopCond", {"Less2"});
- *graph.add_node() = CreateNode("Switch2", "Switch", {"Merge2", "LoopCond2"});
- *graph.add_node() = CreateNode("Identity2", {"Switch2:1"});
- *graph.add_node() = CreateNode(
- "NextIteration2", "NextIteration", {"VariantAdd2"});
- *graph.add_node() = CreateNode("Exit2", "Exit", {"Switch2"});
+ AddSimpleNode("In", "Identity", {}, &graph);
+ AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
+ &graph);
+ AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
+ &graph);
+ AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
+ AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
+ AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
+ AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
+ AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
+ AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
+ AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
+ AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
+ AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
+ AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
+ AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
+
+ AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
+ {"InvariantAdd"}, &graph);
+ AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
+ AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
+ AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
+ {"VariantEnter"}, &graph);
+ AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
+ AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
+ AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
+ AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
+ AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
+ AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
+ AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
+ AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
GrapplerItem item;
item.graph = graph;
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
-namespace {
-NodeDef* AddNode(const string& name, const string& op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- GraphDef* graph) {
- NodeDef* node = graph->add_node();
- node->set_name(name);
- node->set_op(op);
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-} // namespace
-
TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
GrapplerItem item;
- AttrValue frame_name;
- frame_name.set_s("foo");
- AttrValue type;
- type.set_type(DT_RESOURCE);
GraphDef& graph = item.graph;
- AddNode("c", "Const", {}, {}, &graph);
+ AddSimpleNode("c", "Const", {}, &graph);
// Stack with corresponding push/pop.
- AddNode("stack1", "StackV2", {}, {}, &graph);
- AddNode("push1", "StackPushV2", {"stack1", "c"}, {}, &graph);
- AddNode("pop1", "StackPopV2", {"stack1"}, {}, &graph);
- AddNode("id1", "Identity", {"pop1"}, {}, &graph);
+ AddSimpleNode("stack1", "StackV2", {}, &graph);
+ AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
+ AddSimpleNode("pop1", "StackPopV2", {"stack1"}, &graph);
+ AddSimpleNode("id1", "Identity", {"pop1"}, &graph);
// Stack with corresponding push/pop behind Enter.
- AddNode("stack2", "StackV2", {}, {}, &graph);
- AddNode("push_enter", "Enter", {"stack2"},
- {{"T", type}, {"frame_name", frame_name}}, &graph);
- AddNode("push2", "StackPushV2", {"push_enter", "c"}, {}, &graph);
- AddNode("pop_enter", "Enter", {"stack2"},
- {{"T", type}, {"frame_name", frame_name}}, &graph);
- AddNode("pop2", "StackPopV2", {"pop_enter"}, {}, &graph);
- AddNode("id2", "Identity", {"pop2"}, {}, &graph);
+ AddSimpleNode("stack2", "StackV2", {}, &graph);
+ AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
+ AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+ AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
+ AddSimpleNode("pop2", "StackPopV2", {"enter2_stack2"}, &graph);
+ AddSimpleNode("id2", "Identity", {"pop2"}, &graph);
// Stack with unexpected op type in fanout of Stack.
- AddNode("stack3", "StackV2", {}, {}, &graph);
- AddNode("push3", "StackPushV2", {"stack3", "c"}, {}, &graph);
- AddNode("stop", "StopGradient", {"stack3"}, {}, &graph);
+ AddSimpleNode("stack3", "StackV2", {}, &graph);
+ AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
+ AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
LoopOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
GrapplerItem item;
GraphDef& graph = item.graph;
- AttrValue frame_name;
- frame_name.set_s("foo");
- AttrValue type;
- type.set_type(DT_RESOURCE);
- AddNode("c", "Const", {}, {}, &graph);
+ AddSimpleNode("c", "Const", {}, &graph);
// Push without Pop.
- AddNode("stack1", "StackV2", {}, {}, &graph);
- AddNode("push1", "StackPushV2", {"stack1", "c"}, {}, &graph);
+ AddSimpleNode("stack1", "StackV2", {}, &graph);
+ AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
// Push without Pop behind Enter.
- AddNode("stack2", "StackV2", {}, {}, &graph);
- AddNode("push_enter", "Enter", {"stack2"},
- {{"T", type}, {"frame_name", frame_name}}, &graph);
- AddNode("push2", "StackPushV2", {"push_enter", "c"}, {}, &graph);
+ AddSimpleNode("stack2", "StackV2", {}, &graph);
+ AddEnterNode("enter_c", "frame_name", false, 1, {"c"}, &graph);
+ AddEnterNode("enter_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+ AddSimpleNode("push2", "StackPushV2", {"enter_stack2", "enter_c"}, &graph);
// Pop without consumer.
- AddNode("stack3", "StackV2", {}, {}, &graph);
- AddNode("push3", "StackPushV2", {"stack3", "c"}, {}, &graph);
- AddNode("pop3", "StackPopV2", {"stack3"}, {}, &graph);
+ AddSimpleNode("stack3", "StackV2", {}, &graph);
+ AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
+ AddSimpleNode("pop3", "StackPopV2", {"stack3"}, &graph);
LoopOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(9, output.node_size());
+ EXPECT_EQ(10, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
if (node.name() == "push1") {
} else if (node.name() == "push2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("c", node.input(0));
- EXPECT_EQ("^push_enter", node.input(1));
+ EXPECT_EQ("enter_c", node.input(0));
+ EXPECT_EQ("^enter_stack2", node.input(1));
} else if (node.name() == "push3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());