[moco/tf] Revise AvgPool2D unit test (#4242)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 15 Jul 2019 01:58:09 +0000 (10:58 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 15 Jul 2019 01:58:09 +0000 (10:58 +0900)
This will revise AvgPool2D unit test to test only related properties

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Op/AvgPool2D.test.cpp

index 3cdc2b9..3373ac0 100644 (file)
@@ -126,39 +126,31 @@ TEST(TensorFlowImport, AvgPool2D_01)
   EXPECT_TRUE(parse_graphdef(avgpool2d_01_pbtxtdata, graph_def));
   std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
-  // test 1.
-  // loco node : ConstGen - FeatureEncode - AvgPool2D - FeatureDecode - Push
-  loco::Graph::NodeContext *loco_nodes = graph->nodes();
-
-  loco::Graph::InputContext *loco_inputs = graph->inputs();
-  ASSERT_EQ(loco_inputs->size(), 0);
-  ASSERT_EQ(loco_nodes->size(), 5);
+  // what to test:
+  // - there should exist AvgPool2D
+  // - input node should be FeatureEncode
+  // - following node should be FeatureDecode
+  // - stride values should match
+  // - window values should match
+
+  loco::AvgPool2D *avgpool_node =
+      moco::tf::test::find_first_node_bytype<loco::AvgPool2D>(graph.get());
+  ASSERT_NE(avgpool_node, nullptr);
 
-  int idx = 0;
+  loco::Node *previous_node = avgpool_node->ifm();
+  auto following_nodes = loco::succs(avgpool_node);
+  ASSERT_EQ(following_nodes.size(), 1);
+  loco::Node *following_node = *following_nodes.begin();
+  ASSERT_NE(following_node, nullptr);
 
-  loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_nodes->at(idx++));
-  loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(loco_nodes->at(idx++));
-  loco::AvgPool2D *avgpool_node = dynamic_cast<loco::AvgPool2D *>(loco_nodes->at(idx++));
-  loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(loco_nodes->at(idx++));
-  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_nodes->at(idx++));
+  loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(previous_node);
+  loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(following_node);
 
-  ASSERT_NE(const_node, nullptr);
   ASSERT_NE(enc_node, nullptr);
-  ASSERT_NE(avgpool_node, nullptr);
   ASSERT_NE(dec_node, nullptr);
-  ASSERT_NE(push_node, nullptr);
 
-  // check their connection is all OK
-  ASSERT_TRUE(enc_node->input() == const_node);
-  ASSERT_TRUE(avgpool_node->ifm() == enc_node);
-  ASSERT_TRUE(dec_node->input() == avgpool_node);
-  ASSERT_TRUE(push_node->from() == dec_node);
-
-  // test 2.
   // attrs inside AvgPool2D
-  auto avgpool2d = dynamic_cast<loco::AvgPool2D *>(loco_nodes->at(2));
-  ASSERT_NE(avgpool2d, nullptr);
-
+  auto avgpool2d = avgpool_node; // TODO remove this new variable
   // convention
   ASSERT_EQ(avgpool2d->convention(), loco::AvgPool2D::Convention::Valid);