}
TEST_F(ConstantFoldingTest, NeutralElement) {
- for (bool use_const : {true, false}) {
+ int kConst = 0;
+ int kLike = 1;
+ int kFill = 2;
+ for (int const_type : {kConst, kLike, kFill}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
ops::Placeholder::Shape(TensorShape({2, 3})));
Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2})));
- Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
- : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
- Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
- : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
+ Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
+ Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
+ Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
+ Output zeros = const_type == kConst
+ ? zeros_const
+ : (const_type == kLike ? zeros_like : zeros_fill);
+ Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
+ Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
+ Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
+ Output ones = const_type == kConst
+ ? ones_const
+ : (const_type == kLike ? ones_like : ones_fill);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ const string suffix =
+ (const_type == kConst ? "_const"
+ : (const_type == kLike ? "_like" : "_fill"));
+ const string zeros_name = strings::StrCat("zeros", suffix);
+ const string ones_name = strings::StrCat("ones", suffix);
+ const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
+ const string ctrl_ones_name = strings::StrCat("^ones", suffix);
EXPECT_EQ(28, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
if (name == "mul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "mul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul4") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul5") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
} else if (name == "div1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "div2") {
EXPECT_EQ("Reciprocal", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "matmul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "matmul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "matmul3") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^a", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
EXPECT_EQ(0, t.float_val(0));
EXPECT_EQ(2, t.tensor_shape().dim(1).size());
} else if (name == "matmul4") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^b", node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
} else if (name == "add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add2") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "bias_add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
} else if (name == "bias_add2") {
// We don't eliminate this one, because it requires broadcasting.
EXPECT_EQ("BiasAdd", node.op());
- EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
} else if (name == "sub1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "sub2") {
EXPECT_EQ("Neg", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
}
const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
"mul6", "matmul1", "matmul2"};