Add TFLite v3 model console dumper (#171)
authorDmitry Mozolev/AI Tools Lab/Engineer/삼성전자 <d.mozolev@samsung.com>
Thu, 24 May 2018 13:27:33 +0000 (16:27 +0300)
committerSergey Vostokov/AI Tools Lab/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Thu, 24 May 2018 13:27:33 +0000 (16:27 +0300)
Using visitor interface to print TFLite v3 model contents.
Not all content is printed so far, but the code can easily be extended.

Signed-off-by: Dmitry Mozolev <d.mozolev@samsung.com>
contrib/nnc/libs/frontend/tflite/include/tflite_dump_visitor.h [new file with mode: 0644]
contrib/nnc/libs/frontend/tflite/src/tflite_dump_visitor.cpp [new file with mode: 0644]
contrib/nnc/libs/frontend/tflite/src/tflite_importer.cpp

diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_dump_visitor.h b/contrib/nnc/libs/frontend/tflite/include/tflite_dump_visitor.h
new file mode 100644 (file)
index 0000000..8b489f7
--- /dev/null
@@ -0,0 +1,41 @@
+#ifndef NNCC_TFLITE_DUMP_VISITOR_H
+#define NNCC_TFLITE_DUMP_VISITOR_H
+
+#include <vector>
+
+#include "schema_v3_generated.h"
+#include "tflite_visitor.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace frontend
+{
+namespace tflite
+{
+
+class DumpVisitor : public Visitor
+{
+public:
+  void visit(const Model *) override;
+  void visit(const SubGraph *) override;
+  void visit(const Tensor *) override;
+  void visit(const OperatorCode *) override;
+  void visit(const Operator *) override;
+  void visit(const Buffer *) override;
+
+private:
+  // TODO: add counter reset mechanism or restructure the code
+  int tensorCnt = 0;
+  int bufferCnt = 0;
+
+  std::vector<const char *> opNames;
+};
+
+} // namespace tflite
+} // namespace frontend
+} // namespace contrib
+} // namespace nnc
+
+#endif // NNCC_TFLITE_DUMP_VISITOR_H
diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_dump_visitor.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_dump_visitor.cpp
new file mode 100644 (file)
index 0000000..a0f4b93
--- /dev/null
@@ -0,0 +1,116 @@
+#include <iostream>
+
+#include "tflite_dump_visitor.h"
+
+using std::cout;
+using std::endl;
+
+static std::ostream &operator<<(std::ostream &os, const flatbuffers::Vector<int32_t> *v);
+static std::ostream &operator<<(std::ostream &os, Padding pad);
+static std::ostream &operator<<(std::ostream &os, ActivationFunctionType act);
+
+namespace nncc
+{
+namespace contrib
+{
+namespace frontend
+{
+namespace tflite
+{
+
+void DumpVisitor::visit(const Model *m)
+{
+  cout << "[Model version]: " << m->version() << endl;
+  cout << "[Model description]: " << m->description()->data() << endl;
+  cout << "[Model info]: " << m->subgraphs()->size() << " subgraphs" << endl;
+  cout << "[Model info]: " << m->buffers()->size() << " buffers" << endl;
+}
+
+void DumpVisitor::visit(const SubGraph *s)
+{
+  cout << "[Subgraph]: \"" << (s->name() ? s->name()->c_str() : "\0") << "\"" << endl;
+  cout << "[Subgraph inputs]: " << s->inputs() << endl;
+  cout << "[Subgraph outputs]: " << s->outputs() << endl;
+  cout << "[Subgraph info]: " << s->tensors()->size() << " tensors" << endl;
+}
+
+void DumpVisitor::visit(const Buffer *b)
+{
+  cout << "[Buffer " << bufferCnt++ << "]: size: " << (b->data() ? b->data()->size() : 0) << endl;
+}
+
+void DumpVisitor::visit(const Tensor *t)
+{
+  cout << "[Tensor " << tensorCnt++ << "]: \"" << t->name()->data() << "\"" << endl;
+  cout << "  [Tensor shape]: " << t->shape() << endl;
+  cout << "  [Tensor buffer]: " << t->buffer() << endl;
+}
+
+void DumpVisitor::visit(const Operator *op)
+{
+  cout << "[Operator]: " << opNames[op->opcode_index()] << endl;
+  cout << "  [Operator inputs]: " << op->inputs() << endl;
+  cout << "  [Operator outputs]: " << op->outputs() << endl;
+
+  switch (op->builtin_options_type())
+  {
+  case BuiltinOptions::BuiltinOptions_Conv2DOptions:
+  {
+    const Conv2DOptions *opts = op->builtin_options_as<Conv2DOptions>();
+    cout << "  [Padding]: " << opts->padding() << endl;
+    cout << "  [Strides]: " << opts->stride_w() << ", " << opts->stride_h() << endl;
+    cout << "  [Activation]: " << opts->fused_activation_function() << endl;
+    break;
+  }
+  case BuiltinOptions::BuiltinOptions_DepthwiseConv2DOptions:
+  {
+    const DepthwiseConv2DOptions *opts = op->builtin_options_as<DepthwiseConv2DOptions>();
+    cout << "  [Padding]: " << opts->padding() << endl;
+    cout << "  [Strides]: " << opts->stride_w() << ", " << opts->stride_h() << endl;
+    cout << "  [Activation]: " << opts->fused_activation_function() << endl;
+    cout << "  [DepthMultiplier]: " << opts->depth_multiplier() << endl;
+    break;
+  }
+  case BuiltinOptions::BuiltinOptions_ReshapeOptions:
+  {
+    const ReshapeOptions *opts = op->builtin_options_as<ReshapeOptions>();
+    cout << "  [New shape]: " << opts->new_shape() << endl;
+    break;
+  }
+  }
+}
+
+void DumpVisitor::visit(const OperatorCode *oc)
+{
+  opNames.push_back(EnumNamesBuiltinOperator()[oc->builtin_code()]);
+  cout << "[Model operator]: " << opNames.back() << endl;
+}
+
+} // namespace tflite
+} // namespace frontend
+} // namespace contrib
+} // namespace nncc
+
+static std::ostream &operator<<(std::ostream &os, const flatbuffers::Vector<int32_t> *v)
+{
+  for (int i = 0; i < v->size(); ++i)
+  {
+    if (i != 0)
+      os << ", ";
+    os << (*v)[i];
+  }
+
+  return os;
+}
+
+static std::ostream &operator<<(std::ostream &os, Padding pad)
+{
+  os << EnumNamesPadding()[pad];
+  return os;
+}
+
+static std::ostream &operator<<(std::ostream &os, ActivationFunctionType act)
+{
+  os << EnumNamesActivationFunctionType()[act];
+  return os;
+}
index dbe2122..a542bcd 100644 (file)
@@ -1,4 +1,6 @@
 #include "tflite_importer.h"
+#include "tflite_dump_visitor.h"
+#include "tflite_walker.h"
 
 namespace nncc
 {
@@ -55,7 +57,10 @@ void *TfliteImporter::createIR()
 
 void TfliteImporter::dump()
 {
-  // TODO: implement
+  DumpVisitor dumper{};
+  ModelWalker walker{std::vector<Visitor *>{&dumper}};
+
+  walker.walk(modelPacked);
 }
 
 } // namespace tflite