From 1c163997abfe99353bba6f22b8425aeb45911858 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD=20=D0=9C=D0=B8=D1=85=D0=B0?= =?utf8?q?=D0=B9=D0=BB=D0=BE=D0=B2=D0=B8=D1=87=20=D0=A0=D1=83=D1=81=D1=8F?= =?utf8?q?=D0=B5=D0=B2/AI=20Tools=20Lab=20/SRR/Staff=20Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 12 Sep 2018 19:34:28 +0300 Subject: [PATCH] Implement Passes in nnc (#1474) * implement passes system * remove plugins Signed-off-by: Roman Rusyaev --- contrib/nnc/CMakeLists.txt | 66 ++++------- contrib/nnc/README.md | 1 - contrib/nnc/core/CMakeLists.txt | 2 +- contrib/nnc/driver/Driver.cpp | 83 ++++++------- contrib/nnc/driver/Driver.h | 2 +- contrib/nnc/driver/main.cpp | 13 +-- contrib/nnc/examples/caffe_frontend/model_dump.cpp | 5 +- contrib/nnc/examples/plugin/CMakeLists.txt | 9 -- contrib/nnc/examples/plugin/samplePlugin.cpp | 48 -------- .../nnc/examples/tflite_frontend/sanity_check.cpp | 5 +- contrib/nnc/include/Definitions.h.in | 29 +---- contrib/nnc/include/pass/Pass.h | 42 +++++++ contrib/nnc/include/pass/PassData.h | 76 ++++++++++++ contrib/nnc/include/pass/PassException.h | 39 +++++++ contrib/nnc/include/pass/PassManager.h | 51 ++++++++ .../include/passes/caffe_frontend/CaffeFrontend.h | 39 +++++++ .../common_frontend/model_allocation.h | 0 .../common_frontend/nn_importer.h | 0 .../common_frontend/shape_helper.h | 0 .../{plugin => passes}/interpreter/Interpreter.h | 0 .../include/passes/interpreter/InterpreterPass.h | 40 +++++++ .../passes/soft_backend/BaseGenerator.h} | 9 +- .../passes/soft_backend/CGenerator.h} | 8 +- .../passes/soft_backend/CPPGenerator.h} | 8 +- .../passes/tflite_frontend/TfliteFrontend.h | 40 +++++++ contrib/nnc/include/support/CommandLine.h | 1 - contrib/nnc/include/support/PluginException.h | 34 ------ contrib/nnc/include/support/PluginInstance.h | 60 ---------- contrib/nnc/include/support/PluginManager.h | 62 ---------- contrib/nnc/include/support/PluginProxy.h | 79 ------------- contrib/nnc/include/support/shared_library.h | 129 --------------------- contrib/nnc/pass/CMakeLists.txt | 4 + contrib/nnc/pass/PassManager.cpp | 27 +++++ contrib/nnc/{plugin => passes}/CMakeLists.txt | 0 .../caffe_frontend/CMakeLists.txt | 3 +- .../caffe_frontend/caffe_dump_visitor.cpp | 0 .../caffe_frontend/caffe_dump_visitor.h | 0 .../nnc/passes/caffe_frontend/caffe_frontend.cpp | 47 ++++++++ .../caffe_frontend/caffe_importer.cpp | 2 +- .../caffe_frontend/caffe_importer.h | 4 +- .../caffe_frontend/caffe_model_visitor.cpp | 12 +- .../caffe_frontend/caffe_model_visitor.h | 0 .../caffe_frontend/caffe_op_creator.cpp | 29 ++--- .../caffe_frontend/caffe_op_creator.h | 1 - .../caffe_frontend/caffe_visitor.h | 0 .../caffe_frontend/caffe_walker.cpp | 0 .../caffe_frontend/caffe_walker.h | 0 .../caffe_frontend/proto_reader.cpp | 0 .../caffe_frontend/proto_reader.h | 0 .../common_frontend/CMakeLists.txt | 0 .../common_frontend/model_allocation.cpp | 2 +- .../nnc/passes/common_frontend/shape_helper.cpp | 39 +++++++ .../{plugin => passes}/interpreter/CMakeLists.txt | 3 +- .../{plugin => passes}/interpreter/Interpreter.cpp | 2 +- .../interpreter/interpreter_pass.cpp} | 39 +++---- .../{plugin => passes}/interpreter/ops/Bias.cpp | 0 .../nnc/{plugin => passes}/interpreter/ops/Bias.h | 0 .../{plugin => passes}/interpreter/ops/Concat.cpp | 0 .../{plugin => passes}/interpreter/ops/Concat.h | 0 .../interpreter/ops/Depthwise_conv_2D.cpp | 0 .../interpreter/ops/Depthwise_conv_2D.h | 0 .../interpreter/ops/Elementwise.cpp | 0 .../interpreter/ops/Elementwise.h | 0 .../{plugin => passes}/interpreter/ops/Fill.cpp | 0 .../nnc/{plugin => passes}/interpreter/ops/Fill.h | 0 .../interpreter/ops/FullyConnected.cpp | 0 .../interpreter/ops/FullyConnected.h | 0 .../interpreter/ops/OperationImpl.h | 0 .../{plugin => passes}/interpreter/ops/Pool.cpp | 0 .../nnc/{plugin => passes}/interpreter/ops/Pool.h | 0 .../{plugin => passes}/interpreter/ops/Reduce.cpp | 0 .../{plugin => passes}/interpreter/ops/Reduce.h | 0 .../{plugin => passes}/interpreter/ops/Reshape.cpp | 0 .../{plugin => passes}/interpreter/ops/Reshape.h | 0 .../{plugin => passes}/interpreter/ops/Softmax.cpp | 0 .../{plugin => passes}/interpreter/ops/Softmax.h | 0 .../{plugin => passes}/interpreter/ops/common.cpp | 0 .../{plugin => passes}/interpreter/ops/common.h | 0 .../{plugin => passes}/interpreter/ops/conv_2D.cpp | 0 .../{plugin => passes}/interpreter/ops/conv_2D.h | 0 .../interpreter/ops/conv_FFT.cpp | 0 .../{plugin => passes}/interpreter/ops/conv_FFT.h | 0 .../soft_backend/BaseGenerator.cpp} | 22 ++-- .../soft_backend/CGenerator.cpp} | 14 +-- .../{plugin => passes}/soft_backend/CMakeLists.txt | 9 +- .../soft_backend/code_snippets/cpp_add_bias.def | 0 .../soft_backend/code_snippets/cpp_capped_relu.def | 0 .../code_snippets/cpp_common_funcs.def | 0 .../soft_backend/code_snippets/cpp_concat.def | 0 .../soft_backend/code_snippets/cpp_conv.def | 0 .../code_snippets/cpp_depthwise_conv.def | 0 .../code_snippets/cpp_fully_connected.def | 0 .../code_snippets/cpp_header_types.def | 0 .../soft_backend/code_snippets/cpp_operations.def | 0 .../soft_backend/code_snippets/cpp_pool.def | 0 .../soft_backend/code_snippets/cpp_relu.def | 0 .../soft_backend/code_snippets/cpp_softmax.def | 0 .../soft_backend/code_snippets/eigen.def | 0 .../soft_backend/cpp_generator.cpp | 15 ++- .../soft_backend/model_analyzer.cpp | 0 .../soft_backend/model_analyzer.h | 0 .../soft_backend/param_constants.def | 0 .../{plugin => passes}/soft_backend/serializer.cpp | 0 .../{plugin => passes}/soft_backend/serializer.h | 0 .../tflite_frontend/CMakeLists.txt | 5 +- .../tflite_frontend/schema/schema.fbs | 0 .../tflite_frontend/schema/schema.meta | 0 .../tflite_frontend/schema/schema_v0.fbs | 0 .../tflite_frontend/schema/schema_v0.meta | 0 .../tflite_frontend/schema/schema_v1.fbs | 0 .../tflite_frontend/schema/schema_v1.meta | 0 .../tflite_frontend/schema/schema_v2.fbs | 0 .../tflite_frontend/schema/schema_v2.meta | 0 .../tflite_frontend/schema/schema_v3.fbs | 0 .../tflite_frontend/schema/schema_v3.meta | 0 .../{plugin => passes}/tflite_frontend/schema_v3.h | 0 .../tflite_frontend/tflite_dump_visitor.cpp | 0 .../tflite_frontend/tflite_dump_visitor.h | 0 .../nnc/passes/tflite_frontend/tflite_frontend.cpp | 47 ++++++++ .../tflite_frontend/tflite_importer.inline.cpp | 0 .../tflite_frontend/tflite_importer.inline.h | 0 .../tflite_frontend/tflite_ir_visitor.cpp | 14 ++- .../tflite_frontend/tflite_ir_visitor.h | 0 .../tflite_frontend/tflite_op_creator.cpp | 7 +- .../tflite_frontend/tflite_op_creator.h | 3 +- .../tflite_frontend/tflite_v3_importer.cpp | 0 .../tflite_frontend/tflite_v3_importer.h | 4 +- .../tflite_frontend/tflite_visitor.h | 0 .../tflite_frontend/tflite_walker.cpp | 0 .../tflite_frontend/tflite_walker.h | 0 contrib/nnc/plugin/caffe_frontend/caffe_plugin.cpp | 55 --------- .../nnc/plugin/common_frontend/shape_helper.cpp | 37 ------ .../nnc/plugin/interpreter/interpreter_plugin.h | 44 ------- .../nnc/plugin/tflite_frontend/tflite_plugin.cpp | 55 --------- contrib/nnc/support/CLOptionChecker.cpp | 26 ----- contrib/nnc/support/CMakeLists.txt | 6 +- contrib/nnc/support/PluginManager.cpp | 73 ------------ contrib/nnc/support/PluginProxy.cpp | 66 ----------- contrib/nnc/tests/interpreter/graph_creator.cpp | 2 +- contrib/nnc/tests/interpreter/op_info_util.h | 2 +- contrib/nnc/tests/interpreter/op_test.cpp | 2 +- contrib/nnc/tests/soft_backend/compile_cpp.cpp | 4 +- contrib/nnc/unittests/CMakeLists.txt | 2 +- contrib/nnc/unittests/module/CMakeLists.txt | 13 --- contrib/nnc/unittests/module/PluginManager.cpp | 68 ----------- contrib/nnc/unittests/module/PluginProxy.cpp | 42 ------- contrib/nnc/unittests/module/shared_library.cpp | 33 ------ contrib/nnc/unittests/pass/CMakeLists.txt | 6 + contrib/nnc/unittests/pass/PassExceptionTest.cpp | 37 ++++++ contrib/nnc/unittests/pass/PassManagerTest.cpp | 57 +++++++++ .../nnc/unittests/soft_backend/cpp_operations.cpp | 2 +- contrib/nnc/unittests/soft_backend/generator.cpp | 19 +-- contrib/nnc/unittests/support/PluginException.cpp | 37 ------ 153 files changed, 798 insertions(+), 1224 deletions(-) delete mode 100644 contrib/nnc/examples/plugin/CMakeLists.txt delete mode 100644 contrib/nnc/examples/plugin/samplePlugin.cpp create mode 100644 contrib/nnc/include/pass/Pass.h create mode 100644 contrib/nnc/include/pass/PassData.h create mode 100644 contrib/nnc/include/pass/PassException.h create mode 100644 contrib/nnc/include/pass/PassManager.h create mode 100644 contrib/nnc/include/passes/caffe_frontend/CaffeFrontend.h rename contrib/nnc/include/{plugin => passes}/common_frontend/model_allocation.h (100%) rename contrib/nnc/include/{plugin => passes}/common_frontend/nn_importer.h (100%) rename contrib/nnc/include/{plugin => passes}/common_frontend/shape_helper.h (100%) rename contrib/nnc/include/{plugin => passes}/interpreter/Interpreter.h (100%) create mode 100644 contrib/nnc/include/passes/interpreter/InterpreterPass.h rename contrib/nnc/{plugin/soft_backend/base_generator.h => include/passes/soft_backend/BaseGenerator.h} (84%) rename contrib/nnc/{plugin/soft_backend/c_generator.h => include/passes/soft_backend/CGenerator.h} (85%) rename contrib/nnc/{plugin/soft_backend/cpp_generator.h => include/passes/soft_backend/CPPGenerator.h} (91%) create mode 100644 contrib/nnc/include/passes/tflite_frontend/TfliteFrontend.h delete mode 100644 contrib/nnc/include/support/PluginException.h delete mode 100644 contrib/nnc/include/support/PluginInstance.h delete mode 100644 contrib/nnc/include/support/PluginManager.h delete mode 100644 contrib/nnc/include/support/PluginProxy.h delete mode 100644 contrib/nnc/include/support/shared_library.h create mode 100644 contrib/nnc/pass/CMakeLists.txt create mode 100644 contrib/nnc/pass/PassManager.cpp rename contrib/nnc/{plugin => passes}/CMakeLists.txt (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/CMakeLists.txt (82%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_dump_visitor.cpp (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_dump_visitor.h (100%) create mode 100644 contrib/nnc/passes/caffe_frontend/caffe_frontend.cpp rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_importer.cpp (96%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_importer.h (89%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_model_visitor.cpp (95%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_model_visitor.h (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_op_creator.cpp (93%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_op_creator.h (98%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_visitor.h (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_walker.cpp (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/caffe_walker.h (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/proto_reader.cpp (100%) rename contrib/nnc/{plugin => passes}/caffe_frontend/proto_reader.h (100%) rename contrib/nnc/{plugin => passes}/common_frontend/CMakeLists.txt (100%) rename contrib/nnc/{plugin => passes}/common_frontend/model_allocation.cpp (94%) create mode 100644 contrib/nnc/passes/common_frontend/shape_helper.cpp rename contrib/nnc/{plugin => passes}/interpreter/CMakeLists.txt (75%) rename contrib/nnc/{plugin => passes}/interpreter/Interpreter.cpp (99%) rename contrib/nnc/{plugin/interpreter/interpreter_plugin.cpp => passes/interpreter/interpreter_pass.cpp} (82%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Bias.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Bias.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Concat.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Concat.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Depthwise_conv_2D.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Depthwise_conv_2D.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Elementwise.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Elementwise.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Fill.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Fill.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/FullyConnected.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/FullyConnected.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/OperationImpl.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Pool.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Pool.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Reduce.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Reduce.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Reshape.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Reshape.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Softmax.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/Softmax.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/common.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/common.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/conv_2D.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/conv_2D.h (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/conv_FFT.cpp (100%) rename contrib/nnc/{plugin => passes}/interpreter/ops/conv_FFT.h (100%) rename contrib/nnc/{plugin/soft_backend/base_generator.cpp => passes/soft_backend/BaseGenerator.cpp} (83%) rename contrib/nnc/{plugin/soft_backend/c_generator.cpp => passes/soft_backend/CGenerator.cpp} (79%) rename contrib/nnc/{plugin => passes}/soft_backend/CMakeLists.txt (75%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_add_bias.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_capped_relu.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_common_funcs.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_concat.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_conv.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_depthwise_conv.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_fully_connected.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_header_types.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_operations.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_pool.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_relu.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/cpp_softmax.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/code_snippets/eigen.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/cpp_generator.cpp (97%) rename contrib/nnc/{plugin => passes}/soft_backend/model_analyzer.cpp (100%) rename contrib/nnc/{plugin => passes}/soft_backend/model_analyzer.h (100%) rename contrib/nnc/{plugin => passes}/soft_backend/param_constants.def (100%) rename contrib/nnc/{plugin => passes}/soft_backend/serializer.cpp (100%) rename contrib/nnc/{plugin => passes}/soft_backend/serializer.h (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/CMakeLists.txt (89%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema.fbs (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema.meta (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v0.fbs (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v0.meta (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v1.fbs (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v1.meta (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v2.fbs (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v2.meta (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v3.fbs (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema/schema_v3.meta (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/schema_v3.h (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_dump_visitor.cpp (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_dump_visitor.h (100%) create mode 100644 contrib/nnc/passes/tflite_frontend/tflite_frontend.cpp rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_importer.inline.cpp (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_importer.inline.h (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_ir_visitor.cpp (95%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_ir_visitor.h (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_op_creator.cpp (96%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_op_creator.h (97%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_v3_importer.cpp (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_v3_importer.h (81%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_visitor.h (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_walker.cpp (100%) rename contrib/nnc/{plugin => passes}/tflite_frontend/tflite_walker.h (100%) delete mode 100644 contrib/nnc/plugin/caffe_frontend/caffe_plugin.cpp delete mode 100644 contrib/nnc/plugin/common_frontend/shape_helper.cpp delete mode 100644 contrib/nnc/plugin/interpreter/interpreter_plugin.h delete mode 100644 contrib/nnc/plugin/tflite_frontend/tflite_plugin.cpp delete mode 100644 contrib/nnc/support/PluginManager.cpp delete mode 100644 contrib/nnc/support/PluginProxy.cpp delete mode 100644 contrib/nnc/unittests/module/CMakeLists.txt delete mode 100644 contrib/nnc/unittests/module/PluginManager.cpp delete mode 100644 contrib/nnc/unittests/module/PluginProxy.cpp delete mode 100644 contrib/nnc/unittests/module/shared_library.cpp create mode 100644 contrib/nnc/unittests/pass/CMakeLists.txt create mode 100644 contrib/nnc/unittests/pass/PassExceptionTest.cpp create mode 100644 contrib/nnc/unittests/pass/PassManagerTest.cpp delete mode 100644 contrib/nnc/unittests/support/PluginException.cpp diff --git a/contrib/nnc/CMakeLists.txt b/contrib/nnc/CMakeLists.txt index 8aec813..057719f 100644 --- a/contrib/nnc/CMakeLists.txt +++ b/contrib/nnc/CMakeLists.txt @@ -3,9 +3,7 @@ project(nnc) list(INSERT CMAKE_MODULE_PATH 0 ${CMAKE_CURRENT_SOURCE_DIR}/cmake) include(soft_backend) -set(DRIVER_HEADERS driver/Driver.h) -set(DRIVER_SOURCES driver/Driver.cpp) -set(MAIN "driver/main.cpp") +set(DRIVER_SOURCES driver/main.cpp driver/Driver.cpp) set(OPTIONS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/driver/Options.cpp) # add interface header files @@ -19,8 +17,6 @@ set(NNC_ROOT_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) ### set(NNC_INSTALL_PATH ${CMAKE_INSTALL_PREFIX}) # root path of installation directory set(NNC_INSTALL_LIB_PATH ${NNC_INSTALL_PATH}/lib) # directory that contains other directories with shared library -set(NNC_INSTALL_PLUGIN_PATH ${NNC_INSTALL_LIB_PATH}/plugin) # path to where plugins will be located -set(NNC_INSTALL_CORE_PATH ${NNC_INSTALL_LIB_PATH}/core) # path to where common part of nnc will be located # # find necessary packages @@ -53,37 +49,17 @@ endif() ### # -# plugins names -# -# NOTE. If names of plugins are changed then these -# variables will have to be also chagned -if (APPLE) - set(LIB_SUFFIX ".dylib") -else() - set(LIB_SUFFIX ".so") -endif() - -set(NNC_CAFFE_PLUGIN_NAME "libcaffe_importer${LIB_SUFFIX}") -set(NNC_TFLITE_PLUGIN_NAME "libtflite_import${LIB_SUFFIX}") -set(NNC_SOFT_CPP_PLUGIN_NAME "libsoft_backend_cpp${LIB_SUFFIX}") -set(NNC_SOFT_C_PLUGIN_NAME "libsoft_backend_c${LIB_SUFFIX}") -set(NNC_INTERPRETER_NAME "libnnc_interpreter${LIB_SUFFIX}") -### - -# # functions # -function(install_nnc_plugin PLUGIN) - install(TARGETS ${PLUGIN} DESTINATION ${NNC_INSTALL_PLUGIN_PATH}) +function(install_nnc_library LIB) + install(TARGETS ${LIB} DESTINATION ${NNC_INSTALL_LIB_PATH}) # set external RPATHs - set_target_properties(${PLUGIN} PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) + set_target_properties(${LIB} PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) + # use paths from build directoris + set_target_properties(${LIB} PROPERTIES CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) # set RPATH to core part of nnc - set_target_properties(${PLUGIN} PROPERTIES INSTALL_RPATH ${NNC_INSTALL_CORE_PATH}) -endfunction(install_nnc_plugin) - -function(install_common_library) - install(TARGETS ${ARGV} DESTINATION ${NNC_INSTALL_CORE_PATH}) -endfunction(install_common_library) + set_target_properties(${LIB} PROPERTIES INSTALL_RPATH ${NNC_INSTALL_LIB_PATH}) +endfunction() # # end functions # @@ -91,37 +67,35 @@ endfunction(install_common_library) # # Used by unit tests # -set(NNC_SOFT_BACKEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/plugin/soft_backend) -set(NNC_INTERPRETER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/plugin/interpreter) -set(NNC_CAFFE_FRONTEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/plugin/caffe_frontend) -set(NNC_TFLITE_FRONTEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/plugin/tflite_frontend) +set(NNC_SOFT_BACKEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/passes/soft_backend) +set(NNC_INTERPRETER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/passes/interpreter) +set(NNC_CAFFE_FRONTEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/passes/caffe_frontend) +set(NNC_TFLITE_FRONTEND_DIR ${CMAKE_CURRENT_SOURCE_DIR}/passes/tflite_frontend) set(NNC_CORE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/core) set(NNC_SUPPORT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/support) -# driver library -add_nncc_library(nnc_driver SHARED ${DRIVER_HEADERS} ${DRIVER_SOURCES}) -target_link_libraries(nnc_driver PRIVATE nnc_support) -install_common_library(nnc_driver) - # nnc executable -add_executable(nnc ${MAIN} ${OPTIONS_SRC}) -target_link_libraries(nnc PRIVATE nnc_support nnc_driver) +add_executable(nnc ${DRIVER_SOURCES} ${OPTIONS_SRC}) +target_link_libraries(nnc PRIVATE nnc_support nnc_pass) +target_link_libraries(nnc PRIVATE caffe_importer tflite_import soft_backend_cpp soft_backend_c nnc_interpreter) # configure file that contains extern definitions configure_file(${CMAKE_CURRENT_SOURCE_DIR}/include/Definitions.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/Definitions.h) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -set(NNC_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) add_subdirectory(support) add_subdirectory(core) -add_subdirectory(plugin) +add_subdirectory(pass) +add_subdirectory(passes) add_subdirectory(examples) add_subdirectory(unittests) add_subdirectory(tests) # install nnc install(TARGETS nnc DESTINATION ${NNC_INSTALL_PATH}/bin) -set_target_properties(nnc PROPERTIES INSTALL_RPATH "${NNC_INSTALL_CORE_PATH};${NNC_INSTALL_PLUGIN_PATH}") +# TODO when we upgrade our cmake to version 2.12 this is needed to use BUILD_RPATH variable NOCOMMIT +set_target_properties(nnc PROPERTIES INSTALL_RPATH "${NNC_INSTALL_LIB_PATH}") set_target_properties(nnc PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE) +set_target_properties(nnc PROPERTIES CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) diff --git a/contrib/nnc/README.md b/contrib/nnc/README.md index 96bb066..d1f0373 100644 --- a/contrib/nnc/README.md +++ b/contrib/nnc/README.md @@ -4,7 +4,6 @@ Neural Network Compiler ### DESCRIPTION nnc is a neural network compiler that transforms neural networks of various formats into source or machine code. -Most functionality of nnc is stored in dynamically linked plugins. > At this moment only two NN are supported (MobileNet and InceptionV3) in Tensorflow Lite or Caffe format. ### SYNOPSIS diff --git a/contrib/nnc/core/CMakeLists.txt b/contrib/nnc/core/CMakeLists.txt index 39afbe1..ab44691 100644 --- a/contrib/nnc/core/CMakeLists.txt +++ b/contrib/nnc/core/CMakeLists.txt @@ -4,4 +4,4 @@ add_nncc_library(nnc_core SHARED ${SOURCES}) set_target_properties(nnc_core PROPERTIES LINKER_LANGUAGE CXX) # install nnc core library -install_common_library(nnc_core) +install_nnc_library(nnc_core) diff --git a/contrib/nnc/driver/Driver.cpp b/contrib/nnc/driver/Driver.cpp index 5b2f566..b929d5d 100644 --- a/contrib/nnc/driver/Driver.cpp +++ b/contrib/nnc/driver/Driver.cpp @@ -1,11 +1,20 @@ -#include "support/PluginManager.h" +#include "pass/PassManager.h" +#include "pass/PassData.h" + +#include "passes/caffe_frontend/CaffeFrontend.h" +#include "passes/tflite_frontend/TfliteFrontend.h" +#include "passes/interpreter/InterpreterPass.h" +#include "passes/soft_backend/CPPGenerator.h" + #include "support/CommandLine.h" #include "Definitions.h" #include "option/Options.h" #include "Driver.h" using namespace nncc::contrib; -using namespace nncc::contrib::plugin; +using namespace nncc::contrib::pass; +using namespace nncc::contrib::frontend; +using namespace nncc::contrib::backend; namespace nncc { @@ -13,30 +22,28 @@ namespace contrib { /** - * @brief run plugin - * @param plugin_path - absolute path to plugin - * @param data - plugin input data - * @return pointer to data that plugin generated - * @throw PluginException, if errors occured + * @brief run all registered passes + * @throw PassException, if errors occured */ -static void *executePlugin(const std::string &plugin_path, void *data) +static void *runPasses() { - PluginManager pluginManager(plugin_path); - - Plugin *plugin = pluginManager.getPlugin(); - void *res = plugin->execute(data); + auto registeredPasses = PassManager::getPassManager()->getPasses(); + PassData passData(nullptr); - return res; + for ( auto &pass : registeredPasses ) + { + passData = pass->run(passData); + } -} // executePlugin +} // runPasses /** - * @return absolute path to frontend plugin + * @brief Register frontend pass * @throw DriverException if errors occurred */ -static std::string getFrontendPlugin() +static void registerFrontendPass() { - std::string plugin; + Pass *pass; if ( clopt::caffeFrontend.isDisabled() && clopt::tflFrontend.isDisabled() ) { @@ -53,11 +60,11 @@ static std::string getFrontendPlugin() if ( clopt::caffeFrontend ) { - plugin = NNC_FRONTEND_CAFFE_NAME; + pass = &caffe::CaffeFrontend::getInstance(); } else if ( clopt::tflFrontend ) { - plugin = NNC_FRONTEND_TFLITE_NAME; + pass = &tflite::TFLiteFrontend::getInstance(); } else { @@ -66,45 +73,43 @@ static std::string getFrontendPlugin() + clopt::tflFrontend.getNames()[0] + "'"); } - return plugin; + PassManager::getPassManager()->registerPass(pass); -} // getFrontendPlugin +} // registerFrontendPass /** - * @return absolute path to backend plugin + * @brief Register backend pass * @throw DriverException if errors occurred */ -static std::string getBackendPlugin() +static void registerBackendPass() { - std::string plugin; - - assert( clopt::target == NNC_TARGET_X86_CPP || clopt::target == NNC_TARGET_INTERPRETER ); + Pass *pass; if ( clopt::target == NNC_TARGET_X86_CPP ) { - plugin = NNC_BACKEND_SOFT_CPP_NAME; + pass = &soft::CPPCodeGenerator::getInstance(); } else if ( clopt::target == NNC_TARGET_INTERPRETER ) { - plugin = NNC_BACKEND_INTERPRETER_NAME; + pass = &interpreter::InterpreterPass::getInstance(); + + } else + { + assert(false && "invalid option value"); } - return plugin; + PassManager::getPassManager()->registerPass(pass); -} // getBackendPlugin +} // registerBackendPass void Driver::runDriver() { - std::string plugin; - void *plugin_result; - - // run frontend plugin - plugin = getFrontendPlugin(); - plugin_result = executePlugin(plugin, nullptr); + // register passes + registerFrontendPass(); + registerBackendPass(); - // run backend plugin - plugin = getBackendPlugin(); - executePlugin(plugin, plugin_result); + // run registered passes + runPasses(); } // runDriver diff --git a/contrib/nnc/driver/Driver.h b/contrib/nnc/driver/Driver.h index 7cb41d6..ec78d6f 100644 --- a/contrib/nnc/driver/Driver.h +++ b/contrib/nnc/driver/Driver.h @@ -38,7 +38,7 @@ public: /** * @brief main method to run compiler driver * @throw DriverException if errors occurred in driver - * PluginException if errors occurred in plugins + * PassException if errors occurred in passes */ static void runDriver(); diff --git a/contrib/nnc/driver/main.cpp b/contrib/nnc/driver/main.cpp index 3aba221..58b00f0 100644 --- a/contrib/nnc/driver/main.cpp +++ b/contrib/nnc/driver/main.cpp @@ -1,13 +1,14 @@ #include #include -#include "support/PluginException.h" #include "support/CommandLine.h" +#include "pass/PassException.h" #include "Driver.h" #define DEBUG_AREA using namespace nncc::contrib; +using namespace nncc::contrib::pass; int main(int argc, const char *argv[]) { @@ -21,10 +22,8 @@ int main(int argc, const char *argv[]) // // run compiler pipeline: // - // for_each(all_plugins): - // load plugin - // execute plugin - // unload plugin + // for_each(all_passes): + // run pass // Driver::runDriver(); @@ -36,9 +35,9 @@ int main(int argc, const char *argv[]) std::cerr << e.reason() << std::endl; std::cerr << "use --help for more information" << std::endl; } - catch ( const PluginException &e ) + catch ( const PassException &e ) { - std::cerr << e.what() << std::endl; + std::cerr << e.reason() << std::endl; } return exit_code; diff --git a/contrib/nnc/examples/caffe_frontend/model_dump.cpp b/contrib/nnc/examples/caffe_frontend/model_dump.cpp index a201fa5..41923b5 100644 --- a/contrib/nnc/examples/caffe_frontend/model_dump.cpp +++ b/contrib/nnc/examples/caffe_frontend/model_dump.cpp @@ -1,14 +1,15 @@ #include #include "support/CommandLine.h" -#include "support/PluginException.h" #include "option/Options.h" #include "caffe_importer.h" #include "core/modelIR/graph.h" #include "core/modelIR/ir_dot_dumper.h" #include "core/modelIR/ShapeInference.h" +#include "pass/PassException.h" using namespace nncc::contrib; +using namespace nncc::contrib::pass; using namespace nncc::contrib::clopt; using namespace nncc::contrib::core::dumper; @@ -49,7 +50,7 @@ int main(int argc, const char **argv) g->accept(&dotDumper); dotDumper.writeDot(std::cout); - } catch (PluginException &e) { + } catch (PassException &e) { std::cout << "Error: " << e.what() << std::endl; return -1; } diff --git a/contrib/nnc/examples/plugin/CMakeLists.txt b/contrib/nnc/examples/plugin/CMakeLists.txt deleted file mode 100644 index b0e7dbb..0000000 --- a/contrib/nnc/examples/plugin/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -file(GLOB_RECURSE PL_EXAMPLE_PARSE_SRC *.cpp) -file(GLOB_RECURSE PL_EXAMPLE_PARSE_H) - -add_library(some_parser SHARED ${PL_EXAMPLE_PARSE_SRC} ${PL_EXAMPLE_PARSE_H}) -add_library(some_parser_second SHARED ${PL_EXAMPLE_PARSE_SRC} ${PL_EXAMPLE_PARSE_H}) - -target_link_libraries(some_parser PRIVATE nnc_core nnc_support) - -target_link_libraries(some_parser_second PRIVATE nnc_core nnc_support) diff --git a/contrib/nnc/examples/plugin/samplePlugin.cpp b/contrib/nnc/examples/plugin/samplePlugin.cpp deleted file mode 100644 index 2d5fc37..0000000 --- a/contrib/nnc/examples/plugin/samplePlugin.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include -#include -#include "support/PluginInstance.h" -#include "support/PluginException.h" - -using namespace nncc::contrib::plugin; - -class SamplePluginInstance : public FrontendPlugin -{ -public: - SamplePluginInstance &operator=(const SamplePluginInstance &) = delete; - SamplePluginInstance(const SamplePluginInstance &) = delete; - - static FrontendPlugin &getInstance(); - void *execute(void *data) override; - -private: - SamplePluginInstance() = default; - ~SamplePluginInstance() override = default; -}; - -FrontendPlugin &SamplePluginInstance::getInstance() -{ - static SamplePluginInstance instance; - // FIXME: it's necessary to make this printing via debugging system (see issue #893) - //NNC_DEBUG(dbgs() << std::endl << "!!! plugin (" << pluginName << ") " << __func__ << std::endl); - return instance; -} - -void *SamplePluginInstance::execute(void *data) -{ - // FIXME: it's necessary to make this printing via debugging system (see issue #893) - //std::cout << std::endl << "!!! plugin (" << pluginName << ") " << __func__ << std::endl; - return data; -} - -extern "C" Plugin *get_instance() -{ - // FIXME: it's necessary to make this printing via debugging system (see issue #893) - //std::cout << std::endl << "!!! plugin (" << pluginName << ") " << __func__ << std::endl; - return &SamplePluginInstance::getInstance(); -} - -extern "C" int getSomeBeef() -{ - return 0xBEEF; -} diff --git a/contrib/nnc/examples/tflite_frontend/sanity_check.cpp b/contrib/nnc/examples/tflite_frontend/sanity_check.cpp index 905d20e..2d32d0f 100644 --- a/contrib/nnc/examples/tflite_frontend/sanity_check.cpp +++ b/contrib/nnc/examples/tflite_frontend/sanity_check.cpp @@ -1,7 +1,7 @@ #include #include "support/CommandLine.h" -#include "support/PluginException.h" +#include "pass/PassException.h" #include "option/Options.h" #include "tflite_v3_importer.h" #include "core/modelIR/graph.h" @@ -9,6 +9,7 @@ #include "core/modelIR/ShapeInference.h" using namespace nncc::contrib; +using namespace nncc::contrib::pass; using namespace nncc::contrib::clopt; using namespace nncc::contrib::core::dumper; @@ -49,7 +50,7 @@ int main(int argc, const char **argv) g->accept(&dotDumper); dotDumper.writeDot(std::cout); - } catch (PluginException &e) { + } catch (PassException &e) { std::cout << "Error: " << e.what() << std::endl; return -1; } diff --git a/contrib/nnc/include/Definitions.h.in b/contrib/nnc/include/Definitions.h.in index 73862d5..7299f23 100644 --- a/contrib/nnc/include/Definitions.h.in +++ b/contrib/nnc/include/Definitions.h.in @@ -12,34 +12,9 @@ #define NNC_ROOT_PATH "@NNC_INSTALL_PATH@" /** - * @breif absolute path to directory contains plugins + * @breif absolute path to directory contains libraries */ -#define NNC_PLUGINS_PATH "@NNC_INSTALL_PLUGIN_PATH@" - -/** - * @brief name of CAFFE frontend plugin - */ -#define NNC_FRONTEND_CAFFE_NAME "@NNC_CAFFE_PLUGIN_NAME@" - -/** - * @brief name of TensorFlow Lite frontend plugin - */ -#define NNC_FRONTEND_TFLITE_NAME "@NNC_TFLITE_PLUGIN_NAME@" - -/** - * @brief name of Soft backend plugin which generates C source code - */ -#define NNC_BACKEND_SOFT_C_NAME "@NNC_SOFT_C_PLUGIN_NAME@" - -/** - * @brief name of Soft backend plugin which generates C++ source code - */ -#define NNC_BACKEND_SOFT_CPP_NAME "@NNC_SOFT_CPP_PLUGIN_NAME@" - -/** - * @brief name of Interpreter plugin - */ -#define NNC_BACKEND_INTERPRETER_NAME "@NNC_INTERPRETER_NAME@" +#define NNC_LIB_PATH "@NNC_INSTALL_LIB_PATH@" /** * @brief defines if hdf5 package was found diff --git a/contrib/nnc/include/pass/Pass.h b/contrib/nnc/include/pass/Pass.h new file mode 100644 index 0000000..293713d --- /dev/null +++ b/contrib/nnc/include/pass/Pass.h @@ -0,0 +1,42 @@ +#ifndef NNCC_PASS_H +#define NNCC_PASS_H + +#include + +#include "pass/PassData.h" + +namespace nncc +{ +namespace contrib +{ +namespace pass +{ + +/** + * @brief this class represent an interface for all compiler passes like that frontend, backend etc + */ +class Pass +{ +public: + Pass() = default; + + // to prevent copy of object + Pass &operator=(const Pass &) = delete; + Pass(const Pass &) = delete; + + /** + * @brief run compiler pass + * @param data - data that pass is taken + * @return data that can be passed to the next pass + * @throw PassException object if errors occured + */ + virtual PassData run(PassData data) = 0; + + virtual ~Pass() = default; +}; + +} // namespace pass +} // namespace contrib +} // namespace nncc + +#endif //NNCC_PASS_H diff --git a/contrib/nnc/include/pass/PassData.h b/contrib/nnc/include/pass/PassData.h new file mode 100644 index 0000000..20ddbb8 --- /dev/null +++ b/contrib/nnc/include/pass/PassData.h @@ -0,0 +1,76 @@ +#ifndef NNCC_PASSDATA_H +#define NNCC_PASSDATA_H + +#include "core/modelIR/graph.h" +#include "core/modelIR/TensorVariant.h" + +using namespace nncc::contrib::core::IR::model; + +namespace nncc +{ +namespace contrib +{ +namespace pass +{ + +/** + * @brief class that encapsulate value returned and taken by pass + */ +class PassData +{ +public: + PassData(const PassData &) = default; + + PassData(std::nullptr_t data) { _dataContainer.unknown = data; _dataType = PDT::UNKNOWN; } + + /** + * @brief Implicit conversion from Graph* to PassData + */ + /* implicit */ PassData(Graph *graph) { _dataContainer.graph = graph; _dataType = PDT::GRAPH; } + /** + * @brief Implicit conversion from PassData to Graph* + */ + /* implicit */ operator Graph*() const { + if ( _dataType != PDT::GRAPH ) + return nullptr; + return _dataContainer.graph; + } + + /** + * @brief Implicit conversion from Graph* to PassData + */ + /* implicit */ PassData(TensorVariant *tv) { _dataContainer.tensorVariant = tv; _dataType = PDT::TENSOR_VARIANT; } + /** + * @brief Implicit conversion from PassData to Graph* + */ + /* implicit */ operator TensorVariant*() const { + if ( _dataType != PDT::TENSOR_VARIANT ) + return nullptr; + return _dataContainer.tensorVariant; + } + +private: + // types that PassData can contain + enum class PDT : char + { + GRAPH, + TENSOR_VARIANT, + UNKNOWN + + } _dataType; + + // union contains all pointers to objects that can be returned from passes + union + { + Graph *graph; + TensorVariant *tensorVariant; + void *unknown; + + } _dataContainer; +}; + +} // namespace pass +} // namespace contrib +} // namespace nncc + +#endif //NNCC_PASSDATA_H diff --git a/contrib/nnc/include/pass/PassException.h b/contrib/nnc/include/pass/PassException.h new file mode 100644 index 0000000..e925933 --- /dev/null +++ b/contrib/nnc/include/pass/PassException.h @@ -0,0 +1,39 @@ +#ifndef NNCC_PASSEXCEPTION_H +#define NNCC_PASSEXCEPTION_H + +#include +#include + +namespace nncc +{ +namespace contrib +{ +namespace pass +{ + +/** + * @brief objects of this class are to be thrown from Passes if errors are occurred + */ +class PassException : public std::exception +{ +public: + PassException() = default; + PassException(const PassException &) noexcept {}; + + PassException(const std::string &msg) : _msg(msg) {}; + PassException(const char *msg) : _msg(msg) {}; + + /** + * @brief get message describes reason why exception was thrown + */ + const std::string &reason() const { return _msg; } + +private: + std::string _msg; +}; + +} // namespace pass +} // namespace contrib +} // namespace nncc + +#endif //NNCC_PASSEXCEPTION_H diff --git a/contrib/nnc/include/pass/PassManager.h b/contrib/nnc/include/pass/PassManager.h new file mode 100644 index 0000000..f05038c --- /dev/null +++ b/contrib/nnc/include/pass/PassManager.h @@ -0,0 +1,51 @@ +#ifndef __PASS_MANAGER_H__ +#define __PASS_MANAGER_H__ + +#include + +namespace nncc +{ +namespace contrib +{ +namespace pass +{ + +// forward declaration +class Pass; + +/** + * @brief pass manager class. This class manages running of passes + */ +class PassManager +{ +public: + /** + * @brief singleton method to get PassManager instance + */ + static PassManager *getPassManager(); + + /** + * @brief register pass in pass manager + * @param pass - registered pass + */ + void registerPass(Pass *pass); + + /** + * @brief get all registered passes in order in which they were registered + */ + using Passes = std::vector; + Passes getPasses() const { return _passes; } + +private: + PassManager() = default; + ~PassManager() = default; + + // data + Passes _passes; // registered passes +}; + +} // namespace pass +} // namespace contrib +} // namespace nncc + +#endif // __PASS_MANAGER_H__ diff --git a/contrib/nnc/include/passes/caffe_frontend/CaffeFrontend.h b/contrib/nnc/include/passes/caffe_frontend/CaffeFrontend.h new file mode 100644 index 0000000..0337c18 --- /dev/null +++ b/contrib/nnc/include/passes/caffe_frontend/CaffeFrontend.h @@ -0,0 +1,39 @@ +#ifndef NNCC_CAFFEFRONTEND_H +#define NNCC_CAFFEFRONTEND_H + +#include "pass/Pass.h" +#include "pass/PassData.h" + +using namespace nncc::contrib::pass; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace caffe +{ + +/** + * @brief class represent frontend of caffe NN framework + */ +class CaffeFrontend : public Pass +{ +public: + CaffeFrontend &operator=(const CaffeFrontend &) = delete; + CaffeFrontend(const CaffeFrontend &) = delete; + + CaffeFrontend() = default; + ~CaffeFrontend() override = default; + + static Pass &getInstance(); + PassData run(PassData data) override; +}; + +} // namespace caffe +} // namespace frontend +} // namespace contrib +} // namespace nncc + +#endif //NNCC_CAFFEFRONTEND_H diff --git a/contrib/nnc/include/plugin/common_frontend/model_allocation.h b/contrib/nnc/include/passes/common_frontend/model_allocation.h similarity index 100% rename from contrib/nnc/include/plugin/common_frontend/model_allocation.h rename to contrib/nnc/include/passes/common_frontend/model_allocation.h diff --git a/contrib/nnc/include/plugin/common_frontend/nn_importer.h b/contrib/nnc/include/passes/common_frontend/nn_importer.h similarity index 100% rename from contrib/nnc/include/plugin/common_frontend/nn_importer.h rename to contrib/nnc/include/passes/common_frontend/nn_importer.h diff --git a/contrib/nnc/include/plugin/common_frontend/shape_helper.h b/contrib/nnc/include/passes/common_frontend/shape_helper.h similarity index 100% rename from contrib/nnc/include/plugin/common_frontend/shape_helper.h rename to contrib/nnc/include/passes/common_frontend/shape_helper.h diff --git a/contrib/nnc/include/plugin/interpreter/Interpreter.h b/contrib/nnc/include/passes/interpreter/Interpreter.h similarity index 100% rename from contrib/nnc/include/plugin/interpreter/Interpreter.h rename to contrib/nnc/include/passes/interpreter/Interpreter.h diff --git a/contrib/nnc/include/passes/interpreter/InterpreterPass.h b/contrib/nnc/include/passes/interpreter/InterpreterPass.h new file mode 100644 index 0000000..9c8a8b8 --- /dev/null +++ b/contrib/nnc/include/passes/interpreter/InterpreterPass.h @@ -0,0 +1,40 @@ +#ifndef NNCC_INTERPRETERPASS_H +#define NNCC_INTERPRETERPASS_H + +#include "core/modelIR/TensorVariant.h" +#include "core/modelIR/Shape.h" + +#include "pass/Pass.h" +#include "pass/PassData.h" + +namespace nncc +{ +namespace contrib +{ +namespace backend +{ +namespace interpreter +{ + +using namespace nncc::contrib; +using namespace nncc::contrib::pass; + +class InterpreterPass : public Pass +{ +public: + static Pass &getInstance(); + PassData run(PassData data) override; + + virtual ~InterpreterPass(); + +private: + nncc::contrib::core::ADT::TensorVariant loadInput(const nncc::contrib::core::data::Shape &); + nncc::contrib::core::ADT::TensorVariant *_out; +}; + +} // namespace interpreter +} // namespace backend +} // namespace contrib +} // namespace nncc + +#endif //NNCC_INTERPRETERPASS_H diff --git a/contrib/nnc/plugin/soft_backend/base_generator.h b/contrib/nnc/include/passes/soft_backend/BaseGenerator.h similarity index 84% rename from contrib/nnc/plugin/soft_backend/base_generator.h rename to contrib/nnc/include/passes/soft_backend/BaseGenerator.h index 369f1fb..7d6a3d9 100644 --- a/contrib/nnc/plugin/soft_backend/base_generator.h +++ b/contrib/nnc/include/passes/soft_backend/BaseGenerator.h @@ -2,11 +2,14 @@ #define _NNC_SOFT_BACKEND_BASE_GENERATOR_H_ #include "core/modelIR/graph.h" -#include "support/PluginInstance.h" +#include "pass/Pass.h" +#include "pass/PassData.h" #include #include +using namespace nncc::contrib::pass; + namespace nncc { namespace contrib @@ -20,10 +23,10 @@ class ModelAnalyzer; class Serializer; -class BaseCodeGenerator: public nncc::contrib::plugin::BackendPlugin +class BaseCodeGenerator : public Pass { public: - void *execute(void *data) override; + PassData run(PassData data) override; protected: virtual void formatTensorNames(const ModelAnalyzer &ma) = 0; diff --git a/contrib/nnc/plugin/soft_backend/c_generator.h b/contrib/nnc/include/passes/soft_backend/CGenerator.h similarity index 85% rename from contrib/nnc/plugin/soft_backend/c_generator.h rename to contrib/nnc/include/passes/soft_backend/CGenerator.h index 1fbde6d..e94d2c3 100644 --- a/contrib/nnc/plugin/soft_backend/c_generator.h +++ b/contrib/nnc/include/passes/soft_backend/CGenerator.h @@ -1,7 +1,8 @@ #ifndef _NNC_SOFT_BACKEND_C_GENERATOR_H_ #define _NNC_SOFT_BACKEND_C_GENERATOR_H_ -#include "base_generator.h" +#include "passes/soft_backend/BaseGenerator.h" +#include "pass/Pass.h" namespace nncc { @@ -16,12 +17,15 @@ namespace soft class CCodeGenerator: public BaseCodeGenerator { public: - CCodeGenerator() = default; + static Pass &getInstance(); protected: void formatTensorNames(const ModelAnalyzer &ma) override; void materializeHeader(std::ostream &out, const ModelAnalyzer &ma) override; void materializeCode(std::ostream &out, const ModelAnalyzer &ma, const Serializer &s) override; + +private: + CCodeGenerator() = default; }; } // namespace soft diff --git a/contrib/nnc/plugin/soft_backend/cpp_generator.h b/contrib/nnc/include/passes/soft_backend/CPPGenerator.h similarity index 91% rename from contrib/nnc/plugin/soft_backend/cpp_generator.h rename to contrib/nnc/include/passes/soft_backend/CPPGenerator.h index 6448b98..e39eab7 100644 --- a/contrib/nnc/plugin/soft_backend/cpp_generator.h +++ b/contrib/nnc/include/passes/soft_backend/CPPGenerator.h @@ -1,7 +1,8 @@ #ifndef _NNC_SOFT_BACKEND_CPP_GENERATOR_H_ #define _NNC_SOFT_BACKEND_CPP_GENERATOR_H_ -#include "base_generator.h" +#include "passes/soft_backend/BaseGenerator.h" +#include "pass/Pass.h" namespace nncc { @@ -16,7 +17,7 @@ namespace soft class CPPCodeGenerator: public BaseCodeGenerator { public: - CPPCodeGenerator(): BaseCodeGenerator() {} + static Pass &getInstance(); protected: void formatTensorNames(const ModelAnalyzer &ma) override; @@ -29,6 +30,9 @@ protected: void printGetter(std::ostream &out, const std::string &className, const std::string &setterName, const std::string &varName); void materializeInferenceSequence(std::ostream &out, const ModelAnalyzer &ma); void materializeCode(std::ostream &out, const ModelAnalyzer &ma, const Serializer &s) override; + +private: + CPPCodeGenerator(): BaseCodeGenerator() {} }; } // namespace soft diff --git a/contrib/nnc/include/passes/tflite_frontend/TfliteFrontend.h b/contrib/nnc/include/passes/tflite_frontend/TfliteFrontend.h new file mode 100644 index 0000000..c48dccb --- /dev/null +++ b/contrib/nnc/include/passes/tflite_frontend/TfliteFrontend.h @@ -0,0 +1,40 @@ +#ifndef NNCC_TFLITEFRONTEND_H +#define NNCC_TFLITEFRONTEND_H + +#include "pass/Pass.h" +#include "pass/PassData.h" + +using namespace nncc::contrib::pass; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +/** + * @brief class represent frontend of tensor flow lite NN framework + */ +class TFLiteFrontend : public Pass +{ +public: + TFLiteFrontend &operator=(const TFLiteFrontend &) = delete; + TFLiteFrontend(const TFLiteFrontend &) = delete; + + static Pass &getInstance(); + PassData run(PassData data) override; + +private: + TFLiteFrontend() = default; + ~TFLiteFrontend() override = default; +}; + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc + +#endif //NNCC_TFLITEFRONTEND_H diff --git a/contrib/nnc/include/support/CommandLine.h b/contrib/nnc/include/support/CommandLine.h index 6bfdd5a..8469419 100644 --- a/contrib/nnc/include/support/CommandLine.h +++ b/contrib/nnc/include/support/CommandLine.h @@ -483,7 +483,6 @@ Option::Option(const std::vector &optnames, void checkInFile(const Option &); void checkOutFile(const Option &); void checkOutDir(const Option &); -void checkPluginsPath(const Option &); void checkDebugFile(const Option &); } // namespace clopt diff --git a/contrib/nnc/include/support/PluginException.h b/contrib/nnc/include/support/PluginException.h deleted file mode 100644 index ef9c9e9..0000000 --- a/contrib/nnc/include/support/PluginException.h +++ /dev/null @@ -1,34 +0,0 @@ -// -// Created by v.cherepanov@samsung.com on 04.05.18. -// -#ifndef __PLUGIN_EXCEPTION_H__ -#define __PLUGIN_EXCEPTION_H__ - -#include - -namespace nncc -{ -namespace contrib -{ - -class PluginException -{ -public: - PluginException(const PluginException &) noexcept {}; - - explicit PluginException(const std::string &info) { _msg = info; } - - /** - * @brief get message from exception object - */ - std::string what() const { return _msg; } - -private: - std::string _msg; -}; - - -} // namespace contrib -} // namespace nncc - -#endif // __PLUGIN_EXCEPTION_H__ diff --git a/contrib/nnc/include/support/PluginInstance.h b/contrib/nnc/include/support/PluginInstance.h deleted file mode 100644 index 272e2b6..0000000 --- a/contrib/nnc/include/support/PluginInstance.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef __PLUGIN_INSTANCE_H__ -#define __PLUGIN_INSTANCE_H__ - -#include -#include - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -// -// This class and its methods are NOT thread safe -// -class Plugin -{ -public: - Plugin &operator=(const Plugin &) = delete; - Plugin(const Plugin &) = delete; - - virtual void *execute(void *data) = 0; - -protected: - Plugin() = default; - virtual ~Plugin() = default; -}; - -class FrontendPlugin : public Plugin -{ -public: - FrontendPlugin &operator=(const FrontendPlugin &) = delete; - FrontendPlugin(const FrontendPlugin &) = delete; - - void *execute(void *data) override = 0; - -protected: - FrontendPlugin() = default; - ~FrontendPlugin() override = default; -}; - -class BackendPlugin : public Plugin -{ -public: - BackendPlugin &operator=(const BackendPlugin &) = delete; - BackendPlugin (const BackendPlugin &) = delete; - - void *execute(void *data) override = 0; - -protected: - BackendPlugin() = default; - ~BackendPlugin() override = default; -}; - -} // namespace plugin -} // namespace contrib -} // namespace nncc - -#endif // __PLUGIN_INSTANCE_H__ diff --git a/contrib/nnc/include/support/PluginManager.h b/contrib/nnc/include/support/PluginManager.h deleted file mode 100644 index 1d55894..0000000 --- a/contrib/nnc/include/support/PluginManager.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef __PLUGIN_MANAGER_H__ -#define __PLUGIN_MANAGER_H__ - -#include -#include - -#include "PluginProxy.h" - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -/** - * @brief plugin manager class - * this class manages plugin loading, execution and unloading - */ -class PluginManager -{ -public: - /** - * @param plugin_path - absolute plugin path - * @throw PluginException if couldn't load plugin - */ - explicit PluginManager(const std::string &plugin_path); - - /** - * @throw PluginException if couldn't unload plugin - */ - ~PluginManager() noexcept(false); - - /** - * @brief get plugin - */ - Plugin *getPlugin(); - - /** - * @brief print plugin - * @param st - output stream - * @param pm - plugin manager object - */ - friend std::ostream &operator<<(std::ostream &st, const PluginManager &pm); - -private: - /** - * @brief load/unload plugin - */ - void loadPlugin(); - void unloadPlugin(); - - // data - std::string _plugin_path; // path to plugin - std::shared_ptr _plugin_proxy; // access to plugin -}; - -} // namespace plugin -} // namespace contrib -} // namespace nncc - -#endif // __PLUGIN_MANAGER_H__ diff --git a/contrib/nnc/include/support/PluginProxy.h b/contrib/nnc/include/support/PluginProxy.h deleted file mode 100644 index 2f26694..0000000 --- a/contrib/nnc/include/support/PluginProxy.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef __PLUGIN_PROXY_H__ -#define __PLUGIN_PROXY_H__ - -#include -#include -#include - -#include "shared_library.h" -#include "PluginException.h" -#include "PluginInstance.h" - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -/** - * @brief this proxy class provides access to plugins - */ -class PluginProxy -{ -public: - /** - * @brief factory method that load plugin - * @param pluginPath - absolute path to plugin - * @return proxy class instance - * @throw PluginException if couldn't load plugin library or plugin is inappropriate - */ - // TODO it is possible to eliminate factory method and create proxy in constructor - static std::shared_ptr create(const std::string &pluginPath); - - /** - * @brief unload plugin - * @throw PluginException if couldn't unload plugin library - */ - // TODO it is possible to eliminate this method and do unloading library in destructor - void remove(); - - /** - * @brief get plugin absolute path and plugin name - */ - const std::string &getPluginPath() const { return _lib->getPath(); } - const std::string &getPluginName() const { return _pluginName; } - - /** - * @brief get loaded plugin instance - * @return pointer to plugin (don't need to be freed) - */ - Plugin *getPluginInstance(); - - /** - * @brief name of function that provides information - * about plugin. Every plugin must contain this function - */ - static const std::string getInstanceFuncName; - -private: - // only factory method can create class instances - explicit PluginProxy(const std::string &pluginPath); - - friend std::ostream &operator<<(std::ostream &st, const PluginProxy &pl); - - // type of function that returns loaded plugin - using get_instance_t = Plugin* (*)(); - - // data - Plugin *_pluginInstance; // pointer to loaded plugin - std::shared_ptr> _lib; // plugin library - get_instance_t _getInstance; // function from plugin that returns pointer to plugin - std::string _pluginName; // name of plugin -}; - -} // namespace plugin -} // namespace contrib -} // namespace nncc - -#endif /* __PLUGIN_PROXY_H__ */ diff --git a/contrib/nnc/include/support/shared_library.h b/contrib/nnc/include/support/shared_library.h deleted file mode 100644 index 497dd37..0000000 --- a/contrib/nnc/include/support/shared_library.h +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef NNCC_SHAREDLIBRARY_H -#define NNCC_SHAREDLIBRARY_H - -#include -#include -#include -#include - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -/** - * @brief class provides access to shared library - * @tparam ExceptionT - thrown exception type if errors occur - */ -template -class SharedLibrary -{ -public: - explicit SharedLibrary(const std::string &fullPath) : _handle(nullptr), _path(fullPath), _isLoaded(false) {} - - /** - * @brief find function from shared library - * @param funcName - function name - * @return pointer to found function - * @throw ExceptionT if errors occurred - */ - void *findFunc(const std::string &funcName); - - /** - * @brief unload shared library - * @throw ExceptionT if couldn't unload library - */ - void unloadLibrary(); - - /** - * @brief get path of shared library - */ - const std::string& getPath() const { return _path; }; - -private: - // load shared library - void *loadLibrary(); - - // data - std::string _path; // path to shared library - void *_handle; // handle returned by dlopen - bool _isLoaded; -}; - - -template -void *SharedLibrary::findFunc(const std::string &funcName) -{ - if ( !_isLoaded ) - { - _handle = loadLibrary(); - _isLoaded = true; - } - - // reset errors - dlerror(); - - // get function address - assert(_handle); - void *func = dlsym(_handle, funcName.c_str()); - - char *dlsym_error = dlerror(); - - if ( dlsym_error ) - { - throw ExceptionT("Cannot load symbol: '" + funcName + "' : " + dlsym_error); - } - - return func; - -} // findFunc - -template -void *SharedLibrary::loadLibrary() -{ -// NNC_DEBUG(dbgs() << "Opening " << _path << std::endl); - - // open the library - void *handle = dlopen(_path.c_str(), RTLD_LAZY); - - if ( !handle ) - { - throw ExceptionT("Cannot open library: '" + _path + "' : " + dlerror()); - } - - return handle; - -} // loadLibrary - -template -void SharedLibrary::unloadLibrary() -{ - // reset errors - dlerror(); - - // This is a workaround for bug when nnc is received segfault when it can access - // to vtable from unloaded plugin. When we eliminate plugins it will be fixed - /* - // close the library - if ( dlclose(_handle) ) - { - throw ExceptionT("Cannot unloaded library: '" + _path + "' : " + dlerror()); - } - */ - -} // unloadLibrary - -template -std::ostream &operator<<(std::ostream &st, const SharedLibrary &lib) -{ - st << lib.getPath(); - return st; -} - -} // namespase plugin -} // namespace contrib -} // namespace nncc - -#endif // NNCC_SHAREDLIBRARY_H diff --git a/contrib/nnc/pass/CMakeLists.txt b/contrib/nnc/pass/CMakeLists.txt new file mode 100644 index 0000000..8a5094a --- /dev/null +++ b/contrib/nnc/pass/CMakeLists.txt @@ -0,0 +1,4 @@ +set(PASS_MANAGER_SRC PassManager.cpp) + +add_library(nnc_pass STATIC ${PASS_MANAGER_SRC}) +set_target_properties(nnc_pass PROPERTIES LINKER_LANGUAGE CXX) diff --git a/contrib/nnc/pass/PassManager.cpp b/contrib/nnc/pass/PassManager.cpp new file mode 100644 index 0000000..4bf82ab --- /dev/null +++ b/contrib/nnc/pass/PassManager.cpp @@ -0,0 +1,27 @@ +#include "pass/PassManager.h" + +namespace nncc +{ +namespace contrib +{ +namespace pass +{ + +PassManager *PassManager::getPassManager() +{ + static PassManager passManager; + + return &passManager; + +} // getPassManager + + +void PassManager::registerPass(Pass *pass) +{ + _passes.push_back(pass); + +} // registerPass + +} // namespace pass +} // namespace contrib +} // namespace nncc diff --git a/contrib/nnc/plugin/CMakeLists.txt b/contrib/nnc/passes/CMakeLists.txt similarity index 100% rename from contrib/nnc/plugin/CMakeLists.txt rename to contrib/nnc/passes/CMakeLists.txt diff --git a/contrib/nnc/plugin/caffe_frontend/CMakeLists.txt b/contrib/nnc/passes/caffe_frontend/CMakeLists.txt similarity index 82% rename from contrib/nnc/plugin/caffe_frontend/CMakeLists.txt rename to contrib/nnc/passes/caffe_frontend/CMakeLists.txt index 72bb011..6f3ab2c 100644 --- a/contrib/nnc/plugin/caffe_frontend/CMakeLists.txt +++ b/contrib/nnc/passes/caffe_frontend/CMakeLists.txt @@ -10,7 +10,6 @@ file(GLOB caffe_importer_headers *.h) add_nncc_library(caffe_importer SHARED ${caffe_importer_sources} ${caffe_importer_headers}) -set_target_properties(caffe_importer PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NNC_BINARY_DIR}) target_link_libraries(caffe_importer PUBLIC caffeproto) target_link_libraries(caffe_importer PUBLIC nn_import_common) @@ -18,4 +17,4 @@ target_link_libraries(caffe_importer PRIVATE nnc_support) target_link_libraries(caffe_importer PRIVATE nnc_core) # install caffe frontend library -install_nnc_plugin(caffe_importer) +install_nnc_library(caffe_importer) diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_dump_visitor.cpp b/contrib/nnc/passes/caffe_frontend/caffe_dump_visitor.cpp similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_dump_visitor.cpp rename to contrib/nnc/passes/caffe_frontend/caffe_dump_visitor.cpp diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_dump_visitor.h b/contrib/nnc/passes/caffe_frontend/caffe_dump_visitor.h similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_dump_visitor.h rename to contrib/nnc/passes/caffe_frontend/caffe_dump_visitor.h diff --git a/contrib/nnc/passes/caffe_frontend/caffe_frontend.cpp b/contrib/nnc/passes/caffe_frontend/caffe_frontend.cpp new file mode 100644 index 0000000..ccfae89 --- /dev/null +++ b/contrib/nnc/passes/caffe_frontend/caffe_frontend.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +#include "option/Options.h" +#include "pass/PassException.h" +#include "passes/caffe_frontend/CaffeFrontend.h" + +#include "caffe_importer.h" + +using namespace nncc::contrib::pass; +using namespace nncc::contrib::frontend::caffe; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace caffe +{ + +Pass &CaffeFrontend::getInstance() +{ + static CaffeFrontend instance; + return instance; +} + +PassData CaffeFrontend::run(PassData data) +{ + (void)data; + nncc::contrib::frontend::caffe::CaffeImporter importer{clopt::inputFile}; + + bool success = importer.import(); + + if (!success) + { + throw PassException("Could not load model: " + clopt::inputFile + "\n"); + } + + return reinterpret_cast(importer.createIR()); +} + +} // namespace caffe +} // namespace frontend +} // namespace contrib +} // namespace nncc diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_importer.cpp b/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp similarity index 96% rename from contrib/nnc/plugin/caffe_frontend/caffe_importer.cpp rename to contrib/nnc/passes/caffe_frontend/caffe_importer.cpp index 8b69f18..e2fa864 100644 --- a/contrib/nnc/plugin/caffe_frontend/caffe_importer.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp @@ -26,7 +26,7 @@ bool CaffeImporter::import() return util::readProtoFromBinaryFile(modelFilename.c_str(), net.get()); } -void* CaffeImporter::createIR() +void *CaffeImporter::createIR() { ModelVisitor irCreator; ModelWalker caffeWalker(&irCreator); diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_importer.h b/contrib/nnc/passes/caffe_frontend/caffe_importer.h similarity index 89% rename from contrib/nnc/plugin/caffe_frontend/caffe_importer.h rename to contrib/nnc/passes/caffe_frontend/caffe_importer.h index 3a461f9..b92258d 100644 --- a/contrib/nnc/plugin/caffe_frontend/caffe_importer.h +++ b/contrib/nnc/passes/caffe_frontend/caffe_importer.h @@ -6,7 +6,7 @@ #include "caffe/proto/caffe.pb.h" -#include "plugin/common_frontend/nn_importer.h" +#include "passes/common_frontend/nn_importer.h" namespace nncc { @@ -25,7 +25,7 @@ public: explicit CaffeImporter(std::string filename) : modelFilename(std::move(filename)) {}; bool import() override; - void* createIR() override; + void *createIR() override; void dump() override; private: diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_model_visitor.cpp b/contrib/nnc/passes/caffe_frontend/caffe_model_visitor.cpp similarity index 95% rename from contrib/nnc/plugin/caffe_frontend/caffe_model_visitor.cpp rename to contrib/nnc/passes/caffe_frontend/caffe_model_visitor.cpp index 2b9a944..dc3d773 100644 --- a/contrib/nnc/plugin/caffe_frontend/caffe_model_visitor.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_model_visitor.cpp @@ -4,11 +4,13 @@ #include "core/modelIR/Shape.h" #include "core/modelIR/operations/variable_op.h" #include "core/modelIR/TensorUtil.h" -#include "support/PluginException.h" +#include "pass/PassException.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" #include "caffe_model_visitor.h" +using namespace nncc::contrib::pass; + namespace nncc { namespace contrib @@ -83,7 +85,7 @@ void ModelVisitor::visit(const LayerParameter& lp) } else { - throw PluginException("Encountered unsupported Caffe layer type"); + throw PassException("Encountered unsupported Caffe layer type"); } for (auto item : outputs) @@ -130,7 +132,7 @@ void ModelVisitor::processDeprecatedInput(const NetParameter& np) { if (np.input_dim_size() != 0 || np.input_shape_size() != 0) { - throw PluginException("Deprecated Caffe input types are not supported"); + throw PassException("Deprecated Caffe input types are not supported"); } } @@ -182,7 +184,7 @@ std::shared_ptr ModelVisitor::createTensor(const BlobProto &bp) } else { - throw PluginException("No data in Caffe BlobProto, investigate"); + throw PassException("No data in Caffe BlobProto, investigate"); } // Create untyped tensor. Note, tensor contents will be *copied* here. diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_model_visitor.h b/contrib/nnc/passes/caffe_frontend/caffe_model_visitor.h similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_model_visitor.h rename to contrib/nnc/passes/caffe_frontend/caffe_model_visitor.h diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_op_creator.cpp b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp similarity index 93% rename from contrib/nnc/plugin/caffe_frontend/caffe_op_creator.cpp rename to contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp index 3a47b97..d10d01d 100644 --- a/contrib/nnc/plugin/caffe_frontend/caffe_op_creator.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp @@ -15,9 +15,12 @@ #include "core/modelIR/Index.h" #include "core/modelIR/ShapeRange.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" +#include "pass/PassException.h" #include "caffe_op_creator.h" +using namespace nncc::contrib::pass; + namespace nncc { namespace contrib @@ -35,7 +38,7 @@ static inline bool has2DStride(const OptsType& opts) { if (opts.has_stride_h() != opts.has_stride_w()) { - throw PluginException("Conv or Pool layer has only 1 out of 2 2D strides, investigate"); + throw PassException("Conv or Pool layer has only 1 out of 2 2D strides, investigate"); } // We already checked that both 2D strides are both present or both are not return opts.has_stride_h(); @@ -201,7 +204,7 @@ __attribute__ ((unused)) static Shape getPoolWindowShape(const PoolingParameter { if (opts.has_kernel_h() != opts.has_kernel_w()) { - throw PluginException("Pool layer has only 1 out of 2 kernel dimensions, investigate"); + throw PassException("Pool layer has only 1 out of 2 kernel dimensions, investigate"); } if (opts.has_kernel_h()) @@ -214,7 +217,7 @@ __attribute__ ((unused)) static Shape getPoolWindowShape(const PoolingParameter } else { - throw PluginException("Pooling layer doesn't have kernel size data, investigate"); + throw PassException("Pooling layer doesn't have kernel size data, investigate"); } } @@ -227,7 +230,7 @@ __attribute__ ((unused)) static ops::PoolOp::PoolingType getPoolingType(const Po else if (opts.pool() == PoolingParameter::AVE) return PoolingType::AVG; else - throw PluginException("Unsupported pooling type: " + + throw PassException("Unsupported pooling type: " + PoolingParameter::PoolMethod_Name(opts.pool())); } @@ -252,7 +255,7 @@ __attribute__ ((unused)) static int getAxisValue(const OptsType& opts) } else if (axis != 1 && axis != -1) { - throw PluginException("Softmax/Concat layer axis param is not 1 or -1, which implies" + throw PassException("Softmax/Concat layer axis param is not 1 or -1, which implies" "unsupported NN architecture."); } } @@ -364,12 +367,12 @@ std::vector OpCreator::createFullyConnected(InputOps &inputs, InputP { if (opts.has_axis() && opts.axis() != 1) { - throw PluginException("InnerProduct layer axis param is not supported yet"); + throw PassException("InnerProduct layer axis param is not supported yet"); } if (opts.has_transpose() && opts.transpose()) { - throw PluginException("InnerProduct layer transpose param is not supported yet"); + throw PassException("InnerProduct layer transpose param is not supported yet"); } // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize] @@ -403,7 +406,7 @@ std::vector OpCreator::createPool(InputOps inputs, InputParams param if (opts.has_global_pooling() && opts.global_pooling()) { - throw PluginException("Pooling layer global_pooling param is not supported yet"); + throw PassException("Pooling layer global_pooling param is not supported yet"); } Shape windowShape = util::getPoolWindowShape(opts); @@ -437,12 +440,12 @@ std::vector OpCreator::createReshape(InputOps inputs, InputParams pa if (opts.has_axis() || opts.has_num_axes()) { - throw PluginException("Reshape layer axis and num_axes params are not supported yet"); + throw PassException("Reshape layer axis and num_axes params are not supported yet"); } if (!opts.has_shape()) { - throw PluginException("Reshape layer doesn't have shape parameter"); + throw PassException("Reshape layer doesn't have shape parameter"); } Shape newShape = common::ShapeHelper::createShape(opts.shape().dim(), opts.shape().dim_size()); @@ -450,7 +453,7 @@ std::vector OpCreator::createReshape(InputOps inputs, InputParams pa for (unsigned int i = 0; i < newShape.rank(); ++i) { if (newShape.dim(i) == 0) - throw PluginException("Reshape layer zero shape values are not supported yet"); + throw PassException("Reshape layer zero shape values are not supported yet"); } outputs[0]->getOperation()->setOutputShape(0, newShape); @@ -464,7 +467,7 @@ std::vector OpCreator::createRelu(InputOps inputs, InputParams param if (opts.has_negative_slope()) { - throw PluginException("ReLU layer negative_slope param is not supported yet."); + throw PassException("ReLU layer negative_slope param is not supported yet."); } return createOp(inputs); diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_op_creator.h b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.h similarity index 98% rename from contrib/nnc/plugin/caffe_frontend/caffe_op_creator.h rename to contrib/nnc/passes/caffe_frontend/caffe_op_creator.h index 4fa06e7..c6f4970 100644 --- a/contrib/nnc/plugin/caffe_frontend/caffe_op_creator.h +++ b/contrib/nnc/passes/caffe_frontend/caffe_op_creator.h @@ -5,7 +5,6 @@ #include #include -#include "support/PluginException.h" #include "core/modelIR/graph.h" #include "core/modelIR/ir_node.h" #include "core/modelIR/TensorVariant.h" diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_visitor.h b/contrib/nnc/passes/caffe_frontend/caffe_visitor.h similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_visitor.h rename to contrib/nnc/passes/caffe_frontend/caffe_visitor.h diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_walker.cpp b/contrib/nnc/passes/caffe_frontend/caffe_walker.cpp similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_walker.cpp rename to contrib/nnc/passes/caffe_frontend/caffe_walker.cpp diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_walker.h b/contrib/nnc/passes/caffe_frontend/caffe_walker.h similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/caffe_walker.h rename to contrib/nnc/passes/caffe_frontend/caffe_walker.h diff --git a/contrib/nnc/plugin/caffe_frontend/proto_reader.cpp b/contrib/nnc/passes/caffe_frontend/proto_reader.cpp similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/proto_reader.cpp rename to contrib/nnc/passes/caffe_frontend/proto_reader.cpp diff --git a/contrib/nnc/plugin/caffe_frontend/proto_reader.h b/contrib/nnc/passes/caffe_frontend/proto_reader.h similarity index 100% rename from contrib/nnc/plugin/caffe_frontend/proto_reader.h rename to contrib/nnc/passes/caffe_frontend/proto_reader.h diff --git a/contrib/nnc/plugin/common_frontend/CMakeLists.txt b/contrib/nnc/passes/common_frontend/CMakeLists.txt similarity index 100% rename from contrib/nnc/plugin/common_frontend/CMakeLists.txt rename to contrib/nnc/passes/common_frontend/CMakeLists.txt diff --git a/contrib/nnc/plugin/common_frontend/model_allocation.cpp b/contrib/nnc/passes/common_frontend/model_allocation.cpp similarity index 94% rename from contrib/nnc/plugin/common_frontend/model_allocation.cpp rename to contrib/nnc/passes/common_frontend/model_allocation.cpp index 5983368..ff062c2 100644 --- a/contrib/nnc/plugin/common_frontend/model_allocation.cpp +++ b/contrib/nnc/passes/common_frontend/model_allocation.cpp @@ -3,7 +3,7 @@ #include #include -#include "plugin/common_frontend/model_allocation.h" +#include "passes/common_frontend/model_allocation.h" using namespace nncc::contrib::frontend::common; diff --git a/contrib/nnc/passes/common_frontend/shape_helper.cpp b/contrib/nnc/passes/common_frontend/shape_helper.cpp new file mode 100644 index 0000000..f6ed037 --- /dev/null +++ b/contrib/nnc/passes/common_frontend/shape_helper.cpp @@ -0,0 +1,39 @@ +#include + +#include "passes/common_frontend/shape_helper.h" +#include "pass/PassException.h" + +using namespace nncc::contrib::pass; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace common +{ + +Shape &ShapeHelper::cutOffBatchDim(Shape &shape) +{ + if (shape.dim(0) != 1) + { + throw PassException{"While attempting to cut off tensor batch dimension (first one)," + "found that it is not 1. Check the model being imported, if the first" + "dimension of the input is not 1, then it might be not batch, and the" + "code needs some restructuring"}; + } + + for (unsigned int i = 0; i < shape.rank() - 1; ++i) + { + shape.dim(i) = shape.dim(i + 1); + } + shape.resize(shape.rank() - 1); + + return shape; +} + +} // namespace common +} // namespace frontend +} // namespace contrib +} // namespace nncc diff --git a/contrib/nnc/plugin/interpreter/CMakeLists.txt b/contrib/nnc/passes/interpreter/CMakeLists.txt similarity index 75% rename from contrib/nnc/plugin/interpreter/CMakeLists.txt rename to contrib/nnc/passes/interpreter/CMakeLists.txt index acf5d30..ec33943 100644 --- a/contrib/nnc/plugin/interpreter/CMakeLists.txt +++ b/contrib/nnc/passes/interpreter/CMakeLists.txt @@ -1,7 +1,6 @@ file(GLOB_RECURSE interp_src ./*.cpp ./*.h) add_library(nnc_interpreter SHARED ${interp_src}) target_link_libraries(nnc_interpreter PRIVATE nnc_core nnc_support) -set_target_properties(nnc_interpreter PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NNC_BINARY_DIR}) if(NNC_HDF5_SUPPORTED) target_include_directories(nnc_interpreter PRIVATE ${HDF5_INCLUDE_DIRS}) @@ -9,4 +8,4 @@ if(NNC_HDF5_SUPPORTED) endif(NNC_HDF5_SUPPORTED) # install interpreter library -install_nnc_plugin(nnc_interpreter) \ No newline at end of file +install_nnc_library(nnc_interpreter) \ No newline at end of file diff --git a/contrib/nnc/plugin/interpreter/Interpreter.cpp b/contrib/nnc/passes/interpreter/Interpreter.cpp similarity index 99% rename from contrib/nnc/plugin/interpreter/Interpreter.cpp rename to contrib/nnc/passes/interpreter/Interpreter.cpp index 4e42038..de16ea6 100644 --- a/contrib/nnc/plugin/interpreter/Interpreter.cpp +++ b/contrib/nnc/passes/interpreter/Interpreter.cpp @@ -1,7 +1,7 @@ #include #include -#include "plugin/interpreter/Interpreter.h" +#include "passes/interpreter/Interpreter.h" #include "core/modelIR/operations/fully_connected_op.h" #include "core/modelIR/operations/softmax_op.h" diff --git a/contrib/nnc/plugin/interpreter/interpreter_plugin.cpp b/contrib/nnc/passes/interpreter/interpreter_pass.cpp similarity index 82% rename from contrib/nnc/plugin/interpreter/interpreter_plugin.cpp rename to contrib/nnc/passes/interpreter/interpreter_pass.cpp index bbde44b..1b09957 100644 --- a/contrib/nnc/plugin/interpreter/interpreter_plugin.cpp +++ b/contrib/nnc/passes/interpreter/interpreter_pass.cpp @@ -10,11 +10,14 @@ #include #endif // NNC_HDF5_SUPPORTED -#include "support/PluginInstance.h" - #include "core/modelIR/Shape.h" -#include "plugin/interpreter/Interpreter.h" +#include "pass/Pass.h" +#include "pass/PassData.h" +#include "pass/PassException.h" + +#include "passes/interpreter/Interpreter.h" +#include "passes/interpreter/InterpreterPass.h" #include "core/modelIR/ShapeInference.h" #include "core/modelIR/graph.h" @@ -22,7 +25,6 @@ #include "core/modelIR/ShapeRange.h" #include "core/modelIR/Tensor.h" -#include "interpreter_plugin.h" namespace nncc { @@ -32,18 +34,16 @@ namespace backend { namespace interpreter { -namespace plugin -{ using namespace nncc::contrib; +using namespace nncc::contrib::pass; using namespace nncc::contrib::core::data; using namespace nncc::contrib::core::IR::model; using nncc::contrib::core::data::Shape; -using nncc::contrib::plugin::BackendPlugin; using nncc::contrib::backend::interpreter::core::NNInterpreter; -BackendPlugin &InterpreterPlugin::getInstance() { - static InterpreterPlugin instance; +Pass &InterpreterPass::getInstance() { + static InterpreterPass instance; return instance; } @@ -84,8 +84,11 @@ static void writeTensorToHDF5File(TensorVariant *tensor, std::string tensorName, } #endif // NNC_HDF5_SUPPORTED -void *InterpreterPlugin::execute(void *data) { +PassData InterpreterPass::run(PassData data) +{ auto g = static_cast(data); + assert(g); + ShapeInference shapeInference; NNInterpreter interpreter; @@ -95,12 +98,12 @@ void *InterpreterPlugin::execute(void *data) { // Check nodes auto inputNode = g->getInput(clopt::interInNode); if (inputNode == nullptr) { - throw PluginException("input node <" + clopt::interInNode +"> not found" ); + throw PassException("input node <" + clopt::interInNode +"> not found" ); } auto outputNode = g->getOutput(clopt::interOutNode); if (outputNode == nullptr) { - throw PluginException("output node <" + clopt::interOutNode +"> not found" ); + throw PassException("output node <" + clopt::interOutNode +"> not found" ); } auto input = loadInput(inputNode->getOperation()->getOutputShape(0)); @@ -119,7 +122,7 @@ void *InterpreterPlugin::execute(void *data) { return _out; } -TensorVariant InterpreterPlugin::loadInput(const Shape &shape) +TensorVariant InterpreterPass::loadInput(const Shape &shape) { auto f = fopen(clopt::interInputData.c_str(), "rb"); assert(f && "Cannot open file"); @@ -133,7 +136,7 @@ TensorVariant InterpreterPlugin::loadInput(const Shape &shape) if (len != tensorSize) { std::stringstream info; info << "Wrong input file size <" << clopt::interInputData << "> = " << len << ". Should be :" << tensorSize; - throw PluginException(info.str()); + throw PassException(info.str()); } rewind(f); @@ -147,18 +150,12 @@ TensorVariant InterpreterPlugin::loadInput(const Shape &shape) return TensorVariant(shape, std::shared_ptr(data, [](const char* d) { delete[] d; }), TensorVariant::DTYPE::FLOAT, sizeof(float)); } -InterpreterPlugin::~InterpreterPlugin() +InterpreterPass::~InterpreterPass() { delete _out; } -} // namespace plugin } // namespace interpreter } // namespace backend } // namespace contrib } // namespace nncc - -extern "C" nncc::contrib::backend::interpreter::plugin::Plugin *get_instance() { - return &nncc::contrib::backend::interpreter::plugin::InterpreterPlugin::getInstance(); -} - diff --git a/contrib/nnc/plugin/interpreter/ops/Bias.cpp b/contrib/nnc/passes/interpreter/ops/Bias.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Bias.cpp rename to contrib/nnc/passes/interpreter/ops/Bias.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Bias.h b/contrib/nnc/passes/interpreter/ops/Bias.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Bias.h rename to contrib/nnc/passes/interpreter/ops/Bias.h diff --git a/contrib/nnc/plugin/interpreter/ops/Concat.cpp b/contrib/nnc/passes/interpreter/ops/Concat.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Concat.cpp rename to contrib/nnc/passes/interpreter/ops/Concat.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Concat.h b/contrib/nnc/passes/interpreter/ops/Concat.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Concat.h rename to contrib/nnc/passes/interpreter/ops/Concat.h diff --git a/contrib/nnc/plugin/interpreter/ops/Depthwise_conv_2D.cpp b/contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Depthwise_conv_2D.cpp rename to contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Depthwise_conv_2D.h b/contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Depthwise_conv_2D.h rename to contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.h diff --git a/contrib/nnc/plugin/interpreter/ops/Elementwise.cpp b/contrib/nnc/passes/interpreter/ops/Elementwise.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Elementwise.cpp rename to contrib/nnc/passes/interpreter/ops/Elementwise.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Elementwise.h b/contrib/nnc/passes/interpreter/ops/Elementwise.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Elementwise.h rename to contrib/nnc/passes/interpreter/ops/Elementwise.h diff --git a/contrib/nnc/plugin/interpreter/ops/Fill.cpp b/contrib/nnc/passes/interpreter/ops/Fill.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Fill.cpp rename to contrib/nnc/passes/interpreter/ops/Fill.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Fill.h b/contrib/nnc/passes/interpreter/ops/Fill.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Fill.h rename to contrib/nnc/passes/interpreter/ops/Fill.h diff --git a/contrib/nnc/plugin/interpreter/ops/FullyConnected.cpp b/contrib/nnc/passes/interpreter/ops/FullyConnected.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/FullyConnected.cpp rename to contrib/nnc/passes/interpreter/ops/FullyConnected.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/FullyConnected.h b/contrib/nnc/passes/interpreter/ops/FullyConnected.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/FullyConnected.h rename to contrib/nnc/passes/interpreter/ops/FullyConnected.h diff --git a/contrib/nnc/plugin/interpreter/ops/OperationImpl.h b/contrib/nnc/passes/interpreter/ops/OperationImpl.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/OperationImpl.h rename to contrib/nnc/passes/interpreter/ops/OperationImpl.h diff --git a/contrib/nnc/plugin/interpreter/ops/Pool.cpp b/contrib/nnc/passes/interpreter/ops/Pool.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Pool.cpp rename to contrib/nnc/passes/interpreter/ops/Pool.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Pool.h b/contrib/nnc/passes/interpreter/ops/Pool.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Pool.h rename to contrib/nnc/passes/interpreter/ops/Pool.h diff --git a/contrib/nnc/plugin/interpreter/ops/Reduce.cpp b/contrib/nnc/passes/interpreter/ops/Reduce.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Reduce.cpp rename to contrib/nnc/passes/interpreter/ops/Reduce.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Reduce.h b/contrib/nnc/passes/interpreter/ops/Reduce.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Reduce.h rename to contrib/nnc/passes/interpreter/ops/Reduce.h diff --git a/contrib/nnc/plugin/interpreter/ops/Reshape.cpp b/contrib/nnc/passes/interpreter/ops/Reshape.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Reshape.cpp rename to contrib/nnc/passes/interpreter/ops/Reshape.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Reshape.h b/contrib/nnc/passes/interpreter/ops/Reshape.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Reshape.h rename to contrib/nnc/passes/interpreter/ops/Reshape.h diff --git a/contrib/nnc/plugin/interpreter/ops/Softmax.cpp b/contrib/nnc/passes/interpreter/ops/Softmax.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Softmax.cpp rename to contrib/nnc/passes/interpreter/ops/Softmax.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/Softmax.h b/contrib/nnc/passes/interpreter/ops/Softmax.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/Softmax.h rename to contrib/nnc/passes/interpreter/ops/Softmax.h diff --git a/contrib/nnc/plugin/interpreter/ops/common.cpp b/contrib/nnc/passes/interpreter/ops/common.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/common.cpp rename to contrib/nnc/passes/interpreter/ops/common.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/common.h b/contrib/nnc/passes/interpreter/ops/common.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/common.h rename to contrib/nnc/passes/interpreter/ops/common.h diff --git a/contrib/nnc/plugin/interpreter/ops/conv_2D.cpp b/contrib/nnc/passes/interpreter/ops/conv_2D.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/conv_2D.cpp rename to contrib/nnc/passes/interpreter/ops/conv_2D.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/conv_2D.h b/contrib/nnc/passes/interpreter/ops/conv_2D.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/conv_2D.h rename to contrib/nnc/passes/interpreter/ops/conv_2D.h diff --git a/contrib/nnc/plugin/interpreter/ops/conv_FFT.cpp b/contrib/nnc/passes/interpreter/ops/conv_FFT.cpp similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/conv_FFT.cpp rename to contrib/nnc/passes/interpreter/ops/conv_FFT.cpp diff --git a/contrib/nnc/plugin/interpreter/ops/conv_FFT.h b/contrib/nnc/passes/interpreter/ops/conv_FFT.h similarity index 100% rename from contrib/nnc/plugin/interpreter/ops/conv_FFT.h rename to contrib/nnc/passes/interpreter/ops/conv_FFT.h diff --git a/contrib/nnc/plugin/soft_backend/base_generator.cpp b/contrib/nnc/passes/soft_backend/BaseGenerator.cpp similarity index 83% rename from contrib/nnc/plugin/soft_backend/base_generator.cpp rename to contrib/nnc/passes/soft_backend/BaseGenerator.cpp index dd08bde..8004e27 100644 --- a/contrib/nnc/plugin/soft_backend/base_generator.cpp +++ b/contrib/nnc/passes/soft_backend/BaseGenerator.cpp @@ -1,10 +1,11 @@ -#include "base_generator.h" +#include "passes/soft_backend/BaseGenerator.h" #include "model_analyzer.h" #include "serializer.h" -#include "support/PluginException.h" #include "core/modelIR/ShapeInference.h" #include "option/Options.h" - +#include "pass/Pass.h" +#include "pass/PassData.h" +#include "pass/PassException.h" #include "param_constants.def" #include @@ -19,6 +20,7 @@ using namespace std; using namespace nncc::contrib; +using namespace nncc::contrib::pass; using namespace nncc::contrib::core::IR::model; namespace nncc @@ -38,7 +40,7 @@ unique_ptr getStream(const string &path) unique_ptr ofs(new ofstream(path)); if (ofs->fail()) { - throw PluginException("Can not open code output file: " + path); + throw PassException("Can not open code output file: " + path); } return ofs; } @@ -48,7 +50,7 @@ void createDir(const string &path) int res = mkdir(path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); if (res != 0 && errno != EEXIST) { - throw PluginException("Failed to create output directory"); + throw PassException("Failed to create output directory"); } } @@ -79,19 +81,21 @@ void BaseCodeGenerator::materializeModelParams(ostream &out, const Serializer &s out.write(header, HEADER_LEN); if (out.fail()) { - throw PluginException("Failed to write model parameters header"); + throw PassException("Failed to write model parameters header"); } auto ¶ms = s.getBuffer(); out.write(params.data(), params.size()); if (out.fail()) { - throw PluginException("Failed to write model Parameters"); + throw PassException("Failed to write model Parameters"); } } -void *BaseCodeGenerator::execute(void *data) +PassData BaseCodeGenerator::run(PassData data) { - Graph *g = reinterpret_cast(data); + auto g = static_cast(data); + assert(g); + // inference shapes core::IR::model::ShapeInference si; g->accept(&si); diff --git a/contrib/nnc/plugin/soft_backend/c_generator.cpp b/contrib/nnc/passes/soft_backend/CGenerator.cpp similarity index 79% rename from contrib/nnc/plugin/soft_backend/c_generator.cpp rename to contrib/nnc/passes/soft_backend/CGenerator.cpp index 55d5e91..9daf35e 100644 --- a/contrib/nnc/plugin/soft_backend/c_generator.cpp +++ b/contrib/nnc/passes/soft_backend/CGenerator.cpp @@ -1,4 +1,4 @@ -#include "c_generator.h" +#include "passes/soft_backend/CGenerator.h" #include "model_analyzer.h" using namespace std; @@ -29,13 +29,13 @@ void CCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, cons // TODO emit C code to out stream } +Pass &CCodeGenerator::getInstance() +{ + static CCodeGenerator instance; + return instance; +} + } // namespace soft } // namespace backend } // namespace contrib } // namespace nncc - -extern "C" nncc::contrib::plugin::Plugin *get_instance() -{ - static nncc::contrib::backend::soft::CCodeGenerator cCodeGenerator; - return &cCodeGenerator; -} diff --git a/contrib/nnc/plugin/soft_backend/CMakeLists.txt b/contrib/nnc/passes/soft_backend/CMakeLists.txt similarity index 75% rename from contrib/nnc/plugin/soft_backend/CMakeLists.txt rename to contrib/nnc/passes/soft_backend/CMakeLists.txt index 22b5998..90ff8df 100644 --- a/contrib/nnc/plugin/soft_backend/CMakeLists.txt +++ b/contrib/nnc/passes/soft_backend/CMakeLists.txt @@ -1,6 +1,6 @@ -set(SOFT_BACKEND_COMMON_SOURCES base_generator.cpp model_analyzer.cpp serializer.cpp) +set(SOFT_BACKEND_COMMON_SOURCES BaseGenerator.cpp model_analyzer.cpp serializer.cpp) set(SOFT_BACKEND_CPP_SOURCES cpp_generator.cpp) -set(SOFT_BACKEND_C_SOURCES c_generator.cpp) +set(SOFT_BACKEND_C_SOURCES CGenerator.cpp) set(DEF_CONV ${NNC_ROOT_SRC_DIR}/utils/def2src.cpp) file(GLOB_RECURSE SOFT_DEF_SOURCES "*.def") @@ -14,12 +14,9 @@ set_property(TARGET soft_backend_common PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(soft_backend_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(soft_backend_common PRIVATE nnc_support) target_link_libraries(soft_backend_common PRIVATE nnc_core) -# This is included because right now common functional is built into nnc_driver -target_link_libraries(soft_backend_common PRIVATE nnc_driver) function(make_soft_backend NAME) add_library(${NAME} SHARED ${ARGN} ${SOFT_GENERATED_SOURCES}) - set_target_properties(${NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NNC_BINARY_DIR}) target_include_directories(${NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(${NAME} PRIVATE soft_backend_common) @@ -27,7 +24,7 @@ function(make_soft_backend NAME) target_link_libraries(${NAME} PRIVATE nnc_core) # install soft backend c++ library - install_nnc_plugin(${NAME}) + install_nnc_library(${NAME}) endfunction(make_soft_backend) make_soft_backend(soft_backend_cpp ${SOFT_BACKEND_CPP_SOURCES}) diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_add_bias.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_add_bias.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_add_bias.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_add_bias.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_capped_relu.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_capped_relu.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_capped_relu.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_capped_relu.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_common_funcs.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_common_funcs.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_common_funcs.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_common_funcs.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_concat.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_concat.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_concat.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_concat.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_conv.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_conv.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_conv.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_conv.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_depthwise_conv.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_depthwise_conv.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_depthwise_conv.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_depthwise_conv.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_fully_connected.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_fully_connected.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_fully_connected.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_fully_connected.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_header_types.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_header_types.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_header_types.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_header_types.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_operations.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_operations.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_pool.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_pool.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_pool.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_pool.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_relu.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_relu.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_relu.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_relu.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/cpp_softmax.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_softmax.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/cpp_softmax.def rename to contrib/nnc/passes/soft_backend/code_snippets/cpp_softmax.def diff --git a/contrib/nnc/plugin/soft_backend/code_snippets/eigen.def b/contrib/nnc/passes/soft_backend/code_snippets/eigen.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/code_snippets/eigen.def rename to contrib/nnc/passes/soft_backend/code_snippets/eigen.def diff --git a/contrib/nnc/plugin/soft_backend/cpp_generator.cpp b/contrib/nnc/passes/soft_backend/cpp_generator.cpp similarity index 97% rename from contrib/nnc/plugin/soft_backend/cpp_generator.cpp rename to contrib/nnc/passes/soft_backend/cpp_generator.cpp index 2ad887f..710069d 100644 --- a/contrib/nnc/plugin/soft_backend/cpp_generator.cpp +++ b/contrib/nnc/passes/soft_backend/cpp_generator.cpp @@ -1,7 +1,6 @@ -#include "cpp_generator.h" +#include "passes/soft_backend/CPPGenerator.h" #include "model_analyzer.h" #include "serializer.h" -#include "support/PluginException.h" #include "option/Options.h" using namespace std; @@ -283,13 +282,13 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co out << "}"; } +Pass &CPPCodeGenerator::getInstance() +{ + static CPPCodeGenerator cppCodeGenerator; + return cppCodeGenerator; +} + } // namespace soft } // namespace backend } // namespace contrib } // namespace nncc - -extern "C" nncc::contrib::plugin::Plugin *get_instance() -{ - static nncc::contrib::backend::soft::CPPCodeGenerator cppCodeGenerator; - return &cppCodeGenerator; -} diff --git a/contrib/nnc/plugin/soft_backend/model_analyzer.cpp b/contrib/nnc/passes/soft_backend/model_analyzer.cpp similarity index 100% rename from contrib/nnc/plugin/soft_backend/model_analyzer.cpp rename to contrib/nnc/passes/soft_backend/model_analyzer.cpp diff --git a/contrib/nnc/plugin/soft_backend/model_analyzer.h b/contrib/nnc/passes/soft_backend/model_analyzer.h similarity index 100% rename from contrib/nnc/plugin/soft_backend/model_analyzer.h rename to contrib/nnc/passes/soft_backend/model_analyzer.h diff --git a/contrib/nnc/plugin/soft_backend/param_constants.def b/contrib/nnc/passes/soft_backend/param_constants.def similarity index 100% rename from contrib/nnc/plugin/soft_backend/param_constants.def rename to contrib/nnc/passes/soft_backend/param_constants.def diff --git a/contrib/nnc/plugin/soft_backend/serializer.cpp b/contrib/nnc/passes/soft_backend/serializer.cpp similarity index 100% rename from contrib/nnc/plugin/soft_backend/serializer.cpp rename to contrib/nnc/passes/soft_backend/serializer.cpp diff --git a/contrib/nnc/plugin/soft_backend/serializer.h b/contrib/nnc/passes/soft_backend/serializer.h similarity index 100% rename from contrib/nnc/plugin/soft_backend/serializer.h rename to contrib/nnc/passes/soft_backend/serializer.h diff --git a/contrib/nnc/plugin/tflite_frontend/CMakeLists.txt b/contrib/nnc/passes/tflite_frontend/CMakeLists.txt similarity index 89% rename from contrib/nnc/plugin/tflite_frontend/CMakeLists.txt rename to contrib/nnc/passes/tflite_frontend/CMakeLists.txt index 959d07a..ee256f3 100644 --- a/contrib/nnc/plugin/tflite_frontend/CMakeLists.txt +++ b/contrib/nnc/passes/tflite_frontend/CMakeLists.txt @@ -27,12 +27,11 @@ set(tflite_importer_sources tflite_walker.cpp tflite_ir_visitor.cpp tflite_op_creator.cpp tflite_v3_importer.cpp - tflite_plugin.cpp) + tflite_frontend.cpp) file(GLOB tflite_importer_headers *.h) set(tflite_import tflite_import) add_library(${tflite_import} SHARED ${tflite_importer_sources} ${tflite_importer_headers}) -set_target_properties(${tflite_import} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NNC_BINARY_DIR}) target_link_libraries(${tflite_import} PUBLIC tflite_schema) target_link_libraries(${tflite_import} PUBLIC flatbuffers) @@ -41,4 +40,4 @@ target_link_libraries(${tflite_import} PUBLIC nnc_support) target_link_libraries(${tflite_import} PUBLIC nnc_core) # install tflite frontend library -install_nnc_plugin(tflite_import) +install_nnc_library(tflite_import) diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema.fbs similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema.fbs rename to contrib/nnc/passes/tflite_frontend/schema/schema.fbs diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema.meta b/contrib/nnc/passes/tflite_frontend/schema/schema.meta similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema.meta rename to contrib/nnc/passes/tflite_frontend/schema/schema.meta diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v0.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema_v0.fbs similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v0.fbs rename to contrib/nnc/passes/tflite_frontend/schema/schema_v0.fbs diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v0.meta b/contrib/nnc/passes/tflite_frontend/schema/schema_v0.meta similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v0.meta rename to contrib/nnc/passes/tflite_frontend/schema/schema_v0.meta diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v1.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema_v1.fbs similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v1.fbs rename to contrib/nnc/passes/tflite_frontend/schema/schema_v1.fbs diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v1.meta b/contrib/nnc/passes/tflite_frontend/schema/schema_v1.meta similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v1.meta rename to contrib/nnc/passes/tflite_frontend/schema/schema_v1.meta diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v2.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema_v2.fbs similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v2.fbs rename to contrib/nnc/passes/tflite_frontend/schema/schema_v2.fbs diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v2.meta b/contrib/nnc/passes/tflite_frontend/schema/schema_v2.meta similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v2.meta rename to contrib/nnc/passes/tflite_frontend/schema/schema_v2.meta diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v3.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema_v3.fbs similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v3.fbs rename to contrib/nnc/passes/tflite_frontend/schema/schema_v3.fbs diff --git a/contrib/nnc/plugin/tflite_frontend/schema/schema_v3.meta b/contrib/nnc/passes/tflite_frontend/schema/schema_v3.meta similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema/schema_v3.meta rename to contrib/nnc/passes/tflite_frontend/schema/schema_v3.meta diff --git a/contrib/nnc/plugin/tflite_frontend/schema_v3.h b/contrib/nnc/passes/tflite_frontend/schema_v3.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/schema_v3.h rename to contrib/nnc/passes/tflite_frontend/schema_v3.h diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_dump_visitor.cpp b/contrib/nnc/passes/tflite_frontend/tflite_dump_visitor.cpp similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_dump_visitor.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_dump_visitor.cpp diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_dump_visitor.h b/contrib/nnc/passes/tflite_frontend/tflite_dump_visitor.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_dump_visitor.h rename to contrib/nnc/passes/tflite_frontend/tflite_dump_visitor.h diff --git a/contrib/nnc/passes/tflite_frontend/tflite_frontend.cpp b/contrib/nnc/passes/tflite_frontend/tflite_frontend.cpp new file mode 100644 index 0000000..ad43b85 --- /dev/null +++ b/contrib/nnc/passes/tflite_frontend/tflite_frontend.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +#include "pass/Pass.h" +#include "pass/PassException.h" +#include "passes/tflite_frontend/TfliteFrontend.h" +#include "option/Options.h" + +#include "tflite_v3_importer.h" + +using namespace nncc::contrib; +using namespace nncc::contrib::pass; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +Pass &TFLiteFrontend::getInstance() +{ + static TFLiteFrontend instance; + return instance; +} + +PassData TFLiteFrontend::run(PassData data) +{ + nncc::contrib::frontend::tflite::v3::TfliteImporter importer{clopt::inputFile}; + + bool success = importer.import(); + + if (!success) + { + throw PassException("Could not load model: " + clopt::inputFile + "\n"); + } + + return reinterpret_cast(importer.createIR()); +} + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_importer.inline.cpp b/contrib/nnc/passes/tflite_frontend/tflite_importer.inline.cpp similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_importer.inline.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_importer.inline.cpp diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_importer.inline.h b/contrib/nnc/passes/tflite_frontend/tflite_importer.inline.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_importer.inline.h rename to contrib/nnc/passes/tflite_frontend/tflite_importer.inline.h diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_ir_visitor.cpp b/contrib/nnc/passes/tflite_frontend/tflite_ir_visitor.cpp similarity index 95% rename from contrib/nnc/plugin/tflite_frontend/tflite_ir_visitor.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_ir_visitor.cpp index 36617ea..3a664d6 100644 --- a/contrib/nnc/plugin/tflite_frontend/tflite_ir_visitor.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_ir_visitor.cpp @@ -2,17 +2,19 @@ #include #include "schema_v3.h" -#include "support/PluginException.h" +#include "pass/PassException.h" #include "core/modelIR/Shape.h" #include "core/modelIR/Index.h" #include "core/modelIR/IndexRange.h" #include "core/modelIR/TensorUtil.h" #include "core/modelIR/operations/variable_op.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" #include "tflite_ir_visitor.h" #include "tflite_op_creator.h" +using namespace nncc::contrib::pass; + namespace nncc { namespace contrib @@ -101,7 +103,7 @@ void IrVisitor::visit(const Operator *op) outputs = opCreator->createSoftmax(inputs, params, op->builtin_options_as()); break; default: - throw PluginException( + throw PassException( std::string("Encountered unsupported TFLite operator: ") + EnumNamesBuiltinOperator()[opcode]); } @@ -136,8 +138,8 @@ std::vector IrVisitor::createOpInputs(const Operator *op) } catch (const std::out_of_range &e) { - throw PluginException("Found a TFLite operator with an input tensor for which " - "a corresponding Model IR node that outputs it was not created."); + throw PassException("Found a TFLite operator with an input tensor for which " + "a corresponding Model IR node that outputs it was not created."); } return inputsForOp; @@ -212,7 +214,7 @@ std::shared_ptr IrVisitor::createTensor(const Tensor *t, const Buffer type = IrTensor::DTYPE::INT; break; default: - throw PluginException( + throw PassException( std::string("Encountered unsupported tensor type ") + EnumNamesTensorType()[t->type()]); } diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_ir_visitor.h b/contrib/nnc/passes/tflite_frontend/tflite_ir_visitor.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_ir_visitor.h rename to contrib/nnc/passes/tflite_frontend/tflite_ir_visitor.h diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_op_creator.cpp b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp similarity index 96% rename from contrib/nnc/plugin/tflite_frontend/tflite_op_creator.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp index 12168a7..bbba2f6 100644 --- a/contrib/nnc/plugin/tflite_frontend/tflite_op_creator.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp @@ -10,6 +10,9 @@ #include "core/modelIR/operations/pool_op.h" #include "core/modelIR/operations/bias_add_op.h" #include "core/modelIR/operations/reshape_op.h" +#include "pass/PassException.h" + +using namespace nncc::contrib::pass; namespace nncc { @@ -121,8 +124,8 @@ INode::Ref OpCreator::addFusedActivation(INode::Ref input, ActivationFunctionTyp activation = graph->create("", 6); break; default: - throw PluginException(std::string("Encountered unsupported NN activation type: ") + - EnumNamesActivationFunctionType()[activationType]); + throw PassException(std::string("Encountered unsupported NN activation type: ") + + EnumNamesActivationFunctionType()[activationType]); } assert(input->getOperation()->getNumOutputs() == 1); diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_op_creator.h b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h similarity index 97% rename from contrib/nnc/plugin/tflite_frontend/tflite_op_creator.h rename to contrib/nnc/passes/tflite_frontend/tflite_op_creator.h index b038396..1e75c1a 100644 --- a/contrib/nnc/plugin/tflite_frontend/tflite_op_creator.h +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h @@ -6,7 +6,6 @@ #include #include -#include "support/PluginException.h" #include "core/modelIR/graph.h" #include "core/modelIR/ir_node.h" #include "core/modelIR/TensorVariant.h" @@ -15,7 +14,7 @@ #include "core/modelIR/operations/common.h" #include "schema_v3.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" namespace nncc { diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_v3_importer.cpp b/contrib/nnc/passes/tflite_frontend/tflite_v3_importer.cpp similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_v3_importer.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_v3_importer.cpp diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_v3_importer.h b/contrib/nnc/passes/tflite_frontend/tflite_v3_importer.h similarity index 81% rename from contrib/nnc/plugin/tflite_frontend/tflite_v3_importer.h rename to contrib/nnc/passes/tflite_frontend/tflite_v3_importer.h index 7d255a8..037a9a9 100644 --- a/contrib/nnc/plugin/tflite_frontend/tflite_v3_importer.h +++ b/contrib/nnc/passes/tflite_frontend/tflite_v3_importer.h @@ -5,8 +5,8 @@ #include #include "schema_v3.h" -#include "plugin/common_frontend/nn_importer.h" -#include "plugin/common_frontend/model_allocation.h" +#include "passes/common_frontend/nn_importer.h" +#include "passes/common_frontend/model_allocation.h" namespace nncc { diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_visitor.h b/contrib/nnc/passes/tflite_frontend/tflite_visitor.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_visitor.h rename to contrib/nnc/passes/tflite_frontend/tflite_visitor.h diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_walker.cpp b/contrib/nnc/passes/tflite_frontend/tflite_walker.cpp similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_walker.cpp rename to contrib/nnc/passes/tflite_frontend/tflite_walker.cpp diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_walker.h b/contrib/nnc/passes/tflite_frontend/tflite_walker.h similarity index 100% rename from contrib/nnc/plugin/tflite_frontend/tflite_walker.h rename to contrib/nnc/passes/tflite_frontend/tflite_walker.h diff --git a/contrib/nnc/plugin/caffe_frontend/caffe_plugin.cpp b/contrib/nnc/plugin/caffe_frontend/caffe_plugin.cpp deleted file mode 100644 index 04d8c11..0000000 --- a/contrib/nnc/plugin/caffe_frontend/caffe_plugin.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include - -#include "support/PluginInstance.h" -#include "support/PluginException.h" -#include "option/Options.h" - -#include "caffe_importer.h" - -namespace -{ -using namespace nncc::contrib; -using namespace nncc::contrib::plugin; - -class ImporterPlugin : public FrontendPlugin -{ -public: - ImporterPlugin &operator=(const ImporterPlugin &) = delete; - ImporterPlugin(const ImporterPlugin &) = delete; - - static FrontendPlugin &getInstance(); - void *execute(void *data) override; - -private: - ImporterPlugin() = default; - ~ImporterPlugin() override = default; -}; - -FrontendPlugin &ImporterPlugin::getInstance() -{ - static ImporterPlugin instance; - return instance; -} - -void *ImporterPlugin::execute(void *) -{ - nncc::contrib::frontend::caffe::CaffeImporter importer{clopt::inputFile}; - - bool success = importer.import(); - - if (!success) - { - throw nncc::contrib::PluginException("Could not load model: " + clopt::inputFile + "\n"); - } - - return importer.createIR(); -} - -} // anonymous namespace - -extern "C" Plugin *get_instance() -{ - return &ImporterPlugin::getInstance(); -} diff --git a/contrib/nnc/plugin/common_frontend/shape_helper.cpp b/contrib/nnc/plugin/common_frontend/shape_helper.cpp deleted file mode 100644 index f716095..0000000 --- a/contrib/nnc/plugin/common_frontend/shape_helper.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include - -#include "plugin/common_frontend/shape_helper.h" -#include "support/PluginException.h" - -namespace nncc -{ -namespace contrib -{ -namespace frontend -{ -namespace common -{ - -Shape &ShapeHelper::cutOffBatchDim(Shape &shape) -{ - if (shape.dim(0) != 1) - { - throw PluginException{"While attempting to cut off tensor batch dimension (first one)," - "found that it is not 1. Check the model being imported, if the first" - "dimension of the input is not 1, then it might be not batch, and the" - "code needs some restructuring"}; - } - - for (unsigned int i = 0; i < shape.rank() - 1; ++i) - { - shape.dim(i) = shape.dim(i + 1); - } - shape.resize(shape.rank() - 1); - - return shape; -} - -} // namespace common -} // namespace frontend -} // namespace contrib -} // namespace nncc diff --git a/contrib/nnc/plugin/interpreter/interpreter_plugin.h b/contrib/nnc/plugin/interpreter/interpreter_plugin.h deleted file mode 100644 index 456489a..0000000 --- a/contrib/nnc/plugin/interpreter/interpreter_plugin.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef _NNC_BACKEND_INTERPRETER_PLUGIN_ -#define _NNC_BACKEND_INTERPRETER_PLUGIN_ - -#include - -#include "support/PluginInstance.h" -#include "support/PluginException.h" - -#include "core/modelIR/TensorVariant.h" -#include "core/modelIR/Shape.h" - -namespace nncc -{ -namespace contrib -{ -namespace backend -{ -namespace interpreter -{ -namespace plugin -{ - -using namespace nncc::contrib; -using namespace nncc::contrib::plugin; - -class InterpreterPlugin : public BackendPlugin { - public: - static BackendPlugin &getInstance(); - void *execute(void *data) override; - - virtual ~InterpreterPlugin(); - -private: - nncc::contrib::core::ADT::TensorVariant loadInput(const nncc::contrib::core::data::Shape &); - nncc::contrib::core::ADT::TensorVariant *_out; -}; - -} // namespace plugin -} // namespace interpreter -} // namespace backend -} // namespace contrib -} // namespace nncc - -#endif //_NNC_BACKEND_INTERPRETER_PLUGIN_ diff --git a/contrib/nnc/plugin/tflite_frontend/tflite_plugin.cpp b/contrib/nnc/plugin/tflite_frontend/tflite_plugin.cpp deleted file mode 100644 index cd17f0f..0000000 --- a/contrib/nnc/plugin/tflite_frontend/tflite_plugin.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include - -#include "support/PluginInstance.h" -#include "support/PluginException.h" -#include "option/Options.h" - -#include "tflite_v3_importer.h" - -namespace -{ -using namespace nncc::contrib; -using namespace nncc::contrib::plugin; - -class ImporterPlugin : public FrontendPlugin -{ -public: - ImporterPlugin &operator=(const ImporterPlugin &) = delete; - ImporterPlugin(const ImporterPlugin &) = delete; - - static FrontendPlugin &getInstance(); - void *execute(void *data) override; - -private: - ImporterPlugin() = default; - ~ImporterPlugin() override = default; -}; - -FrontendPlugin &ImporterPlugin::getInstance() -{ - static ImporterPlugin instance; - return instance; -} - -void *ImporterPlugin::execute(void *) -{ - nncc::contrib::frontend::tflite::v3::TfliteImporter importer{clopt::inputFile}; - - bool success = importer.import(); - - if (!success) - { - throw nncc::contrib::PluginException("Could not load model: " + clopt::inputFile + "\n"); - }; - - return importer.createIR(); -} - -} // anonymous namespace - -extern "C" Plugin *get_instance() -{ - return &ImporterPlugin::getInstance(); -} diff --git a/contrib/nnc/support/CLOptionChecker.cpp b/contrib/nnc/support/CLOptionChecker.cpp index bc4a6e7..00d39a2 100644 --- a/contrib/nnc/support/CLOptionChecker.cpp +++ b/contrib/nnc/support/CLOptionChecker.cpp @@ -15,32 +15,6 @@ namespace contrib { namespace clopt { -void checkPluginsPath(const Option &plugin_dir) -{ - auto dir = opendir(plugin_dir.c_str()); - - if (dir) - { - closedir(dir); - return; - } - - auto err = errno; - - switch (err) - { - case ENOENT: - throw BadOption("No such plugins directory"); - case ENOTDIR: - throw BadOption("Value for plugins path is not directory"); - case EACCES: - throw BadOption("Has no permission to open plugins directory"); - default: - throw BadOption("Can not open plugins directory"); - } - -} // checkPluginsPath - void checkInFile(const Option &in_file) { if ( in_file.empty() ) diff --git a/contrib/nnc/support/CMakeLists.txt b/contrib/nnc/support/CMakeLists.txt index fdce6ff..edd06fc 100644 --- a/contrib/nnc/support/CMakeLists.txt +++ b/contrib/nnc/support/CMakeLists.txt @@ -1,12 +1,10 @@ set(SUPPORT_SOURCES CommandLine.cpp CLOptionChecker.cpp - Debug.cpp - PluginManager.cpp - PluginProxy.cpp) + Debug.cpp) add_library(nnc_support SHARED ${SUPPORT_SOURCES}) set_target_properties(nnc_support PROPERTIES LINKER_LANGUAGE CXX) target_link_libraries(nnc_support PRIVATE dl) -install_common_library(nnc_support) +install_nnc_library(nnc_support) diff --git a/contrib/nnc/support/PluginManager.cpp b/contrib/nnc/support/PluginManager.cpp deleted file mode 100644 index 88ce9d5..0000000 --- a/contrib/nnc/support/PluginManager.cpp +++ /dev/null @@ -1,73 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "support/PluginProxy.h" -#include "support/PluginManager.h" -#include "support/PluginInstance.h" -#include "option/Options.h" - -#include "support/PluginException.h" - -//#include "debug.h" -//#define DEBUG_AREA "plugin" - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -PluginManager::PluginManager(const std::string &plugin_path) -{ - _plugin_path = plugin_path; - loadPlugin(); - -} // PluginManager - -PluginManager::~PluginManager() noexcept(false) -{ - // if exception is already being thrown then we can't throw yet - // another exception from destructor because std::terminate will - // be called and terminates program so that we don't unload plugin - // because this function can throw exception - if ( !std::uncaught_exception() ) - { - unloadPlugin(); - } - -} // ~PluginManager - -void PluginManager::loadPlugin() -{ - // TODO: we need to move debug to support directory - //NNC_DEBUG(dbgs() << "Current plugin path is <" << _plugin_path << ">" << std::endl); - _plugin_proxy = PluginProxy::create(_plugin_path); - -} // loadPlugin - -void PluginManager::unloadPlugin() -{ - _plugin_proxy->remove(); - -} // unloadPlugin - -Plugin *PluginManager::getPlugin() -{ - return _plugin_proxy->getPluginInstance(); - -} // getPlugin - -std::ostream &operator<<(std::ostream &st, const PluginManager &pm) -{ - st << pm._plugin_path << std::endl; - return st; -} - -} // namespace plugin -} // namespace contrib -} // namespace nncc diff --git a/contrib/nnc/support/PluginProxy.cpp b/contrib/nnc/support/PluginProxy.cpp deleted file mode 100644 index 192cebc..0000000 --- a/contrib/nnc/support/PluginProxy.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include -#include -#include - -#include "support/PluginProxy.h" -#include "support/shared_library.h" - -namespace nncc -{ -namespace contrib -{ -namespace plugin -{ - -const std::string PluginProxy::getInstanceFuncName = "get_instance"; - -PluginProxy::PluginProxy(const std::string &pluginPath) - : _getInstance(nullptr), _lib(nullptr), _pluginInstance(nullptr) -{ - _lib = std::make_shared>(pluginPath); - - // get plugin name by path - auto i = pluginPath.find_last_of('/'); - _pluginName = ( i == std::string::npos ) ? pluginPath : pluginPath.substr(i + 1); - -} // PluginProxy - -std::shared_ptr PluginProxy::create(const std::string &pluginPath) -{ - auto proxy = std::shared_ptr(new PluginProxy(pluginPath)); - - // get plugin function from shared library - proxy->_getInstance = (get_instance_t)proxy->_lib->findFunc(getInstanceFuncName); - - // call plugin function - proxy->_pluginInstance = proxy->_getInstance(); - - if ( !proxy->_pluginInstance ) - throw PluginException("this shared library is not NNC plugin"); - - return proxy; - -} // create - -void PluginProxy::remove() -{ - _lib->unloadLibrary(); - -} // remove - -Plugin *PluginProxy::getPluginInstance() -{ - assert(_pluginInstance); - return _pluginInstance; - -} // getPluginInstance - -std::ostream &operator<<(std::ostream &st, const PluginProxy &pl) -{ - st << *(pl._lib); - return st; -} - -} // namespace plugin -} // namespace contrib -} // namespace nncc diff --git a/contrib/nnc/tests/interpreter/graph_creator.cpp b/contrib/nnc/tests/interpreter/graph_creator.cpp index 14631d6..110bbe8 100644 --- a/contrib/nnc/tests/interpreter/graph_creator.cpp +++ b/contrib/nnc/tests/interpreter/graph_creator.cpp @@ -14,7 +14,7 @@ #include "core/modelIR/operations/softmax_op.h" #include "core/modelIR/ShapeInference.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" #include "op_info_generated.h" #include "graph_creator.h" diff --git a/contrib/nnc/tests/interpreter/op_info_util.h b/contrib/nnc/tests/interpreter/op_info_util.h index 09d24da..9c23bc1 100644 --- a/contrib/nnc/tests/interpreter/op_info_util.h +++ b/contrib/nnc/tests/interpreter/op_info_util.h @@ -11,7 +11,7 @@ #include "core/modelIR/ShapeInference.h" #include "op_info_generated.h" -#include "plugin/common_frontend/shape_helper.h" +#include "passes/common_frontend/shape_helper.h" #include "graph_creator.h" using namespace nncc::contrib::frontend::common; diff --git a/contrib/nnc/tests/interpreter/op_test.cpp b/contrib/nnc/tests/interpreter/op_test.cpp index 51f25ac..bf9d7b6 100644 --- a/contrib/nnc/tests/interpreter/op_test.cpp +++ b/contrib/nnc/tests/interpreter/op_test.cpp @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "op_info_generated.h" -#include "plugin/interpreter/Interpreter.h" +#include "passes/interpreter/Interpreter.h" #include "core/modelIR/graph.h" #include "op_info_util.h" #include "graph_creator.h" diff --git a/contrib/nnc/tests/soft_backend/compile_cpp.cpp b/contrib/nnc/tests/soft_backend/compile_cpp.cpp index a38f396..beeeed2 100644 --- a/contrib/nnc/tests/soft_backend/compile_cpp.cpp +++ b/contrib/nnc/tests/soft_backend/compile_cpp.cpp @@ -18,7 +18,7 @@ #include "core/modelIR/operations/variable_op.h" #include "core/modelIR/ShapeInference.h" -#include "cpp_generator.h" +#include "passes/soft_backend/CPPGenerator.h" // This header generated and contains array with test_main.def contents #include "test_main.generated.h" @@ -78,7 +78,7 @@ int main(int argc, const char *argv[]) Graph g; fillGraph(g); - nncc::contrib::backend::soft::CPPCodeGenerator().execute(&g); + nncc::contrib::backend::soft::CPPCodeGenerator::getInstance().run(&g); string basePath = outputDir + "/" + artifactName; diff --git a/contrib/nnc/unittests/CMakeLists.txt b/contrib/nnc/unittests/CMakeLists.txt index 22c72f8..59b4541 100644 --- a/contrib/nnc/unittests/CMakeLists.txt +++ b/contrib/nnc/unittests/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(module) +add_subdirectory(pass) add_subdirectory(core) add_subdirectory(soft_backend) add_subdirectory(support) diff --git a/contrib/nnc/unittests/module/CMakeLists.txt b/contrib/nnc/unittests/module/CMakeLists.txt deleted file mode 100644 index 91034cf..0000000 --- a/contrib/nnc/unittests/module/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -file(GLOB_RECURSE TEST_SOURCES "*.cpp") - -# Plugin module tests -add_nncc_test(nnc_module_test ${OPTIONS_SRC} ${TEST_SOURCES} ${HEADERS}) -if (TARGET nnc_module_test) - nncc_target_link_libraries(nnc_module_test nnc_support dl) - - # Set macro in nnc_module_test with some_parser absolute path - target_compile_definitions(nnc_module_test PRIVATE - CMAKE_SAMPLE_PLUGIN_ABS_PATH=$ - CMAKE_SAMPLE_PLUGIN_2_ABS_PATH=$ - CMAKE_SAMPLE_PLUGIN_DIR_ABS_PATH=$) -endif() diff --git a/contrib/nnc/unittests/module/PluginManager.cpp b/contrib/nnc/unittests/module/PluginManager.cpp deleted file mode 100644 index 135037c..0000000 --- a/contrib/nnc/unittests/module/PluginManager.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include - -#include "support/CommandLine.h" -#include "support/PluginManager.h" - -#include "gtest/gtest.h" - -#define STRING(s) _STRING(s) -#define _STRING(s) #s - -using namespace nncc::contrib; -using namespace nncc::contrib::plugin; - - -// Test PluginManager loading with unexisting path -TEST(CONTRIB_NNC, PluginManagerMissingDir) -{ - try - { - PluginManager pluginManager("AAA"); - FAIL(); - } - catch ( const PluginException &e ) - { - - } -} - -// Test PluginManager work with correct configuration -TEST(CONTRIB_NNC, PluginManager) -{ - PluginManager pluginManager(STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); - - // Test operator '<<' - std::ostringstream os; - os << pluginManager; - - std::string plugin(STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH) "\n"); - EXPECT_TRUE(os.str() == plugin ); -} - -// Test PluginManager that destructor doesn't throw exception -TEST(SUPPORT_NNC, verifyPluginManagerForLoadPlugin) -{ - std::string plugin_path{STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)}; - - try - { - // load the library - PluginManager pluginManager(plugin_path); - } - catch ( const PluginException &e ) - { - FAIL(); - } - - try - { - // load the library - PluginManager pluginManager(plugin_path); - - throw PluginException("test exception"); - } - catch (const PluginException &e) - { - ASSERT_EQ(e.what(), "test exception"); - } -} diff --git a/contrib/nnc/unittests/module/PluginProxy.cpp b/contrib/nnc/unittests/module/PluginProxy.cpp deleted file mode 100644 index 2fa5069..0000000 --- a/contrib/nnc/unittests/module/PluginProxy.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "support/PluginProxy.h" -#include "support/PluginException.h" - -#include "gtest/gtest.h" - -#define STRING(s) _STRING(s) -#define _STRING(s) #s - -using namespace nncc::contrib::plugin; -using namespace nncc::contrib; - -void tryCreatePluginProxy(std::string path) -{ - try - { - std::shared_ptr pp = PluginProxy::create(path); - FAIL(); - } - catch ( PluginException &e ) {} -} - -TEST(CONTRIB_NNC, PluginProxy) -{ - // Create PluginProxy from path with '/' and without, to visit both execution pathes - tryCreatePluginProxy("/some/path/pluginName"); - tryCreatePluginProxy("pluginName"); - - // Create PluginProxy with sample plugin - std::shared_ptr pp = PluginProxy::create(STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); - ASSERT_EQ(pp->getPluginPath(), STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); -#ifdef __APPLE__ - ASSERT_EQ(pp->getPluginName(), "libsome_parser.dylib"); -#else /* __APPLE__ */ - ASSERT_EQ(pp->getPluginName(), "libsome_parser.so"); -#endif /* __APPLE__ */ - ASSERT_NE(pp->getPluginInstance(), nullptr); - - // Operator '<<' - std::ostringstream os; - os << *pp; - ASSERT_EQ(os.str(), STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); -} diff --git a/contrib/nnc/unittests/module/shared_library.cpp b/contrib/nnc/unittests/module/shared_library.cpp deleted file mode 100644 index a62e2e8..0000000 --- a/contrib/nnc/unittests/module/shared_library.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "support/shared_library.h" -#include "support/PluginException.h" - -#include "gtest/gtest.h" - -#define STRING(s) _STRING(s) -#define _STRING(s) #s - -using namespace nncc::contrib; -using namespace nncc::contrib::plugin; - -TEST(CONTRIB_NNC, SharedLibrary) -{ - // missing so - missing function - SharedLibrary slMissing("/"); - ASSERT_THROW(slMissing.findFunc("missing_func_name"), PluginException); - - // existing so - missing function - SharedLibrary sl(STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); - ASSERT_THROW(sl.findFunc("missing_func_name"), PluginException); - - // existing so - existing function - typedef int (*fp_t)(); - fp_t getSomeBeef = (fp_t)sl.findFunc("getSomeBeef"); - ASSERT_NE(getSomeBeef, nullptr); - ASSERT_EQ(getSomeBeef(), 0xBEEF); - - // Operator '<<' - std::ostringstream os; - os << sl; - ASSERT_EQ(os.str(), STRING(CMAKE_SAMPLE_PLUGIN_ABS_PATH)); - -} diff --git a/contrib/nnc/unittests/pass/CMakeLists.txt b/contrib/nnc/unittests/pass/CMakeLists.txt new file mode 100644 index 0000000..35fb5d5 --- /dev/null +++ b/contrib/nnc/unittests/pass/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE TEST_SOURCES "*.cpp") + +add_nncc_test(nnc_pass_test ${TEST_SOURCES}) +if (TARGET nnc_pass_test) + nncc_target_link_libraries(nnc_pass_test nnc_support nnc_core) +endif() diff --git a/contrib/nnc/unittests/pass/PassExceptionTest.cpp b/contrib/nnc/unittests/pass/PassExceptionTest.cpp new file mode 100644 index 0000000..86fe4c3 --- /dev/null +++ b/contrib/nnc/unittests/pass/PassExceptionTest.cpp @@ -0,0 +1,37 @@ +#include "pass/PassException.h" + +#include "gtest/gtest.h" + +using namespace nncc::contrib::pass; + +const char *ErrorMsg = "error constructor"; + +void passErr1() { throw PassException(ErrorMsg); } + +void passErr2() +{ + try + { + passErr1(); + } + catch (PassException &e) + { + throw; + } +} + +TEST(CONTRIB_PASS, PassException) +{ + try + { + passErr2(); + } + catch (PassException &e) + { + ASSERT_TRUE(ErrorMsg == e.reason()); + return; + } + + // should not happen + FAIL(); +} diff --git a/contrib/nnc/unittests/pass/PassManagerTest.cpp b/contrib/nnc/unittests/pass/PassManagerTest.cpp new file mode 100644 index 0000000..5a365f2 --- /dev/null +++ b/contrib/nnc/unittests/pass/PassManagerTest.cpp @@ -0,0 +1,57 @@ +#include + +#include "core/modelIR/graph.h" +#include "support/CommandLine.h" +#include "pass/Pass.h" +#include "pass/PassData.h" +#include "pass/PassException.h" + +#include "gtest/gtest.h" + +using namespace nncc::contrib; +using namespace nncc::contrib::pass; +using namespace nncc::contrib::core::IR::model; + +class DummyPass1 : public Pass +{ +public: + PassData run(PassData data) override + { + auto graph = static_cast(data); + + if ( !graph ) + { + throw PassException(); + } + + return graph; + } +}; + +class DummyPass2 : public Pass +{ +public: + PassData run(PassData data) override + { + auto tv = static_cast(data); + + if ( !tv ) + { + throw PassException(); + } + + return nullptr; + } +}; + +TEST(CONTRIB_PASS, PassManager) +{ + DummyPass1 pass1; + DummyPass2 pass2; + + Graph g; + auto res = pass1.run(&g); + ASSERT_NE(static_cast(res), nullptr); + + ASSERT_THROW(pass2.run(res), PassException); +} diff --git a/contrib/nnc/unittests/soft_backend/cpp_operations.cpp b/contrib/nnc/unittests/soft_backend/cpp_operations.cpp index 78a604d..66d4ae2 100644 --- a/contrib/nnc/unittests/soft_backend/cpp_operations.cpp +++ b/contrib/nnc/unittests/soft_backend/cpp_operations.cpp @@ -47,7 +47,7 @@ #include "core/modelIR/ShapeRange.h" #include "core/modelIR/ShapeInference.h" -#include "plugin/interpreter/Interpreter.h" +#include "passes/interpreter/Interpreter.h" #include "gtest/gtest.h" diff --git a/contrib/nnc/unittests/soft_backend/generator.cpp b/contrib/nnc/unittests/soft_backend/generator.cpp index f983e32..d070389 100644 --- a/contrib/nnc/unittests/soft_backend/generator.cpp +++ b/contrib/nnc/unittests/soft_backend/generator.cpp @@ -1,8 +1,6 @@ -#include "cpp_generator.h" +#include "passes/soft_backend/CPPGenerator.h" #include "core/modelIR/operations/relu_op.h" -#include "support/PluginException.h" - #include #include @@ -83,20 +81,14 @@ TEST(Generator, check_generator_call) deleteDir(TEST_DIR); } assert(!isFileExists(TEST_DIR) && "remove output dir"); - { - CPPCodeGenerator gen; - gen.execute(&g); - } + CPPCodeGenerator::getInstance().run(&g); checkOutputExists(BASE_NAME); // test that generator creates output files in existing empty dir deleteFile(BASE_NAME ".h"); deleteFile(BASE_NAME ".cpp"); deleteFile(BASE_NAME ".params"); - { - CPPCodeGenerator gen; - gen.execute(&g); - } + CPPCodeGenerator::getInstance().run(&g); checkOutputExists(BASE_NAME); // test that generator rewrites existing files @@ -105,10 +97,7 @@ TEST(Generator, check_generator_call) int res = stat(BASE_NAME ".h", &sBefore); assert(res == 0); assert(sBefore.st_size == 0); - { - CPPCodeGenerator gen; - gen.execute(&g); - } + CPPCodeGenerator::getInstance().run(&g); res = stat(BASE_NAME ".h", &sAfter); assert(res == 0); diff --git a/contrib/nnc/unittests/support/PluginException.cpp b/contrib/nnc/unittests/support/PluginException.cpp deleted file mode 100644 index d55168b..0000000 --- a/contrib/nnc/unittests/support/PluginException.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "support/PluginException.h" - -#include "gtest/gtest.h" - -using namespace nncc::contrib; - -std::string pluginErrorMsg = "error constructor"; - -void pluginErr1() { throw PluginException(pluginErrorMsg); } - -void pluginErr2() -{ - try - { - pluginErr1(); - } - catch (PluginException &e) - { - throw; - } -} - -TEST(CONTRIB_PLUGIN, PluginException) -{ - try - { - pluginErr2(); - } - catch (PluginException &e) - { - ASSERT_TRUE(pluginErrorMsg == e.what()); - return; - } - - // should not happen - ASSERT_TRUE(false); -} -- 2.7.4