[exo] Add tests for TFLConcatenate (#8273)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Thu, 17 Oct 2019 10:15:17 +0000 (19:15 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 17 Oct 2019 10:15:17 +0000 (19:15 +0900)
This commit will add tests for `TFLConcatenate` in `exo`

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
compiler/exo/src/Dialect/IR/TFLNodes.test.cpp
compiler/exo/src/TFLite/TFLExporterImpl.test.cpp

index f04505b..f4289cd 100644 (file)
@@ -34,7 +34,19 @@ TEST(TFLAddTest, constructor)
 
 // TODO TFLAveragePool2D
 
-// TODO TFLConcatenation
+TEST(TFLConcatTest, constructor)
+{
+  locoex::TFLConcatenation concat_node(3);
+
+  ASSERT_EQ(concat_node.dialect(), locoex::TFLDialect::get());
+  ASSERT_EQ(concat_node.opcode(), locoex::TFLOpcode::CONCATENATION);
+
+  ASSERT_EQ(concat_node.numValues(), 3);
+  ASSERT_EQ(concat_node.values(0), nullptr);
+  ASSERT_EQ(concat_node.values(1), nullptr);
+  ASSERT_EQ(concat_node.values(2), nullptr);
+  ASSERT_EQ(concat_node.fusedActivationFunction(), locoex::FusedActFunc::UNDEFINED);
+}
 
 // TODO TFLConv2D
 
index 4746df9..25fa721 100644 (file)
@@ -86,7 +86,49 @@ template <> loco::FeatureDecode *TFLExporterImplTests::make_node(void)
 
 // TODO TFLAveragePool2D
 
-// TODO TFLConcatenation
+TEST_F(TFLExporterImplTests, Concatenate)
+{
+  auto pull1 = make_node<loco::Pull>();
+  {
+    pull1->dtype(loco::DataType::FLOAT32);
+    pull1->shape({1, 2, 3, 4});
+  }
+  auto pull2 = make_node<loco::Pull>();
+  {
+    pull2->dtype(loco::DataType::FLOAT32);
+    pull2->shape({1, 2, 3, 4});
+  }
+  auto concat = make_node<loco::TensorConcat>();
+  {
+    concat->lhs(pull1);
+    concat->rhs(pull2);
+  }
+  auto push = make_node<loco::Push>();
+  {
+    push->from(concat);
+  }
+
+  auto input1 = graph()->inputs()->create();
+  {
+    input1->name("input1");
+    loco::link(input1, pull1);
+  }
+  auto input2 = graph()->inputs()->create();
+  {
+    input2->name("input2");
+    loco::link(input2, pull2);
+  }
+  auto output = graph()->outputs()->create();
+  {
+    output->name("output");
+    loco::link(output, push);
+  }
+
+  exo::TFLExporter::Impl exporter{graph()};
+
+  // TODO Add more checks
+  SUCCEED();
+}
 
 // TODO TFLConv2D