From 88d2864284b28bd44c8419a99337f5377040f379 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vladimir=20Plazun/AI=20Tools=20Lab=20/SRR/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 5 Sep 2019 07:28:40 +0300 Subject: [PATCH] [custom op] Implement custom op registration using public API (#7181) * [custom op] Implement custom op registration using public API Adds public custom operation registration method Signed-off-by: Vladimir Plazun * format fix --- runtimes/include/nnfw_dev.h | 3 +++ runtimes/neurun/frontend/api/nnfw_dev.cc | 13 +++++++++++++ runtimes/neurun/frontend/api/wrapper/nnfw_api.cc | 12 +++++++++++- runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp | 5 +++++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/runtimes/include/nnfw_dev.h b/runtimes/include/nnfw_dev.h index b4e62fa..5886377 100644 --- a/runtimes/include/nnfw_dev.h +++ b/runtimes/include/nnfw_dev.h @@ -58,4 +58,7 @@ typedef struct nnfw_custom_eval eval_function; } custom_kernel_registration_info; +NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id, + custom_kernel_registration_info *info); + #endif // __NNFW_DEV_H__ diff --git a/runtimes/neurun/frontend/api/nnfw_dev.cc b/runtimes/neurun/frontend/api/nnfw_dev.cc index 3c8178b..642cb1f 100644 --- a/runtimes/neurun/frontend/api/nnfw_dev.cc +++ b/runtimes/neurun/frontend/api/nnfw_dev.cc @@ -163,3 +163,16 @@ NNFW_STATUS nnfw_output_tensorinfo(nnfw_session *session, uint32_t index, { return session->output_tensorinfo(index, tensor_info); } + +/* + * Register custom operation + * @param session session to register this operation + * @param id operation id + * @param info registration info ( eval function, etc. ) + * @return NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id, + custom_kernel_registration_info *info) +{ + return session->register_custom_operation(id, info->eval_function); +} diff --git a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc index 49474fd..6b8a53d 100644 --- a/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc +++ b/runtimes/neurun/frontend/api/wrapper/nnfw_api.cc @@ -23,7 +23,9 @@ #include #include -nnfw_session::nnfw_session() : _graph{nullptr}, _execution{nullptr} +nnfw_session::nnfw_session() + : _graph{nullptr}, _execution{nullptr}, + _kernel_registry{new neurun::backend::custom::KernelRegistry} { // DO NOTHING } @@ -53,6 +55,7 @@ NNFW_STATUS nnfw_session::load_model_from_file(const char *package_dir) auto model = nnfw::cpp14::make_unique(); _graph = std::make_shared(std::move(model)); + _graph->bindKernelRegistry(_kernel_registry); tflite_loader::Loader loader(*_graph); auto model_file_path = package_dir + std::string("/") + models[0].asString(); // first model loader.loadFromFile(model_file_path.c_str()); @@ -255,3 +258,10 @@ NNFW_STATUS nnfw_session::output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti) } return NNFW_STATUS_NO_ERROR; } + +NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id, + nnfw_custom_eval eval_func) +{ + _kernel_registry->registerKernel(id, eval_func); + return NNFW_STATUS_NO_ERROR; +} diff --git a/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp b/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp index 8efdd77..f616f81 100644 --- a/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp +++ b/runtimes/neurun/frontend/api/wrapper/nnfw_api.hpp @@ -18,9 +18,11 @@ #define __API_NNFW_INTERNAL_HPP__ #include "nnfw.h" +#include "nnfw_dev.h" #include "compiler/Compiler.h" #include "exec/Execution.h" #include "graph/Graph.h" +#include "backend/CustomKernelRegistry.h" struct nnfw_session { @@ -40,9 +42,12 @@ public: NNFW_STATUS input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); NNFW_STATUS output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); + NNFW_STATUS register_custom_operation(const std::string &id, nnfw_custom_eval eval_func); + private: std::shared_ptr _graph; std::shared_ptr _execution; + std::shared_ptr _kernel_registry; }; #endif // __API_NNFW_INTERNAL_HPP__ -- 2.7.4