[exo-tflite] Test to check if transpose is exported OK (#8000)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Fri, 11 Oct 2019 06:59:11 +0000 (15:59 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 11 Oct 2019 06:59:11 +0000 (15:59 +0900)
* [exo-tflite] Test to check if transpose is exported OK

This adds two tests to see if transpose is exported OK.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* knob check

* revise comment for perm vector; removed unused vars

* add todo; replace 'if' to 'ASSERT_EQ'

compiler/exo-tflite/src/TFLExporterImpl.test.cpp
compiler/exo-tflite/src/TestGraph.h

index 14224b3..4746df9 100644 (file)
@@ -18,7 +18,9 @@
 
 #include "schema_generated.h"
 
+#include "TestGraph.h"
 #include "GraphBlock.h"
+#include "Knob.h"
 
 #include <loco/IR/PermutingCodec.h>
 #include <stdex/Memory.h>
@@ -141,7 +143,127 @@ TEST_F(TFLExporterImplTests, Relu6)
 
 // TODO TFLTanh
 
-// TODO TFLTranspose
+TEST(TFLExporterImplTest, Transpose_simple)
+{
+  // TODO Remove this check when we remove related-knobs
+  // Without setting these knob, exo::TFLExporter::Impl exporter{g.graph()} fails.
+  // For this reason, this check is needed.
+  // Note: the default value of these knobs is `true`
+  if (!(exo::get<exo::Knob::EnableTFLDialect>() && exo::get<exo::Knob::ConvertTensorTranspose>()))
+    return;
+
+  exo::test::ExampleGraph<exo::test::ExampleGraphType::Transpose> g;
+
+  // pull attribute
+  {
+    g.pull->dtype(loco::DataType::FLOAT32);
+    g.pull->shape({1, 2, 2, 3});
+  }
+
+  // transpose attribute
+  {
+    g.transpose->perm()->size(4);
+    g.transpose->perm()->axis(0) = 1;
+    g.transpose->perm()->axis(1) = 2;
+    g.transpose->perm()->axis(2) = 3;
+    g.transpose->perm()->axis(3) = 0;
+  }
+
+  exo::TFLExporter::Impl exporter{g.graph()};
+  {
+    auto model = tflite::GetModel(exporter.getBufferPointer());
+    auto operators = model->subgraphs()->Get(0)->operators();
+
+    assert(operators->Length() == 1);
+
+    int n = 0; // op index of Transpose in tflite file
+
+    auto opcode_index = operators->Get(n)->opcode_index();
+
+    ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+              tflite::BuiltinOperator_TRANSPOSE);
+
+    auto perm = operators->Get(n)->inputs()->Get(1);
+
+    auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+    ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+    ASSERT_EQ(perm_tensor->shape()->size(), 1);
+    ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+    auto bufs = (model->buffers());
+    auto *perm_buf =
+        reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+
+    ASSERT_EQ(perm_buf[0], 1);
+    ASSERT_EQ(perm_buf[1], 2);
+    ASSERT_EQ(perm_buf[2], 3);
+    ASSERT_EQ(perm_buf[3], 0);
+  }
+}
+
+/*
+  test case:
+        Pull ----- FeatureEncode ---- FeatureDecode --- Push
+          0 -----------> H ---------+      O              0
+          1              W          +----> H -----------> 1
+          2              I(depth)          W              2
+          3              O(coutn)          I              3
+
+      axis 0 ----------> H --------------> H -----------> 1
+      axis 1 ----------> W --------------> W -----------> 2
+      axis 2 ----------> I --------------> I -----------> 3
+      axis 3 ----------> O --------------> O -----------> 0
+
+      So, perm vector of Tranpose = [3, 0, 1, 2].
+      Please refer to loco::TensorTranspose about the definition of perm vector.
+*/
+TEST(TFLExporterImplTest, Transpose_from_FilterEncode_FilterDecode)
+{
+  // TODO Remove this check when we remove related-knobs
+  // Without setting these knob, exo::TFLExporter::Impl exporter{g.graph()} fails.
+  // For this reason, this check is needed.
+  // Note: the default value of these knobs is `true`
+  if (!(exo::get<exo::Knob::EnableTFLDialect>() && exo::get<exo::Knob::ConvertTensorTranspose>()))
+    return;
+
+  exo::test::ExampleGraph<exo::test::ExampleGraphType::FilterEncode_FilterDecode> g;
+
+  // pull attribute
+  {
+    g.pull->dtype(loco::DataType::FLOAT32);
+    g.pull->shape({1, 2, 3, 4}); // whatever value of rank 4
+  }
+
+  exo::TFLExporter::Impl exporter{g.graph()};
+  {
+    auto model = tflite::GetModel(exporter.getBufferPointer());
+    auto operators = model->subgraphs()->Get(0)->operators();
+
+    assert(operators->Length() == 1);
+
+    int n = 0; // op index of Transpose in tflite file
+
+    auto opcode_index = operators->Get(n)->opcode_index();
+
+    ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+              tflite::BuiltinOperator_TRANSPOSE);
+
+    auto perm = operators->Get(n)->inputs()->Get(1);
+
+    auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+    ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+    ASSERT_EQ(perm_tensor->shape()->size(), 1);
+    ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+    auto bufs = (model->buffers());
+    auto *perm_buf =
+        reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+    ASSERT_EQ(perm_buf[0], 3);
+    ASSERT_EQ(perm_buf[1], 0);
+    ASSERT_EQ(perm_buf[2], 1);
+    ASSERT_EQ(perm_buf[3], 2);
+  }
+}
 
 /**
  * What happens when there is a mismatch between generation and execution order!?
index 7afc9e9..9b5a2a2 100644 (file)
@@ -109,6 +109,7 @@ private:
   void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
   void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
   void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
+  void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); };
 
   void setInput(locoex::TFLAveragePool2D *node, loco::Node *input) { node->value(input); };
   void setInput(locoex::TFLMaxPool2D *node, loco::Node *input) { node->value(input); };
@@ -155,6 +156,8 @@ enum class ExampleGraphType
 {
   FeatureBiasAdd,
   ConstGen_ReLU,
+  FilterEncode_FilterDecode,
+  Transpose,
 
   TFLTranspose,
 };
@@ -220,6 +223,49 @@ public:
 };
 
 /**
+ * @brief Class to creates the following:
+ *
+ *     Pull -- Transpose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph
+{
+public:
+  loco::TensorTranspose *transpose = nullptr;
+
+public:
+  ExampleGraph()
+  {
+    transpose = append<loco::TensorTranspose>(pull);
+    complete(transpose);
+  }
+
+  loco::Graph *graph() { return g.get(); }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ *     Pull -- FilterEncode -- FilterDecode -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph
+{
+public:
+  loco::FilterEncode *filterEncode = nullptr;
+  loco::FilterDecode *filterDecode = nullptr;
+
+public:
+  ExampleGraph()
+  {
+    filterEncode = exo::make_filter_encode<exo::FilterLayout::HWIO>(pull); // from Tensorflow
+    filterDecode =
+        exo::make_filter_decode<exo::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite
+    complete(filterDecode);
+  }
+
+  loco::Graph *graph() { return g.get(); }
+};
+
+/**
  * @brief Class to create the following:
  *
  *     Pull -- TFLTranspose -- Push