nnfw_custom_eval eval_function;
} custom_kernel_registration_info;
+NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id,
+ custom_kernel_registration_info *info);
+
#endif // __NNFW_DEV_H__
{
return session->output_tensorinfo(index, tensor_info);
}
+
+/*
+ * Register custom operation
+ * @param session session to register this operation
+ * @param id operation id
+ * @param info registration info ( eval function, etc. )
+ * @return NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_register_custom_op_info(nnfw_session *session, const char *id,
+ custom_kernel_registration_info *info)
+{
+ return session->register_custom_operation(id, info->eval_function);
+}
#include <limits.h>
#include <stdint.h>
-nnfw_session::nnfw_session() : _graph{nullptr}, _execution{nullptr}
+nnfw_session::nnfw_session()
+ : _graph{nullptr}, _execution{nullptr},
+ _kernel_registry{new neurun::backend::custom::KernelRegistry}
{
// DO NOTHING
}
auto model = nnfw::cpp14::make_unique<neurun::model::Model>();
_graph = std::make_shared<neurun::graph::Graph>(std::move(model));
+ _graph->bindKernelRegistry(_kernel_registry);
tflite_loader::Loader loader(*_graph);
auto model_file_path = package_dir + std::string("/") + models[0].asString(); // first model
loader.loadFromFile(model_file_path.c_str());
}
return NNFW_STATUS_NO_ERROR;
}
+
+NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id,
+ nnfw_custom_eval eval_func)
+{
+ _kernel_registry->registerKernel(id, eval_func);
+ return NNFW_STATUS_NO_ERROR;
+}
#define __API_NNFW_INTERNAL_HPP__
#include "nnfw.h"
+#include "nnfw_dev.h"
#include "compiler/Compiler.h"
#include "exec/Execution.h"
#include "graph/Graph.h"
+#include "backend/CustomKernelRegistry.h"
struct nnfw_session
{
NNFW_STATUS input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
NNFW_STATUS output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS register_custom_operation(const std::string &id, nnfw_custom_eval eval_func);
+
private:
std::shared_ptr<neurun::graph::Graph> _graph;
std::shared_ptr<neurun::exec::Execution> _execution;
+ std::shared_ptr<neurun::backend::custom::KernelRegistry> _kernel_registry;
};
#endif // __API_NNFW_INTERNAL_HPP__