run evaluate nodes on parts of arithmetic optimizer tests.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:28:22 +0000 (14:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 21:31:49 +0000 (14:31 -0700)
PiperOrigin-RevId: 191647386

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/core/grappler/utils/grappler_test.h

index 4ce3e73..0c6549d 100644 (file)
@@ -274,6 +274,11 @@ tf_cuda_cc_test(
         ":constant_folding",
         ":model_pruner",
         "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
+        "//tensorflow/core:all_kernels",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
index ef3ed35..48f1dd5 100644 (file)
@@ -156,7 +156,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch = {"div"};
 
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
   EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
@@ -164,7 +164,6 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
   OptimizeTwice(&optimizer, &item, &output);
   NodeMap node_map(&output);
   EXPECT_EQ(2, output.node_size());
-
   const NodeDef* new_c1 = node_map.GetNode("c1");
   ASSERT_NE(new_c1, nullptr);
 
@@ -174,7 +173,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
   EXPECT_EQ("c1", new_div->input(0));
   EXPECT_EQ("c1", new_div->input(1));
 
-  auto tensors = EvaluateNodes(output, item.fetch, {});
+  auto tensors = EvaluateNodes(output, item.fetch);
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
 }
@@ -193,6 +192,11 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch = {"div"};
+  Tensor bool_t(DT_BOOL, TensorShape({}));
+  bool_t.scalar<bool>().setConstant(true);
+  auto tensors_expected =
+      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -208,6 +212,10 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
   EXPECT_EQ("check1", new_div->input(1));
   EXPECT_EQ("^assert1", new_div->input(2));
   EXPECT_EQ("^assert1", new_div->input(3));
+
+  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
@@ -219,7 +227,9 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
   Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  item.fetch = {"div"};
+  item.fetch = {"div1"};
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -241,6 +251,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
   EXPECT_EQ(2, new_div1->input_size());
   EXPECT_EQ("mul1", new_div1->input(0));
   EXPECT_EQ("mul1", new_div1->input(1));
+
+  auto tensors = EvaluateNodes(output, item.fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, MulToSquare) {
@@ -251,6 +265,9 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) {
   Output id = ops::Identity(s.WithOpName("id"), mul);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"id"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -265,6 +282,10 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) {
   EXPECT_EQ(2, output.node(4).input_size());
   EXPECT_EQ("c", output.node(4).input(0));
   EXPECT_EQ("^d", output.node(4).input(1));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
@@ -277,6 +298,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
   Output id = ops::Identity(s.WithOpName("id"), recip2);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"id"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -287,6 +311,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
   EXPECT_EQ("c", output.node(1).input(0));
   EXPECT_EQ("c", output.node(3).input(0));
   EXPECT_EQ("c", output.node(5).input(0));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
@@ -299,6 +327,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
   Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"id2"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -312,6 +343,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
   EXPECT_EQ(6, output.node_size());
   EXPECT_EQ("squeeze", output.node(5).input(0));
   EXPECT_EQ("c", output.node(2).input(0));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
@@ -326,6 +361,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
+  std::vector<string> fetch = {"id2"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+
   ArithmeticOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
@@ -343,6 +382,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
       EXPECT_EQ(original.input(j), optimized.input(j));
     }
   }
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
@@ -354,6 +397,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
+  std::vector<string> fetch = {"id"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+
   ArithmeticOptimizer optimizer;
   GraphDef output;
   OptimizeTwice(&optimizer, &item, &output);
@@ -375,6 +422,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
   const NodeDef* new_id = node_map.GetNode("id");
   ASSERT_NE(new_id, nullptr);
   EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
@@ -387,6 +438,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
+  std::vector<string> fetch = {"id"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+
   ArithmeticOptimizer optimizer;
   GraphDef output;
   OptimizeTwice(&optimizer, &item, &output);
@@ -409,6 +464,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
   const NodeDef* new_id = node_map.GetNode("id");
   ASSERT_NE(new_id, nullptr);
   EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
@@ -424,6 +483,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
 
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
   const std::vector<string> devices{
       "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
       "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
@@ -515,7 +575,8 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
       GrapplerItem item;
       item.fetch = {"id"};
       TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+      auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+      EXPECT_EQ(1, tensors_expected.size());
       ArithmeticOptimizer optimizer;
       EnableOnlyHoistCommonFactor(&optimizer);
 
@@ -554,21 +615,26 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
         EXPECT_EQ("id", id_node->name());
         EXPECT_EQ(HoistMulName("add"), id_node->input(0));
       }
+      auto tensors = EvaluateNodes(output, item.fetch);
+      EXPECT_EQ(1, tensors.size());
+      test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
     }
   }
 }
 
 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
