protected:
template <DataType DTYPE>
void SimpleNeutralElementTest() {
- typedef typename EnumToDataType<DTYPE>::Type T;
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
- ops::Placeholder::Shape(TensorShape({2, 2})));
- Tensor zeros_t(DTYPE, TensorShape({2, 2}));
- Tensor ones_t(DTYPE, TensorShape({2, 2}));
- Tensor x_t(DTYPE, TensorShape({2, 2}));
- for (int i = 0; i < 4; ++i) {
- zeros_t.flat<T>()(i) = T(0);
- ones_t.flat<T>()(i) = T(1);
- x_t.flat<T>()(i) = T(i + 1);
- }
- Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
- Output ones = ops::Const(s.WithOpName("ones"), ones_t);
- Output mul1;
- Output mul2;
- Output add1;
- Output add2;
- if (DTYPE == DT_BOOL) {
- mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
- mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
- add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
- add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
- } else {
- mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
- mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
- add1 = ops::Add(s.WithOpName("add1"), x, zeros);
- add1 = ops::Add(s.WithOpName("add2"), x, ones);
- }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"mul1", "mul2", "add1", "add2"};
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(7, output.node_size());
- for (int i = 0; i < output.node_size(); ++i) {
- const NodeDef& node = output.node(i);
- const string& name = node.name();
- if (name == "mul1") {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
- } else if (name == "mul2") {
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
- } else if (name == "add1") {
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
- } else if (name == "add2") {
- if (DTYPE == DT_BOOL) {
+ for (bool use_snapshot : {false, true}) {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
+ Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+ Tensor ones_t(DTYPE, TensorShape({2, 2}));
+ Tensor x_t(DTYPE, TensorShape({2, 2}));
+ for (int i = 0; i < 4; ++i) {
+ zeros_t.flat<T>()(i) = T(0);
+ ones_t.flat<T>()(i) = T(1);
+ x_t.flat<T>()(i) = T(i + 1);
+ }
+ Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+ Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+ Output mul1;
+ Output mul2;
+ Output add1;
+ Output add2;
+ if (DTYPE == DT_BOOL) {
+ mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
+ mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
+ add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
+ add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
+ } else {
+ mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+ add1 = ops::Add(s.WithOpName("add1"), x, zeros);
+ add1 = ops::Add(s.WithOpName("add2"), x, ones);
+ }
+ if (use_snapshot) {
+ // Add an op with ref input to prevent Snapshot from being
+ // turned into Identity.
+ ops::Assign(s.WithOpName("assign"), v, ones);
+ }
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"mul1", "mul2", "add1", "add2"};
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(7, output.node_size());
+ const string snapshot_or_identity =
+ use_snapshot ? "Snapshot" : "Identity";
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "mul2") {
+ EXPECT_EQ(snapshot_or_identity, node.op());
+ EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones", node.input(1));
- } else {
- EXPECT_EQ("Add", node.op());
+ } else if (name == "add1") {
+ EXPECT_EQ(snapshot_or_identity, node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("ones", node.input(1));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "add2") {
+ if (DTYPE == DT_BOOL) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else {
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ones", node.input(1));
+ }
}
}
- }
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
- auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
- EXPECT_EQ(4, tensors_expected.size());
- EXPECT_EQ(4, tensors.size());
- for (int i = 0; i < item.fetch.size(); ++i) {
- test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
+ auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
+ EXPECT_EQ(4, tensors_expected.size());
+ EXPECT_EQ(4, tensors.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ }
}
}
};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"stack", "matmul3", "matmul4"};
- ConstantFolding optimizer(nullptr /* cpu_device */);
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul4") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul5") {
EXPECT_EQ("^zeros_1d", node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "div1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "div2") {
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
EXPECT_EQ(3, t.tensor_shape().dim(1).size());
} else if (name == "add1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add2") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "bias_add1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_1d", node.input(1));
} else if (name == "bias_add2") {
EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
} else if (name == "sub1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "sub2") {