Add TFLite importer plugin (#420)
authorDmitry Mozolev/AI Tools Lab /SRR/Engineer/삼성전자 <d.mozolev@samsung.com>
Tue, 3 Jul 2018 07:36:44 +0000 (10:36 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Tue, 3 Jul 2018 07:36:44 +0000 (16:36 +0900)
Add TFLite importer plugin

This commit introduces TFLite importer to the plugin system.

Signed-off-by: Dmitry Mozolev <d.mozolev@samsung.com>
contrib/nnc/libs/frontend/tflite/CMakeLists.txt
contrib/nnc/libs/frontend/tflite/src/tflite_plugin.cpp [new file with mode: 0644]

index c4c6e40..9965286 100644 (file)
@@ -16,7 +16,8 @@ set(tflite_importer_sources ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_walker.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_dump_visitor.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_ir_visitor.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_op_creator.cpp
-                            ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_v3_importer.cpp)
+                            ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_v3_importer.cpp
+                            ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_plugin.cpp)
 file(GLOB tflite_importer_headers include/*.h)
 list(APPEND tflite_importer_headers ${FB_GEN_SOURCES})
 
diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_plugin.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_plugin.cpp
new file mode 100644 (file)
index 0000000..eaf36d7
--- /dev/null
@@ -0,0 +1,109 @@
+#include <map>
+#include <vector>
+#include <iostream>
+
+#include "PluginType.h"
+#include "PluginInstance.h"
+#include "PluginException.h"
+#include "ConfigException.h"
+
+#include "tflite_v3_importer.h"
+
+namespace
+{
+const std::string pluginName = "Tensorflow Lite importer";
+const std::string pluginVersion = "0.0.1";
+const std::string pluginDesc = "Converts Tensorflow Lite v3 model to Model IR";
+
+const auto pluginType = nncc::contrib::plugin::typeFrontEnd;
+
+const auto inputFilenameOption = "input-filename";
+
+using namespace nncc::contrib::config;
+using namespace nncc::contrib::plugin;
+
+class ImporterPlugin : public AbstractPluginInstance
+{
+public:
+    ImporterPlugin &operator=(const ImporterPlugin &) = delete;
+    ImporterPlugin(const ImporterPlugin &) = delete;
+
+    static AbstractPluginInstance &getInstance();
+    void fillSession() override;
+    void checkConfig() override;
+    void *execute(void *data) override;
+
+    void setParam(const std::string &name) override;
+    void setParam(const std::string &name, const std::string &value) override;
+
+private:
+    std::string _filename;
+
+private:
+    ImporterPlugin() = default;
+    ~ImporterPlugin() override = default;
+};
+
+AbstractPluginInstance &ImporterPlugin::getInstance()
+{
+  static ImporterPlugin instance;
+  return instance;
+}
+
+void ImporterPlugin::fillSession()
+{
+  static std::map<std::string, std::string> info = {{"module description", pluginDesc}};
+
+  static std::vector<PluginParam> moduleParams =
+          {{inputFilenameOption, "path to Tensorflow Lite model file", false}};
+
+  AbstractPluginInstance::fillSessionBase(pluginType, pluginVersion, pluginName);
+
+  for (auto &i : info)
+    getSession()->addInfo(i.first, i.second);
+
+  for (auto &p : moduleParams)
+    getSession()->registerParam(p);
+}
+
+void ImporterPlugin::checkConfig()
+{
+}
+
+void *ImporterPlugin::execute(void *)
+{
+  nncc::contrib::frontend::tflite::v3::TfliteImporter importer{_filename};
+
+  bool success = importer.import();
+
+  if (!success)
+  {
+    throw nncc::contrib::PluginException("Could not load model: " + _filename + "\n");
+  };
+
+  return importer.createIR();
+}
+
+void ImporterPlugin::setParam(const std::string &name)
+{
+  throw nncc::contrib::ConfigException("unsupported parameter <" + name + ">");
+}
+
+void ImporterPlugin::setParam(const std::string &name, const std::string &value)
+{
+  if (name == inputFilenameOption)
+  {
+    _filename = value;
+  }
+  else
+  {
+    throw nncc::contrib::ConfigException("unsupported parameter <" + name + ">");
+  }
+}
+
+} // anonymous namespace
+
+extern "C" AbstractPluginInstance *get_instance()
+{
+  return &ImporterPlugin::getInstance();
+}