-  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+  Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+  Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
   Output z = ops::Complex(s.WithOpName("z"), re, im);
   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
   Output conj = ops::Conj(s.WithOpName("conj"), z);
   Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+  std::vector<string> fetch = {"trans"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
   ArithmeticOptimizer optimizer;
   GraphDef output;
   OptimizeTwice(&optimizer, &item, &output);
@@ -582,12 +648,16 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
   EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
   EXPECT_EQ("z", trans_fused_node->input(0));
   EXPECT_EQ("perm", trans_fused_node->input(1));
+
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
 }
 
 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
-  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+  Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+  Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
   Output z = ops::Complex(s.WithOpName("z"), re, im);
   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
   Output conj = ops::Conj(s.WithOpName("conj"), z);
@@ -595,6 +665,9 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
       ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"conjugate_trans"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -608,18 +681,24 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
   EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
   EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
   EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
 }
 
 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
-  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+  Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+  Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
   Output z = ops::Complex(s.WithOpName("z"), re, im);
   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
   Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
   Output conj = ops::Conj(s.WithOpName("conj"), trans);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"conj"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -633,6 +712,9 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
   EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
   EXPECT_EQ("z", conj_fused_node->input(0));
   EXPECT_EQ("perm", conj_fused_node->input(1));
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
 }
 
 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
@@ -654,6 +736,9 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
     }
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
+    std::vector<string> fetch = {"matmul"};
+    auto tensors_expected = EvaluateNodes(item.graph, fetch);
+    EXPECT_EQ(1, tensors_expected.size());
 
     ArithmeticOptimizer optimizer;
     GraphDef output;
@@ -674,6 +759,9 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
     }
+    auto tensors = EvaluateNodes(output, fetch);
+    EXPECT_EQ(1, tensors.size());
+    test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
   }
 }
 
@@ -695,6 +783,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
   Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  std::vector<string> fetch = {"matmul"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
@@ -707,6 +798,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
   EXPECT_EQ("b", output.node(10).input(1));
   EXPECT_TRUE(output.node(10).attr().at("adj_x").b());
   EXPECT_TRUE(output.node(10).attr().at("adj_y").b());
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
@@ -727,7 +821,10 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
   GrapplerItem item;
   item.fetch = {"outputs"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+  auto tensors_expected =
+      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+  EXPECT_EQ(1, tensors_expected.size());
   GraphDef output;
   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
@@ -735,6 +832,9 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
 
   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
@@ -749,7 +849,10 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
   GrapplerItem item;
   item.fetch = {"outputs"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
+  item.feed = {{"Placeholder", x_t}};
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+  EXPECT_EQ(1, tensors_expected.size());
   GraphDef output;
   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
@@ -757,6 +860,9 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
 
   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+  auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
 TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
@@ -769,7 +875,6 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
   GrapplerItem item;
   item.fetch = {"outputs"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
   GraphDef output;
   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
@@ -800,7 +905,10 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
   GrapplerItem item;
   item.fetch = {"outputs"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+  auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
+  item.feed = {{"nchw_vect_c", x_t}};
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+  EXPECT_EQ(1, tensors_expected.size());
   GraphDef output;
   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
@@ -808,6 +916,9 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
 
   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+  auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
 }
 
 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
index 7faa68a..8d14663 100644 (file)
@@ -83,14 +83,6 @@ class ConstantFoldingTest : public GrapplerTest {
   }
 };
 
-template <DataType DTYPE>
-Tensor GetRandomTensor(const TensorShape& shape) {
-  typedef typename EnumToDataType<DTYPE>::Type T;
-  Tensor tensor(DTYPE, shape);
-  tensor.flat<T>() = tensor.flat<T>().random();
-  return tensor;
-}
-
 TEST_F(ConstantFoldingTest, SimpleFolding) {
   // Build a simple graph with a few trivially prunable ops.
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -380,11 +372,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
       }
     }
-    auto a_t = GetRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
-    auto b_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
-    auto x_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
-    auto y_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
-    auto bias_t = GetRandomTensor<DT_FLOAT>(TensorShape({2}));
+    auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
+    auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+    auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+    auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+    auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
 
     auto tensors_expected = EvaluateNodes(
         item.graph, item.fetch,
index 3bc7bea..e1394b9 100644 (file)
@@ -57,6 +57,15 @@ class GrapplerTest : public ::testing::Test {
   // Count nodes of the given op-type in a graph.
   int CountOpNodes(const GraphDef& graph, const string& op);
 
+  // Get a random tansor with given shape.
+  template <DataType DTYPE>
+  Tensor GenerateRandomTensor(const TensorShape& shape) const {
+    typedef typename EnumToDataType<DTYPE>::Type T;
+    Tensor tensor(DTYPE, shape);
+    tensor.flat<T>() = tensor.flat<T>().random();
+    return tensor;
+  }
+
  private:
   SessionOptions options_;
 };