[nnc] Join `import` and `createIR` methods into `importModel` (#5984)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Mon, 29 Jul 2019 11:35:55 +0000 (14:35 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 29 Jul 2019 11:35:55 +0000 (14:35 +0300)
Methods `import` and `createIR` are always used together. Join them to simplify usage.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
14 files changed:
compiler/nnc/include/passes/common_frontend/NNImporter.h
compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp
compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.h
compiler/nnc/passes/caffe_frontend/caffe_importer_pass.cpp
compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h
compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp
compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h
compiler/nnc/passes/tflite_frontend/tflite_importer_pass.cpp
compiler/nnc/passes/tflite_frontend/tflite_importer_pass.h
compiler/nnc/tests/import/caffe.cpp
compiler/nnc/tests/import/tflite.cpp
compiler/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp
compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp
compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp

index e4c866b..33526f4 100644 (file)
@@ -29,25 +29,14 @@ class NNImporter : public Pass {
 public:
   // template method pattern
   PassData run(PassData /*data*/) final {
-    import();
-    return createIR();
+    return importModel();
   }
 
   static std::unique_ptr<NNImporter> createNNImporter();
 
   void cleanup() override {}
 
-  /**
-  * @brief Import model from file, must be called before 'createIR' method
-  * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it
-  */
-  virtual void import() = 0;
-
-  /**
-   * @brief Create MIR graph from caffe model, must be called after 'import' method
-   * @return MIR graph, corresponding to processed caffe model
-   */
-  virtual mir::Graph *createIR() = 0;
+  virtual mir::Graph *importModel() = 0;
 
   std::string getName() override { return "importer";};
 };
index b3d8cf0..4e49809 100644 (file)
@@ -26,22 +26,11 @@ Caffe2ImporterPass::Caffe2ImporterPass(const std::string &predict_net, const std
 {
 }
 
-void Caffe2ImporterPass::import()
+mir::Graph *Caffe2ImporterPass::importModel()
 {
   try
   {
     _pimpl->import();
-  }
-  catch (const std::exception &e)
-  {
-    throw PassException(e.what());
-  }
-}
-
-mir::Graph *Caffe2ImporterPass::createIR()
-{
-  try
-  {
     return _pimpl->createIR();
   }
   catch (const std::exception &e)
index 5b4f1a4..0a7d743 100644 (file)
@@ -30,17 +30,7 @@ public:
   explicit Caffe2ImporterPass(const std::string &predict_net, const std::string &init_net,
                               const std::vector<std::vector<int>> &input_shapes);
 
-  /**
-   * @brief Import model from file, must be called before 'createIR' method
-   * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it
-   */
-  void import() override;
-
-  /**
-   * @brief Create MIR graph from caffe model, must be called after 'import' method
-   * @return MIR graph, corresponding to processed caffe model
-   */
-  mir::Graph *createIR() override;
+  mir::Graph *importModel() override;
 
   void cleanup() override;
 
index 6adb856..3bb7d51 100644 (file)
@@ -26,23 +26,11 @@ CaffeImporterPass::CaffeImporterPass(const std::string &filename)
 {
 }
 
-void CaffeImporterPass::import()
+mir::Graph *CaffeImporterPass::importModel()
 {
   try
   {
     _pimpl->import();
-  }
-  catch (const std::exception &e)
-  {
-    throw PassException(e.what());
-  }
-}
-
-mir::Graph *CaffeImporterPass::createIR()
-{
-  try
-  {
-
     return _pimpl->createIR();
   }
   catch (const std::exception &e)
index 29615e2..34f1e8e 100644 (file)
@@ -29,9 +29,7 @@ class CaffeImporterPass : public NNImporter
 public:
   explicit CaffeImporterPass(const std::string &filename);
 
-  void import() override;
-
-  mir::Graph *createIR() override;
+  mir::Graph *importModel() override;
 
   void cleanup() override;
 
index e8497ea..e8bb064 100644 (file)
@@ -26,22 +26,11 @@ ONNXImporterPass::ONNXImporterPass(const std::string &filename)
 {
 }
 
-void ONNXImporterPass::import()
+mir::Graph *ONNXImporterPass::importModel()
 {
   try
   {
     _pimpl->import();
-  }
-  catch (const std::exception &e)
-  {
-    throw PassException(e.what());
-  }
-}
-
-mir::Graph *ONNXImporterPass::createIR()
-{
-  try
-  {
     return _pimpl->createIR();
   }
   catch (const std::exception &e)
index 6089fe0..322cddf 100644 (file)
@@ -29,9 +29,7 @@ class ONNXImporterPass : public NNImporter
 public:
   explicit ONNXImporterPass(const std::string &filename);
 
-  void import() override;
-
-  mir::Graph *createIR() override;
+  mir::Graph *importModel() override;
 
   ~ONNXImporterPass() override;
 
index fff3a20..f70a4d1 100644 (file)
@@ -26,22 +26,11 @@ TfliteImporterPass::TfliteImporterPass(const std::string &filename)
 {
 }
 
-void TfliteImporterPass::import()
+mir::Graph *TfliteImporterPass::importModel()
 {
   try
   {
     _pimpl->import();
-  }
-  catch (const std::exception &e)
-  {
-    throw PassException(e.what());
-  }
-}
-
-mir::Graph *TfliteImporterPass::createIR()
-{
-  try
-  {
     return _pimpl->createIR();
   }
   catch (const std::exception &e)
index c70f1db..ad55fdb 100644 (file)
@@ -29,17 +29,7 @@ class TfliteImporterPass : public NNImporter
 public:
   explicit TfliteImporterPass(const std::string &filename);
 
-  /**
-   * @brief Import model from file, must be called before 'createIR' method
-   * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it
-   */
-  void import() override;
-
-  /**
-   * @brief Create MIR graph from caffe model, must be called after 'import' method
-   * @return MIR graph, corresponding to processed caffe model
-   */
-  mir::Graph *createIR() override;
+  mir::Graph *importModel() override;
 
   void cleanup() override;
 
index f480da9..88a42d4 100644 (file)
@@ -30,10 +30,8 @@ int main(int argc, const char **argv) {
 
   nnc::CaffeImporterPass importer{cli::inputFile};
 
-  importer.import();
-
   try {
-    importer.createIR();
+    importer.importModel();
   }
   catch (...) {
     std::cout << "Could not create IR for model \"" << cli::inputFile << "\"" << std::endl;
index e0f202a..a5fe85e 100644 (file)
@@ -32,11 +32,9 @@ int main(int argc, const char **argv)
   cli::CommandLine::getParser()->parseCommandLine(argc, argv);
   nnc::TfliteImporterPass importer{cli::inputFile};
 
-  importer.import();
-
   try
   {
-    importer.createIR();
+    importer.importModel();
   }
   catch (...)
   {
index 7f93c16..0e0c387 100644 (file)
@@ -16,8 +16,7 @@ TEST(CAFFE_IMPORT_UNSUPPORTED, ImportAModelWithUnsupportedLayers) {
 
   nnc::CaffeImporterPass importer{filename};
   try {
-    importer.import();
-    importer.createIR();
+    importer.importModel();
   }
   catch (nnc::PassException &e) {
     ASSERT_EQ(std::string(ErrorMsg), e.what());
index fe0943c..2ae8a55 100644 (file)
@@ -34,9 +34,8 @@ int main(int argc, const char **argv) {
   nnc::Caffe2ImporterPass importer{cli::inputFile, cli::initNet, {cli::inputShapes}};
 
   try {
-    importer.import();
     IrDotDumper dotDumper;
-    auto g = static_cast<Graph *>(importer.createIR());
+    auto g = static_cast<Graph *>(importer.importModel());
     g->accept(&dotDumper);
     dotDumper.writeDot(std::cout);
   }
index 1d5677c..83d7d36 100644 (file)
@@ -32,9 +32,8 @@ int main(int argc, const char **argv) {
   nnc::CaffeImporterPass importer{cli::inputFile};
 
   try {
-    importer.import();
     IrDotDumper dotDumper;
-    auto g = importer.createIR();
+    auto g = importer.importModel();
     g->accept(&dotDumper);
     dotDumper.writeDot(std::cout);
   }