EXPECT_TRUE(parse_graphdef(avgpool2d_01_pbtxtdata, graph_def));
std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
- // test 1.
- // loco node : ConstGen - FeatureEncode - AvgPool2D - FeatureDecode - Push
- loco::Graph::NodeContext *loco_nodes = graph->nodes();
-
- loco::Graph::InputContext *loco_inputs = graph->inputs();
- ASSERT_EQ(loco_inputs->size(), 0);
- ASSERT_EQ(loco_nodes->size(), 5);
+ // what to test:
+ // - there should exist AvgPool2D
+ // - input node should be FeatureEncode
+ // - following node should be FeatureDecode
+ // - stride values should match
+ // - window values should match
+
+ loco::AvgPool2D *avgpool_node =
+ moco::tf::test::find_first_node_bytype<loco::AvgPool2D>(graph.get());
+ ASSERT_NE(avgpool_node, nullptr);
- int idx = 0;
+ loco::Node *previous_node = avgpool_node->ifm();
+ auto following_nodes = loco::succs(avgpool_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
- loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_nodes->at(idx++));
- loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(loco_nodes->at(idx++));
- loco::AvgPool2D *avgpool_node = dynamic_cast<loco::AvgPool2D *>(loco_nodes->at(idx++));
- loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(loco_nodes->at(idx++));
- loco::Push *push_node = dynamic_cast<loco::Push *>(loco_nodes->at(idx++));
+ loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(previous_node);
+ loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(following_node);
- ASSERT_NE(const_node, nullptr);
ASSERT_NE(enc_node, nullptr);
- ASSERT_NE(avgpool_node, nullptr);
ASSERT_NE(dec_node, nullptr);
- ASSERT_NE(push_node, nullptr);
- // check their connection is all OK
- ASSERT_TRUE(enc_node->input() == const_node);
- ASSERT_TRUE(avgpool_node->ifm() == enc_node);
- ASSERT_TRUE(dec_node->input() == avgpool_node);
- ASSERT_TRUE(push_node->from() == dec_node);
-
- // test 2.
// attrs inside AvgPool2D
- auto avgpool2d = dynamic_cast<loco::AvgPool2D *>(loco_nodes->at(2));
- ASSERT_NE(avgpool2d, nullptr);
-
+ auto avgpool2d = avgpool_node; // TODO remove this new variable
// convention
ASSERT_EQ(avgpool2d->convention(), loco::AvgPool2D::Convention::Valid);