item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
EXPECT_EQ(node.input(2), "Split:2");
}
}
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+ item.feed = {{"Placeholder", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
EXPECT_EQ(2, outputs_node->input_size());
EXPECT_EQ(outputs_node->input(0), "outputs_const");
EXPECT_EQ(outputs_node->input(1), "^Placeholder");
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
EXPECT_EQ(3, output.node_size());
EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantCast(&optimizer);
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Cast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {