[nnc] Create interface for NN importers (#2728)
authorРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Wed, 19 Dec 2018 18:20:47 +0000 (21:20 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 19 Dec 2018 18:20:47 +0000 (21:20 +0300)
* added factory method to create importer objects
* added template method to manage `run` method for all importers
* ONNX importer wasn't changed

Signed-off-by: Roman Rusyaev <r.rusyaev@samsung.com>
19 files changed:
contrib/nnc/driver/Driver.cpp
contrib/nnc/include/passes/common_frontend/NNImporter.h [moved from contrib/nnc/include/passes/common_frontend/nn_importer.h with 57% similarity]
contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_importer.h [moved from contrib/nnc/include/passes/caffe2_frontend/caffe2_importer.h with 90% similarity]
contrib/nnc/passes/caffe2_frontend/caffe2_op_types.h
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_importer.h [moved from contrib/nnc/include/passes/caffe_frontend/caffe_importer.h with 74% similarity]
contrib/nnc/passes/caffe_frontend/caffe_op_types.h
contrib/nnc/passes/common_frontend/CMakeLists.txt
contrib/nnc/passes/common_frontend/NNImporter.cpp [new file with mode: 0644]
contrib/nnc/passes/tflite_frontend/CMakeLists.txt
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_importer.h [moved from contrib/nnc/include/passes/tflite_frontend/tflite_importer.h with 88% similarity]
contrib/nnc/tests/import/caffe.cpp
contrib/nnc/tests/import/tflite.cpp
contrib/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp
contrib/nnc/utils/caffe2_dot_dumper/model_dump.cpp
contrib/nnc/utils/caffe_dot_dumper/model_dump.cpp
contrib/nnc/utils/tflite_dot_dumper/sanity_check.cpp

index f8798e1..ee57a3a 100644 (file)
 
 #include "pass/PassData.h"
 
-#include "passes/caffe_frontend/caffe_importer.h"
-#include "passes/caffe2_frontend/caffe2_importer.h"
-#include "passes/tflite_frontend/tflite_importer.h"
+#include "passes/common_frontend/NNImporter.h"
+
 #include "passes/interpreter/InterpreterPass.h"
 #include "passes/soft_backend/CPPGenerator.h"
 #include "passes/acl_soft_backend/AclCppGenerator.h"
-#include "passes/onnx_frontend/ONNXImporter.h"
 #include "support/CommandLine.h"
 #include "Definitions.h"
 #include "option/Options.h"
@@ -52,21 +50,17 @@ void Driver::runPasses() {
 static std::string getFrontendOptionsString() {
   std::string res;
 
-#ifdef NNC_FRONTEND_CAFFE_ENABLED
-  res += " '" + cli::caffeFrontend.getNames()[0] + "' ";
-#endif // NNC_FRONTEND_CAFFE_ENABLED
+  if (!cli::caffeFrontend.isDisabled())
+    res += " '" + cli::caffeFrontend.getNames()[0] + "' ";
 
-#ifdef NNC_FRONTEND_CAFFE2_ENABLED
-  res += " '" + cli::caffe2Frontend.getNames()[0] + "' ";
-#endif // NNC_FRONTEND_CAFFE2_ENABLED
+  if (!cli::caffe2Frontend.isDisabled())
+    res += " '" + cli::caffe2Frontend.getNames()[0] + "' ";
 
-#ifdef NNC_FRONTEND_ONNX_ENABLED
-  res += " '" + cli::onnxFrontend.getNames()[0] + "' ";
-#endif // NNC_FRONTEND_ONNX_ENABLED
+  if (!cli::onnxFrontend.isDisabled())
+    res += " '" + cli::onnxFrontend.getNames()[0] + "' ";
 
-#ifdef NNC_FRONTEND_TFLITE_ENABLED
-  res += " '" + cli::tflFrontend.getNames()[0] + "' ";
-#endif // NNC_FRONTEND_TFLITE_ENABLED
+  if (!cli::tflFrontend.isDisabled())
+    res += " '" + cli::tflFrontend.getNames()[0] + "' ";
 
   return res;
 }
@@ -77,40 +71,20 @@ static std::string getFrontendOptionsString() {
  */
 void Driver::registerFrontendPass() {
 
-  std::unique_ptr<Pass> pass;
-
   // For bool, the value false is converted to zero and the value true is converted to one
   if (cli::caffeFrontend + cli::caffe2Frontend + cli::tflFrontend + cli::onnxFrontend != 1)
     throw DriverException("One and only one of the following options are allowed and have to be set"
                           " to be set in the same time: " + getFrontendOptionsString());
 
-  if (cli::caffeFrontend) {
-#ifdef NNC_FRONTEND_CAFFE_ENABLED
-    pass = std::move(std::unique_ptr<Pass>(new CaffeImporter(cli::inputFile)));
-#endif // NNC_FRONTEND_CAFFE_ENABLED
-  } else if (cli::caffe2Frontend) {
-#ifdef NNC_FRONTEND_CAFFE2_ENABLED
-    // FIXME: caffe2 input shapes are not provided by model and must be set from cli
-    // current 'inputShapes' could provide only one shape, while model could has several inputs
-    pass = std::move(std::unique_ptr<Pass>(new Caffe2Importer(cli::inputFile, cli::initNet,
-                                                              {cli::inputShapes})));
-#endif // NNC_FRONTEND_CAFFE2_ENABLED
-  } else if ( cli::onnxFrontend ) {
-#ifdef NNC_FRONTEND_ONNX_ENABLED
-    pass = std::move(std::unique_ptr<Pass>(new ONNXImporter()));
-#endif // NNC_FRONTEND_ONNX_ENABLED
-  }
-  else if ( cli::tflFrontend ) {
-#ifdef NNC_FRONTEND_TFLITE_ENABLED
-    pass = std::move(std::unique_ptr<Pass>(new TfliteImporter(cli::inputFile)));
-#endif // NNC_FRONTEND_TFLITE_ENABLED
+  std::unique_ptr<Pass> pass = NNImporter::createNNImporter();
+
+  if (pass) {
+    _passManager.registerPass(std::move(pass));
   } else {
     throw DriverException("One of the following options must be defined: '"
                           + getFrontendOptionsString());
   }
 
-  _passManager.registerPass(std::move(pass));
-
 } // registerFrontendPass
 
 /**
 #ifndef FRONTEND_COMMON_INCLUDE_NN_IMPORTER_
 #define FRONTEND_COMMON_INCLUDE_NN_IMPORTER_
 
+#include "pass/Pass.h"
 #include "core/modelIR/Graph.h"
 
 namespace nnc {
 
-class NNImporter {
+/**
+ * @brief Interface for all frontends. All who uses frontends must do it thought this interface
+ */
+class NNImporter : public Pass {
 public:
-  virtual ~NNImporter() = default;
+  // template method pattern
+  PassData run(PassData) final {
+    import();
+    return createIR();
+  }
+
+  static std::unique_ptr<NNImporter> createNNImporter();
 
+  void cleanup() override {}
+
+  /**
+  * @brief Import model from file, must be called before 'createIR' method
+  * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it
+  */
   virtual void import() = 0;
+
+  /**
+   * @brief Create MIR graph from caffe model, must be called after 'import' method
+   * @return MIR graph, corresponding to processed caffe model
+   */
   virtual mir::Graph *createIR() = 0;
 };
 
index 1b874c3..e2a5fa8 100644 (file)
 #include <sstream>
 #include <cassert>
 
-#include "passes/caffe2_frontend/caffe2_importer.h"
+#include "caffe2_importer.h"
 #include "passes/common_frontend/shape_helper.h"
 #include "passes/common_frontend/proto_helper.h"
 
-#include "caffe2/proto/caffe2.pb.h"
-
 #include "caffe2_op_types.h"
 #include "caffe2_op_creator.h"
 
@@ -52,11 +50,6 @@ Caffe2Importer::Caffe2Importer(std::string predictNet, std::string initNet,
 
 Caffe2Importer::~Caffe2Importer() = default;
 
-PassData Caffe2Importer::run(PassData) {
-  import();
-  return createIR();
-}
-
 void Caffe2Importer::cleanup() {
   delete _graph;
 }
 #include <string>
 #include <memory>
 
-#include "passes/common_frontend/nn_importer.h"
+#include "caffe2/proto/caffe2.pb.h"
+#include "caffe2_op_creator.h"
+#include "caffe2_op_types.h"
+#include "passes/common_frontend/NNImporter.h"
 
 #include "pass/Pass.h"
 #include "pass/PassData.h"
 
-// Use forward declarations for non interface classes
-namespace caffe2 {
-class OperatorDef;
-class NetDef;
-}
 namespace nnc {
-class Caffe2OpCreator;
-enum class SupportedCaffe2OpType : uint8_t;
-}
 
-namespace nnc {
-
-class Caffe2Importer : public NNImporter, public Pass {
+class Caffe2Importer : public NNImporter {
 public:
   explicit Caffe2Importer(std::string predictNet, std::string initNet,
                           std::vector<std::vector<int>> inputShapes);
@@ -56,7 +49,6 @@ public:
   */
   mir::Graph* createIR() override;
 
-  PassData run(PassData) override;
   void cleanup() override;
 
   ~Caffe2Importer();
index 2a1de48..414bd87 100644 (file)
@@ -19,7 +19,7 @@
 
 namespace nnc {
 
-enum class SupportedCaffe2OpType : uint8_t {
+enum class SupportedCaffe2OpType {
   add,
   averagePool,
   concat,
index cc4a8eb..149c90c 100644 (file)
@@ -19,9 +19,7 @@
 #include <sstream>
 #include <cassert>
 
-#include "caffe/proto/caffe.pb.h"
-
-#include "passes/caffe_frontend/caffe_importer.h"
+#include "caffe_importer.h"
 #include "caffe_op_creator.h"
 #include "caffe_op_types.h"
 
@@ -220,11 +218,6 @@ void CaffeImporter::setGraphOutputs() {
   _graph->markOutput(_blobNameToIODescriptor[last_layer.top(0)].op);
 }
 
-PassData CaffeImporter::run(PassData) {
-  import();
-  return createIR();
-}
-
 void CaffeImporter::cleanup() {
   delete _graph;
 }
 #include <string>
 #include <memory>
 
-#include "passes/common_frontend/nn_importer.h"
-
-#include "pass/Pass.h"
-#include "pass/PassData.h"
-
-// Use forward declarations for non interface classes
-namespace caffe {
-class BlobProto;
-class LayerParameter;
-class NetParameter;
-}
-namespace nnc {
-class CaffeOpCreator;
-enum class CaffeOpType : uint8_t;
-}
+#include "caffe/proto/caffe.pb.h"
+#include "caffe_op_creator.h"
+#include "caffe_op_types.h"
+#include "passes/common_frontend/NNImporter.h"
 
 namespace nnc {
 
-class CaffeImporter : public NNImporter, public Pass {
+class CaffeImporter : public NNImporter {
 public:
   explicit CaffeImporter(std::string filename);
 
-  /**
-  * @brief Import model from file, must be called before 'createIR' method
-  * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it
-  */
   void import() override;
-
-  /**
-  * @brief Create MIR graph from caffe model, must be called after 'import' method
-  * @return MIR graph, corresponding to processed caffe model
-  */
   mir::Graph* createIR() override;
 
-  PassData run(PassData) override;
   void cleanup() override;
 
   ~CaffeImporter();
index 1e76f4e..887112b 100644 (file)
@@ -19,7 +19,7 @@
 
 namespace nnc {
 
-enum class CaffeOpType : uint8_t {
+enum class CaffeOpType {
   absVal,
   accuracy,
   argMax,
index cb3cc2e..d9508c3 100644 (file)
@@ -7,3 +7,9 @@ set(COMMON_SOURCES model_allocation.cpp op_creator_helper.cpp)
 add_library(nn_import_common STATIC ${COMMON_SOURCES})
 set_target_properties(nn_import_common PROPERTIES POSITION_INDEPENDENT_CODE ON)
 target_link_libraries(nn_import_common PRIVATE nnc_core nnc_support)
+
+# This library depends on other frontends to provide uniform interface for those who use frontends
+set(IMPORTER_SUPPORT_SOURCES NNImporter.cpp)
+add_nnc_library(nn_importer_support STATIC ${IMPORTER_SUPPORT_SOURCES})
+target_include_directories(nn_importer_support PRIVATE ${NNC_CAFFE_FRONTEND_DIR} ${NNC_CAFFE2_FRONTEND_DIR} ${NNC_TFLITE_FRONTEND_DIR})
+target_link_libraries(nn_importer_support caffe_importer caffe2_importer tflite_importer)
diff --git a/contrib/nnc/passes/common_frontend/NNImporter.cpp b/contrib/nnc/passes/common_frontend/NNImporter.cpp
new file mode 100644 (file)
index 0000000..0654a8d
--- /dev/null
@@ -0,0 +1,41 @@
+#include <memory>
+
+#include "Definitions.h"
+#include "option/Options.h"
+#include "passes/common_frontend/NNImporter.h"
+
+#include "tflite_importer.h"
+#include "caffe_importer.h"
+#include "caffe2_importer.h"
+
+namespace nnc {
+
+std::unique_ptr<NNImporter> NNImporter::createNNImporter() {
+
+  std::unique_ptr<NNImporter> importer(nullptr);
+
+  if (cli::caffeFrontend) {
+#ifdef NNC_FRONTEND_CAFFE_ENABLED
+    importer.reset(new CaffeImporter(cli::inputFile));
+#endif // NNC_FRONTEND_CAFFE_ENABLED
+  } else if (cli::caffe2Frontend) {
+#ifdef NNC_FRONTEND_CAFFE2_ENABLED
+    // FIXME: caffe2 input shapes are not provided by model and must be set from cli
+    // current 'inputShapes' could provide only one shape, while model could has several inputs
+    importer.reset(new Caffe2Importer(cli::inputFile, cli::initNet, {cli::inputShapes}));
+#endif // NNC_FRONTEND_CAFFE2_ENABLED
+  } else if (cli::onnxFrontend) {
+#ifdef NNC_FRONTEND_ONNX_ENABLED
+    importer.reset(new ONNXImporter());
+#endif // NNC_FRONTEND_ONNX_ENABLED
+  } else if (cli::tflFrontend) {
+#ifdef NNC_FRONTEND_TFLITE_ENABLED
+    importer.reset(new TfliteImporter(cli::inputFile));
+#endif // NNC_FRONTEND_TFLITE_ENABLED
+  }
+
+  return importer;
+
+} // createNNImporter
+
+} // namespace nnc
index fce09e3..2b6d5af 100644 (file)
@@ -26,14 +26,13 @@ set(tflite_importer_sources tflite_op_creator.cpp
                             tflite_importer.cpp)
 file(GLOB tflite_importer_headers *.h)
 
-set(tflite_import tflite_import)
-add_nnc_library(${tflite_import} SHARED ${tflite_importer_sources} ${tflite_importer_headers})
+add_nnc_library(tflite_importer SHARED ${tflite_importer_sources} ${tflite_importer_headers})
 
-target_link_libraries(${tflite_import} PUBLIC tflite_schema)
-target_link_libraries(${tflite_import} PUBLIC flatbuffers)
-target_link_libraries(${tflite_import} PUBLIC nn_import_common)
-target_link_libraries(${tflite_import} PUBLIC nnc_support)
-target_link_libraries(${tflite_import} PUBLIC nnc_core)
+target_link_libraries(tflite_importer PUBLIC tflite_schema)
+target_link_libraries(tflite_importer PUBLIC flatbuffers)
+target_link_libraries(tflite_importer PUBLIC nn_import_common)
+target_link_libraries(tflite_importer PUBLIC nnc_support)
+target_link_libraries(tflite_importer PUBLIC nnc_core)
 
 # install tflite frontend library
-install_nnc_library(tflite_import)
+install_nnc_library(tflite_importer)
index f990c59..4b8296b 100644 (file)
@@ -15,7 +15,7 @@
  */
 
 #include "schema_generated.h"
-#include "passes/tflite_frontend/tflite_importer.h"
+#include "tflite_importer.h"
 #include "tflite_op_creator.h"
 
 using namespace ::tflite;
@@ -371,12 +371,6 @@ void TfliteImporter::setIrNodeNames() {
     item.second->setName((*_tensors)[item.first]->name()->c_str());
 }
 
-
-PassData TfliteImporter::run(PassData) {
-  import();
-  return createIR();
-}
-
 void TfliteImporter ::cleanup() {
   delete _graph;
 }
 #include "pass/Pass.h"
 #include "pass/PassException.h"
 #include "pass/PassData.h"
-#include "passes/common_frontend/nn_importer.h"
+#include "passes/common_frontend/NNImporter.h"
 #include "passes/common_frontend/model_allocation.h"
+#include "tflite_op_creator.h"
 
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/TensorUtil.h"
 #include "core/modelIR/TensorVariant.h"
 
-// Use forward declarations for non interface classes
-namespace flatbuffers {
-template<typename T> class Vector;
-template<typename T> struct Offset;
-}
-namespace tflite {
-struct Buffer;
-struct Model;
-struct ModelT;
-struct Operator;
-struct OperatorCode;
-struct SubGraph;
-struct Tensor;
-}
 namespace nnc {
-class TFLiteOpCreator;
-}
 
-namespace nnc {
-
-class TfliteImporter : public NNImporter, public Pass {
+class TfliteImporter : public NNImporter {
 public:
   explicit TfliteImporter(std::string filename);
 
@@ -69,7 +52,6 @@ public:
 
   void importUnpacked();
 
-  PassData run(PassData) override;
   void cleanup() override;
 
   ~TfliteImporter();
index ff0d66b..99db79f 100644 (file)
@@ -18,7 +18,7 @@
 #include "support/CommandLine.h"
 #include "option/Options.h"
 
-#include "passes/caffe_frontend/caffe_importer.h"
+#include "caffe_importer.h"
 
 using namespace nnc;
 
index bebc32e..0f20921 100644 (file)
@@ -18,7 +18,7 @@
 #include "support/CommandLine.h"
 #include "option/Options.h"
 
-#include "passes/tflite_frontend/tflite_importer.h"
+#include "tflite_importer.h"
 
 using namespace nnc;
 
index 86b0ea7..c3ee986 100644 (file)
@@ -1,4 +1,4 @@
-#include "passes/caffe_frontend/caffe_importer.h"
+#include "caffe_importer.h"
 #include "gtest/gtest.h"
 #include "pass/PassException.h"
 #include <string>
index ba4e86a..d0f9ea8 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "support/CommandLine.h"
 #include "option/Options.h"
-#include "passes/caffe2_frontend/caffe2_importer.h"
+#include "caffe2_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"
 #include "pass/PassException.h"
index 997d657..62d60e6 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "support/CommandLine.h"
 #include "option/Options.h"
-#include "passes/caffe_frontend/caffe_importer.h"
+#include "caffe_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"
 #include "pass/PassException.h"
index 481da52..be9d07a 100644 (file)
@@ -19,7 +19,7 @@
 #include "support/CommandLine.h"
 #include "pass/PassException.h"
 #include "option/Options.h"
-#include "passes/tflite_frontend/tflite_importer.h"
+#include "tflite_importer.h"
 #include "core/modelIR/Graph.h"
 #include "core/modelIR/IrDotDumper.h"