[moco-tf] Revise importer test (#5960)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 30 Jul 2019 00:18:00 +0000 (09:18 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 30 Jul 2019 00:18:00 +0000 (09:18 +0900)
This will revise importer test to do test with TFIdentity IR that uses Identity node as a test material

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Importer.test.cpp

index 771dbb5..770984b 100644 (file)
@@ -18,6 +18,9 @@
 
 #include "TestHelper.h"
 
+#include "IR/TFIdentity.h"
+#include "Op/Identity.h"
+
 #include <loco.h>
 #include <plier/tf/TestHelper.h>
 
@@ -79,7 +82,6 @@ node {
 
 TEST(TensorFlowImport, load_model_withio)
 {
-  moco::tf::Importer importer;
   moco::tf::ModelSignature signature;
 
   signature.add_input(moco::tf::TensorName("Placeholder", 0));
@@ -87,13 +89,25 @@ TEST(TensorFlowImport, load_model_withio)
 
   tensorflow::GraphDef graph_def;
   EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+  using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::Canonical>;
+
+  moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+  r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+  moco::tf::Importer importer{&r};
+
   std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
-  loco::Graph::InputContext *inputs = graph->inputs();
-  ASSERT_EQ(inputs->size(), 1);
-  loco::GraphInput *input = inputs->at(0);
-  ASSERT_EQ(input->dtype(), loco::DataType::FLOAT32);
-  loco::Pull *pull = input->node();
+  // what to test:
+  // - import reads Pull
+  // - import reads Forward
+  // - attribute values should match
+
+  auto pull = find_first_node_bytype<loco::Pull>(graph.get());
+  ASSERT_NE(pull, nullptr);
+  auto forward = find_first_node_bytype<loco::Forward>(graph.get());
+  ASSERT_NE(forward, nullptr);
+
   ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32);
   ASSERT_EQ(pull->rank(), 4);
   loco::Dimension dim1 = 1;
@@ -102,25 +116,33 @@ TEST(TensorFlowImport, load_model_withio)
   ASSERT_EQ(pull->dim(1).value(), dim2.value());
   ASSERT_EQ(pull->dim(2).value(), dim1.value());
   ASSERT_EQ(pull->dim(3).value(), dim2.value());
+}
+
+TEST(TensorFlowImport, load_model_withio_tf)
+{
+  moco::tf::ModelSignature signature;
+
+  signature.add_input(moco::tf::TensorName("Placeholder", 0));
+  signature.add_output(moco::tf::TensorName("output/identity", 0));
+
+  tensorflow::GraphDef graph_def;
+  EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+  using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+  moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+  // TODO add Placeholder
+  r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+  moco::tf::Importer importer{&r};
+
+  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+  // what to test:
+  // - import reads Placeholder
+  // - import reads Identity
+  // - attribute values should match
 
-  loco::Graph::OutputContext *outputs = graph->outputs();
-  ASSERT_EQ(outputs->size(), 1);
-  loco::GraphOutput *output = outputs->at(0);
-  ASSERT_EQ(output->dtype(), loco::DataType::FLOAT32);
-  loco::Push *push = output->node();
-  // Currently we don't know the shape of output node(s)
-  // ASSERT_EQ(push->rank(), 4);
-  // ASSERT_EQ(push->dim(0).value(), dim1.value());
-  // ASSERT_EQ(push->dim(1).value(), dim2.value());
-  // ASSERT_EQ(push->dim(2).value(), dim1.value());
-  // ASSERT_EQ(push->dim(3).value(), dim2.value());
-
-  loco::Graph::NodeContext *nodes = graph->nodes();
-  ASSERT_EQ(nodes->size(), 3);
-  loco::Pull *node0 = dynamic_cast<loco::Pull *>(nodes->at(0));
-  ASSERT_EQ(node0, pull);
-  loco::Push *node2 = dynamic_cast<loco::Push *>(nodes->at(2));
-  ASSERT_EQ(node2, push);
-  loco::Forward *node1 = dynamic_cast<loco::Forward *>(nodes->at(1));
-  ASSERT_NE(node1, nullptr);
+  auto tfidentity = find_first_node_bytype<moco::tf::TFIdentity>(graph.get());
+  ASSERT_NE(tfidentity, nullptr);
+  ASSERT_NE(tfidentity->input(), nullptr);
 }