#include "schema_generated.h"
+#include "TestGraph.h"
#include "GraphBlock.h"
+#include "Knob.h"
#include <loco/IR/PermutingCodec.h>
#include <stdex/Memory.h>
// TODO TFLTanh
-// TODO TFLTranspose
+TEST(TFLExporterImplTest, Transpose_simple)
+{
+ // TODO Remove this check when we remove related-knobs
+ // Without setting these knob, exo::TFLExporter::Impl exporter{g.graph()} fails.
+ // For this reason, this check is needed.
+ // Note: the default value of these knobs is `true`
+ if (!(exo::get<exo::Knob::EnableTFLDialect>() && exo::get<exo::Knob::ConvertTensorTranspose>()))
+ return;
+
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::Transpose> g;
+
+ // pull attribute
+ {
+ g.pull->dtype(loco::DataType::FLOAT32);
+ g.pull->shape({1, 2, 2, 3});
+ }
+
+ // transpose attribute
+ {
+ g.transpose->perm()->size(4);
+ g.transpose->perm()->axis(0) = 1;
+ g.transpose->perm()->axis(1) = 2;
+ g.transpose->perm()->axis(2) = 3;
+ g.transpose->perm()->axis(3) = 0;
+ }
+
+ exo::TFLExporter::Impl exporter{g.graph()};
+ {
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto operators = model->subgraphs()->Get(0)->operators();
+
+ assert(operators->Length() == 1);
+
+ int n = 0; // op index of Transpose in tflite file
+
+ auto opcode_index = operators->Get(n)->opcode_index();
+
+ ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+ tflite::BuiltinOperator_TRANSPOSE);
+
+ auto perm = operators->Get(n)->inputs()->Get(1);
+
+ auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+ ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+ ASSERT_EQ(perm_tensor->shape()->size(), 1);
+ ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+ auto bufs = (model->buffers());
+ auto *perm_buf =
+ reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+
+ ASSERT_EQ(perm_buf[0], 1);
+ ASSERT_EQ(perm_buf[1], 2);
+ ASSERT_EQ(perm_buf[2], 3);
+ ASSERT_EQ(perm_buf[3], 0);
+ }
+}
+
+/*
+ test case:
+ Pull ----- FeatureEncode ---- FeatureDecode --- Push
+ 0 -----------> H ---------+ O 0
+ 1 W +----> H -----------> 1
+ 2 I(depth) W 2
+ 3 O(coutn) I 3
+
+ axis 0 ----------> H --------------> H -----------> 1
+ axis 1 ----------> W --------------> W -----------> 2
+ axis 2 ----------> I --------------> I -----------> 3
+ axis 3 ----------> O --------------> O -----------> 0
+
+ So, perm vector of Tranpose = [3, 0, 1, 2].
+ Please refer to loco::TensorTranspose about the definition of perm vector.
+*/
+TEST(TFLExporterImplTest, Transpose_from_FilterEncode_FilterDecode)
+{
+ // TODO Remove this check when we remove related-knobs
+ // Without setting these knob, exo::TFLExporter::Impl exporter{g.graph()} fails.
+ // For this reason, this check is needed.
+ // Note: the default value of these knobs is `true`
+ if (!(exo::get<exo::Knob::EnableTFLDialect>() && exo::get<exo::Knob::ConvertTensorTranspose>()))
+ return;
+
+ exo::test::ExampleGraph<exo::test::ExampleGraphType::FilterEncode_FilterDecode> g;
+
+ // pull attribute
+ {
+ g.pull->dtype(loco::DataType::FLOAT32);
+ g.pull->shape({1, 2, 3, 4}); // whatever value of rank 4
+ }
+
+ exo::TFLExporter::Impl exporter{g.graph()};
+ {
+ auto model = tflite::GetModel(exporter.getBufferPointer());
+ auto operators = model->subgraphs()->Get(0)->operators();
+
+ assert(operators->Length() == 1);
+
+ int n = 0; // op index of Transpose in tflite file
+
+ auto opcode_index = operators->Get(n)->opcode_index();
+
+ ASSERT_EQ(model->operator_codes()->Get(opcode_index)->builtin_code(),
+ tflite::BuiltinOperator_TRANSPOSE);
+
+ auto perm = operators->Get(n)->inputs()->Get(1);
+
+ auto perm_tensor = model->subgraphs()->Get(0)->tensors()->Get(perm);
+ ASSERT_EQ(perm_tensor->type(), tflite::TensorType::TensorType_INT32);
+ ASSERT_EQ(perm_tensor->shape()->size(), 1);
+ ASSERT_EQ(perm_tensor->shape()->Get(0), 4);
+
+ auto bufs = (model->buffers());
+ auto *perm_buf =
+ reinterpret_cast<const int32_t *>(bufs->Get(perm_tensor->buffer())->data()->data());
+ ASSERT_EQ(perm_buf[0], 3);
+ ASSERT_EQ(perm_buf[1], 0);
+ ASSERT_EQ(perm_buf[2], 1);
+ ASSERT_EQ(perm_buf[3], 2);
+ }
+}
/**
* What happens when there is a mismatch between generation and execution order!?
void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); };
void setInput(locoex::TFLAveragePool2D *node, loco::Node *input) { node->value(input); };
void setInput(locoex::TFLMaxPool2D *node, loco::Node *input) { node->value(input); };
{
FeatureBiasAdd,
ConstGen_ReLU,
+ FilterEncode_FilterDecode,
+ Transpose,
TFLTranspose,
};
};
/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- Transpose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph
+{
+public:
+ loco::TensorTranspose *transpose = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ transpose = append<loco::TensorTranspose>(pull);
+ complete(transpose);
+ }
+
+ loco::Graph *graph() { return g.get(); }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- FilterEncode -- FilterDecode -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph
+{
+public:
+ loco::FilterEncode *filterEncode = nullptr;
+ loco::FilterDecode *filterDecode = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ filterEncode = exo::make_filter_encode<exo::FilterLayout::HWIO>(pull); // from Tensorflow
+ filterDecode =
+ exo::make_filter_decode<exo::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite
+ complete(filterDecode);
+ }
+
+ loco::Graph *graph() { return g.get(); }
+};
+
+/**
* @brief Class to create the following:
*
* Pull -- TFLTranspose -- Push