From 5d9ba4e997a1ba2d19ee8891282cfd34c2f2de28 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 29 Jul 2019 14:35:55 +0300 Subject: [PATCH] [nnc] Join `import` and `createIR` methods into `importModel` (#5984) Methods `import` and `createIR` are always used together. Join them to simplify usage. Signed-off-by: Sergei Barannikov --- compiler/nnc/include/passes/common_frontend/NNImporter.h | 15 ++------------- .../nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp | 13 +------------ .../nnc/passes/caffe2_frontend/caffe2_importer_pass.h | 12 +----------- .../nnc/passes/caffe_frontend/caffe_importer_pass.cpp | 14 +------------- compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h | 4 +--- compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp | 13 +------------ compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h | 4 +--- .../nnc/passes/tflite_frontend/tflite_importer_pass.cpp | 13 +------------ .../nnc/passes/tflite_frontend/tflite_importer_pass.h | 12 +----------- compiler/nnc/tests/import/caffe.cpp | 4 +--- compiler/nnc/tests/import/tflite.cpp | 4 +--- .../unittests/caffe_frontend/unsupported_caffe_model.cpp | 3 +-- compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp | 3 +-- compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp | 3 +-- 14 files changed, 15 insertions(+), 102 deletions(-) diff --git a/compiler/nnc/include/passes/common_frontend/NNImporter.h b/compiler/nnc/include/passes/common_frontend/NNImporter.h index e4c866b..33526f4 100644 --- a/compiler/nnc/include/passes/common_frontend/NNImporter.h +++ b/compiler/nnc/include/passes/common_frontend/NNImporter.h @@ -29,25 +29,14 @@ class NNImporter : public Pass { public: // template method pattern PassData run(PassData /*data*/) final { - import(); - return createIR(); + return importModel(); } static std::unique_ptr createNNImporter(); void cleanup() override {} - /** - * @brief Import model from file, must be called before 'createIR' method - * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it - */ - virtual void import() = 0; - - /** - * @brief Create MIR graph from caffe model, must be called after 'import' method - * @return MIR graph, corresponding to processed caffe model - */ - virtual mir::Graph *createIR() = 0; + virtual mir::Graph *importModel() = 0; std::string getName() override { return "importer";}; }; diff --git a/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp b/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp index b3d8cf0..4e49809 100644 --- a/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp +++ b/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.cpp @@ -26,22 +26,11 @@ Caffe2ImporterPass::Caffe2ImporterPass(const std::string &predict_net, const std { } -void Caffe2ImporterPass::import() +mir::Graph *Caffe2ImporterPass::importModel() { try { _pimpl->import(); - } - catch (const std::exception &e) - { - throw PassException(e.what()); - } -} - -mir::Graph *Caffe2ImporterPass::createIR() -{ - try - { return _pimpl->createIR(); } catch (const std::exception &e) diff --git a/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.h b/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.h index 5b4f1a4..0a7d743 100644 --- a/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.h +++ b/compiler/nnc/passes/caffe2_frontend/caffe2_importer_pass.h @@ -30,17 +30,7 @@ public: explicit Caffe2ImporterPass(const std::string &predict_net, const std::string &init_net, const std::vector> &input_shapes); - /** - * @brief Import model from file, must be called before 'createIR' method - * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it - */ - void import() override; - - /** - * @brief Create MIR graph from caffe model, must be called after 'import' method - * @return MIR graph, corresponding to processed caffe model - */ - mir::Graph *createIR() override; + mir::Graph *importModel() override; void cleanup() override; diff --git a/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.cpp b/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.cpp index 6adb856..3bb7d51 100644 --- a/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.cpp +++ b/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.cpp @@ -26,23 +26,11 @@ CaffeImporterPass::CaffeImporterPass(const std::string &filename) { } -void CaffeImporterPass::import() +mir::Graph *CaffeImporterPass::importModel() { try { _pimpl->import(); - } - catch (const std::exception &e) - { - throw PassException(e.what()); - } -} - -mir::Graph *CaffeImporterPass::createIR() -{ - try - { - return _pimpl->createIR(); } catch (const std::exception &e) diff --git a/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h b/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h index 29615e2..34f1e8e 100644 --- a/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h +++ b/compiler/nnc/passes/caffe_frontend/caffe_importer_pass.h @@ -29,9 +29,7 @@ class CaffeImporterPass : public NNImporter public: explicit CaffeImporterPass(const std::string &filename); - void import() override; - - mir::Graph *createIR() override; + mir::Graph *importModel() override; void cleanup() override; diff --git a/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp b/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp index e8497ea..e8bb064 100644 --- a/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp +++ b/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.cpp @@ -26,22 +26,11 @@ ONNXImporterPass::ONNXImporterPass(const std::string &filename) { } -void ONNXImporterPass::import() +mir::Graph *ONNXImporterPass::importModel() { try { _pimpl->import(); - } - catch (const std::exception &e) - { - throw PassException(e.what()); - } -} - -mir::Graph *ONNXImporterPass::createIR() -{ - try - { return _pimpl->createIR(); } catch (const std::exception &e) diff --git a/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h b/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h index 6089fe0..322cddf 100644 --- a/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h +++ b/compiler/nnc/passes/onnx_frontend/ONNXImporterPass.h @@ -29,9 +29,7 @@ class ONNXImporterPass : public NNImporter public: explicit ONNXImporterPass(const std::string &filename); - void import() override; - - mir::Graph *createIR() override; + mir::Graph *importModel() override; ~ONNXImporterPass() override; diff --git a/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.cpp b/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.cpp index fff3a20..f70a4d1 100644 --- a/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.cpp +++ b/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.cpp @@ -26,22 +26,11 @@ TfliteImporterPass::TfliteImporterPass(const std::string &filename) { } -void TfliteImporterPass::import() +mir::Graph *TfliteImporterPass::importModel() { try { _pimpl->import(); - } - catch (const std::exception &e) - { - throw PassException(e.what()); - } -} - -mir::Graph *TfliteImporterPass::createIR() -{ - try - { return _pimpl->createIR(); } catch (const std::exception &e) diff --git a/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.h b/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.h index c70f1db..ad55fdb 100644 --- a/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.h +++ b/compiler/nnc/passes/tflite_frontend/tflite_importer_pass.h @@ -29,17 +29,7 @@ class TfliteImporterPass : public NNImporter public: explicit TfliteImporterPass(const std::string &filename); - /** - * @brief Import model from file, must be called before 'createIR' method - * @throw PassException in case, if model couldn't be parsed or NNC doesn't support it - */ - void import() override; - - /** - * @brief Create MIR graph from caffe model, must be called after 'import' method - * @return MIR graph, corresponding to processed caffe model - */ - mir::Graph *createIR() override; + mir::Graph *importModel() override; void cleanup() override; diff --git a/compiler/nnc/tests/import/caffe.cpp b/compiler/nnc/tests/import/caffe.cpp index f480da9..88a42d4 100644 --- a/compiler/nnc/tests/import/caffe.cpp +++ b/compiler/nnc/tests/import/caffe.cpp @@ -30,10 +30,8 @@ int main(int argc, const char **argv) { nnc::CaffeImporterPass importer{cli::inputFile}; - importer.import(); - try { - importer.createIR(); + importer.importModel(); } catch (...) { std::cout << "Could not create IR for model \"" << cli::inputFile << "\"" << std::endl; diff --git a/compiler/nnc/tests/import/tflite.cpp b/compiler/nnc/tests/import/tflite.cpp index e0f202a..a5fe85e 100644 --- a/compiler/nnc/tests/import/tflite.cpp +++ b/compiler/nnc/tests/import/tflite.cpp @@ -32,11 +32,9 @@ int main(int argc, const char **argv) cli::CommandLine::getParser()->parseCommandLine(argc, argv); nnc::TfliteImporterPass importer{cli::inputFile}; - importer.import(); - try { - importer.createIR(); + importer.importModel(); } catch (...) { diff --git a/compiler/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp b/compiler/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp index 7f93c16..0e0c387 100644 --- a/compiler/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp +++ b/compiler/nnc/unittests/caffe_frontend/unsupported_caffe_model.cpp @@ -16,8 +16,7 @@ TEST(CAFFE_IMPORT_UNSUPPORTED, ImportAModelWithUnsupportedLayers) { nnc::CaffeImporterPass importer{filename}; try { - importer.import(); - importer.createIR(); + importer.importModel(); } catch (nnc::PassException &e) { ASSERT_EQ(std::string(ErrorMsg), e.what()); diff --git a/compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp b/compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp index fe0943c..2ae8a55 100644 --- a/compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp +++ b/compiler/nnc/utils/caffe2_dot_dumper/model_dump.cpp @@ -34,9 +34,8 @@ int main(int argc, const char **argv) { nnc::Caffe2ImporterPass importer{cli::inputFile, cli::initNet, {cli::inputShapes}}; try { - importer.import(); IrDotDumper dotDumper; - auto g = static_cast(importer.createIR()); + auto g = static_cast(importer.importModel()); g->accept(&dotDumper); dotDumper.writeDot(std::cout); } diff --git a/compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp b/compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp index 1d5677c..83d7d36 100644 --- a/compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp +++ b/compiler/nnc/utils/caffe_dot_dumper/model_dump.cpp @@ -32,9 +32,8 @@ int main(int argc, const char **argv) { nnc::CaffeImporterPass importer{cli::inputFile}; try { - importer.import(); IrDotDumper dotDumper; - auto g = importer.createIR(); + auto g = importer.importModel(); g->accept(&dotDumper); dotDumper.writeDot(std::cout); } -- 2.7.4