2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "nnfw_api_internal.h"
18 #include "CustomKernelRegistry.h"
19 #include "compiler/Compiler.h"
20 #include "util/ConfigSource.h"
21 #include "exec/Execution.h"
22 #include "circle_loader.h"
23 #include "tflite_loader.h"
24 #include "json/json.h"
25 #include "ir/OpCode.h"
31 #include <util/ConfigSource.h>
32 #include <misc/string_helpers.h>
35 * API does not accept string argument longer than max length below
37 #define MAX_BACKEND_NAME_LENGTH 32
38 #define MAX_OP_NAME_LENGTH 64
39 #define MAX_PATH_LENGTH 1024
41 // Is null-terminating in length ?
42 static bool null_terminating(const char *str, uint32_t length)
44 for (uint32_t i = 0; i < length; i++)
46 if (*(str + i) == '\0')
54 static onert::ir::Layout convertLayout(NNFW_LAYOUT layout)
56 if (layout == NNFW_LAYOUT_CHANNELS_LAST)
58 return onert::ir::Layout::NHWC;
60 else if (layout == NNFW_LAYOUT_CHANNELS_FIRST)
62 return onert::ir::Layout::NCHW;
64 return onert::ir::Layout::UNKNOWN;
67 nnfw_session::nnfw_session()
68 : _subgraphs{nullptr}, _execution{nullptr},
69 _kernel_registry{std::make_shared<onert::frontend::custom::KernelRegistry>()}
74 nnfw_session::~nnfw_session() = default;
76 NNFW_STATUS nnfw_session::load_circle_from_buffer(uint8_t *buffer, size_t size)
78 if (!isStateInitialized())
79 return NNFW_STATUS_INVALID_STATE;
82 return NNFW_STATUS_UNEXPECTED_NULL;
85 return NNFW_STATUS_ERROR;
87 _subgraphs = onert::circle_loader::loadModel(buffer, size);
88 _compiler = std::make_unique<onert::compiler::Compiler>(_subgraphs);
90 _state = State::MODEL_LOADED;
91 return NNFW_STATUS_NO_ERROR;
94 NNFW_STATUS nnfw_session::load_model_from_file(const char *package_dir)
96 if (!isStateInitialized())
97 return NNFW_STATUS_INVALID_STATE;
101 std::cerr << "package_dir is null." << std::endl;
102 return NNFW_STATUS_UNEXPECTED_NULL;
105 if (!null_terminating(package_dir, MAX_PATH_LENGTH))
107 std::cerr << "nnpackage path is too long" << std::endl;
108 return NNFW_STATUS_ERROR;
111 // TODO : add support for zipped package file load
113 if (!(dir = opendir(package_dir)))
115 std::cerr << "invalid nnpackge directory: " << package_dir << std::endl;
116 return NNFW_STATUS_ERROR;
122 std::string manifest_file_name(package_dir);
123 manifest_file_name += "/metadata/MANIFEST";
124 std::ifstream mfs(manifest_file_name);
126 // extract the filename of the first(index 0) model
127 // e.g. In MANIFEST file, { "models" : [ "firstmodel.tflite", "2nd.tflite" ] }
130 const Json::Value &models = root["models"];
131 const Json::Value &model_types = root["model-types"];
133 auto model_file_path = package_dir + std::string("/") + models[0].asString(); // first model
134 auto model_type = model_types[0].asString(); // first model's type
135 if (model_type == "tflite")
137 _subgraphs = onert::tflite_loader::loadModel(model_file_path.c_str());
139 else if (model_type == "circle")
141 _subgraphs = onert::circle_loader::loadModel(model_file_path.c_str());
145 std::cerr << "Unsupported model type in MANIFEST" << std::endl;
146 return NNFW_STATUS_ERROR;
148 _subgraphs->primary()->bindKernelBuilder(_kernel_registry->getBuilder());
150 catch (const std::exception &e)
152 std::cerr << "Error during model loading : " << e.what() << std::endl;
153 return NNFW_STATUS_ERROR;
156 _compiler = std::make_unique<onert::compiler::Compiler>(_subgraphs);
158 _state = State::MODEL_LOADED;
159 return NNFW_STATUS_NO_ERROR;
162 NNFW_STATUS nnfw_session::prepare()
164 // NOTE. If users want to run prepare() more than one time, this could be removed.
165 if (!isStateModelLoaded())
167 std::cerr << "Error during model prepare : ";
168 if (isStateInitialized())
170 std::cerr << "prepare should be run once";
174 std::cerr << "invalid state";
176 std::cerr << std::endl;
177 return NNFW_STATUS_INVALID_STATE;
180 if (!_subgraphs || !primary_subgraph() || primary_subgraph()->isBuildingPhase())
182 std::cerr << "Error during model prepare : "
183 << "prepare should be run after load_model" << std::endl;
184 return NNFW_STATUS_ERROR;
190 std::shared_ptr<onert::exec::ExecutorMap> executors = _compiler->compile();
191 _execution = std::make_shared<onert::exec::Execution>(executors);
193 catch (const std::exception &e)
195 std::cerr << "Error during model prepare : " << e.what() << std::endl;
196 return NNFW_STATUS_ERROR;
199 _state = State::PREPARED;
200 return NNFW_STATUS_NO_ERROR;
203 NNFW_STATUS nnfw_session::run()
205 if (!isStatePreparedOrFinishedRun())
207 std::cerr << "Error during nnfw_session::run : "
208 << "run should be run after prepare" << std::endl;
209 return NNFW_STATUS_INVALID_STATE;
214 _execution->execute();
216 catch (const std::exception &e)
218 std::cerr << "Error during nnfw_session::run : " << e.what() << std::endl;
219 return NNFW_STATUS_ERROR;
222 _state = State::FINISHED_RUN;
223 return NNFW_STATUS_NO_ERROR;
226 NNFW_STATUS nnfw_session::run_async()
228 if (!isStatePreparedOrFinishedRun())
230 std::cerr << "Error during nnfw_session::run_async : "
231 << "run_async should be run after prepare" << std::endl;
232 return NNFW_STATUS_INVALID_STATE;
235 _execution->startExecute();
237 _state = State::RUNNING;
238 return NNFW_STATUS_NO_ERROR;
241 NNFW_STATUS nnfw_session::await()
243 if (!isStateRunning())
245 std::cerr << "Error during nnfw_session::run_await : "
246 << "run_await should be run after run_async" << std::endl;
247 return NNFW_STATUS_ERROR;
250 _execution->waitFinish();
252 _state = State::FINISHED_RUN;
253 return NNFW_STATUS_NO_ERROR;
256 NNFW_STATUS nnfw_session::set_input(uint32_t index, NNFW_TYPE /*type*/, const void *buffer,
259 if (!isStatePreparedOrFinishedRun())
261 std::cerr << "Error during nnfw_session::set_input : invalid state" << std::endl;
262 return NNFW_STATUS_INVALID_STATE;
265 if (!buffer && length != 0)
268 << "Error during nnfw_session::set_input : given buffer is NULL but the length is not 0"
270 return NNFW_STATUS_ERROR;
275 _execution->setInput(onert::ir::IOIndex(index), buffer, length);
277 catch (const std::exception &e)
279 std::cerr << "Error during nnfw_session::set_input : " << e.what() << std::endl;
280 return NNFW_STATUS_ERROR;
282 return NNFW_STATUS_NO_ERROR;
285 NNFW_STATUS nnfw_session::set_output(uint32_t index, NNFW_TYPE /*type*/, void *buffer,
288 if (!isStatePreparedOrFinishedRun())
290 std::cerr << "Error during nnfw_session::set_output : invalid state" << std::endl;
291 return NNFW_STATUS_INVALID_STATE;
294 if (!buffer && length != 0)
297 << "Error during nnfw_session::set_output : given buffer is NULL but the length is not 0"
299 return NNFW_STATUS_ERROR;
304 _execution->setOutput(onert::ir::IOIndex(index), buffer, length);
306 catch (const std::exception &e)
308 std::cerr << "Error during nnfw_session::set_output : " << e.what() << std::endl;
309 return NNFW_STATUS_ERROR;
311 return NNFW_STATUS_NO_ERROR;
314 NNFW_STATUS nnfw_session::input_size(uint32_t *number)
316 if (isStateInitialized()) // Model is not loaded
317 return NNFW_STATUS_INVALID_STATE;
321 if (number == nullptr)
323 std::cerr << "Error during nnfw_session::input_size, number is null pointer." << std::endl;
324 return NNFW_STATUS_UNEXPECTED_NULL;
326 *number = primary_subgraph()->getInputs().size();
328 catch (const std::exception &e)
330 std::cerr << "Error during nnfw_session::input_size : " << e.what() << std::endl;
331 return NNFW_STATUS_ERROR;
333 return NNFW_STATUS_NO_ERROR;
336 NNFW_STATUS nnfw_session::output_size(uint32_t *number)
338 if (isStateInitialized()) // Model is not loaded
339 return NNFW_STATUS_INVALID_STATE;
343 if (number == nullptr)
345 std::cerr << "Error during nnfw_session::output_size, number is null pointer." << std::endl;
346 return NNFW_STATUS_UNEXPECTED_NULL;
348 *number = primary_subgraph()->getOutputs().size();
350 catch (const std::exception &e)
352 std::cerr << "Error during nnfw_session::output_size" << e.what() << std::endl;
353 return NNFW_STATUS_ERROR;
355 return NNFW_STATUS_NO_ERROR;
358 NNFW_STATUS nnfw_session::set_input_layout(uint32_t index, NNFW_LAYOUT layout)
362 if (layout != NNFW_LAYOUT_NONE && layout != NNFW_LAYOUT_CHANNELS_FIRST &&
363 layout != NNFW_LAYOUT_CHANNELS_LAST)
365 std::cerr << "Error during nnfw_session::set_input_layout, not supported layout" << std::endl;
366 return NNFW_STATUS_ERROR;
368 _execution->setInputLayout(onert::ir::IOIndex(index), convertLayout(layout));
370 catch (const std::exception &e)
372 std::cerr << "Error during nnfw_session::set_input_layout : " << e.what() << std::endl;
373 return NNFW_STATUS_ERROR;
375 return NNFW_STATUS_NO_ERROR;
378 NNFW_STATUS nnfw_session::set_output_layout(uint32_t index, NNFW_LAYOUT layout)
382 if (layout != NNFW_LAYOUT_NONE && layout != NNFW_LAYOUT_CHANNELS_FIRST &&
383 layout != NNFW_LAYOUT_CHANNELS_LAST)
385 std::cerr << "Error during nnfw_session::set_output_layout, not supported layout"
387 return NNFW_STATUS_ERROR;
389 _execution->setOutputLayout(onert::ir::IOIndex(index), convertLayout(layout));
391 catch (const std::exception &e)
393 std::cerr << "Error during nnfw_session::set_output_layout : " << e.what() << std::endl;
394 return NNFW_STATUS_ERROR;
396 return NNFW_STATUS_NO_ERROR;
399 static NNFW_TYPE datatype_to_nnfw_dtype(onert::ir::DataType dt)
401 using onert::ir::DataType;
404 case DataType::FLOAT32:
405 return NNFW_TYPE_TENSOR_FLOAT32;
406 case DataType::INT32:
407 return NNFW_TYPE_TENSOR_INT32;
408 case DataType::QUANT_UINT8_ASYMM:
409 return NNFW_TYPE_TENSOR_QUANT8_ASYMM;
410 case DataType::BOOL8:
411 return NNFW_TYPE_TENSOR_BOOL;
412 case DataType::UINT8:
413 return NNFW_TYPE_TENSOR_UINT8;
414 case DataType::INT64:
415 return NNFW_TYPE_TENSOR_INT64;
416 case DataType::UINT32:
417 case DataType::QUANT_INT8_SYMM:
419 throw std::runtime_error("Error: Model has type that runtime API does not support.");
423 NNFW_STATUS nnfw_session::apply_tensorinfo(uint32_t index, nnfw_tensorinfo ti)
427 if (isStateInitialized())
429 std::cerr << "Error during set_input_tensorinfo : should be run after load_model"
431 return NNFW_STATUS_INVALID_STATE;
434 if (ti.rank <= 0 || ti.rank > NNFW_MAX_RANK)
436 std::cerr << "unsupported rank: " << ti.rank << std::endl;
437 return NNFW_STATUS_ERROR;
440 for (int32_t i = 0; i < ti.rank; ++i)
444 std::cerr << "dim must be positive integer but was " << ti.dims[i] << std::endl;
445 return NNFW_STATUS_ERROR;
450 if (!isStatePreparedOrFinishedRun())
452 // In this case, if we apply input shape in primary_subgraph, it will propagate after
453 // compilation and excution
454 auto ind = primary_subgraph()->getInputs().at(index);
455 auto &input = primary_subgraph()->operands().at(ind);
457 onert::ir::Shape new_shape(ti.rank);
458 for (int32_t i = 0; i < ti.rank; i++)
459 new_shape.dim(i) = ti.dims[i];
461 // overwrite input shape with the shape from ti
462 input.info().shape(new_shape);
464 else // when called after nnfw_session::prepare()
466 onert::ir::Shape new_shape(ti.rank);
467 for (int32_t i = 0; i < ti.rank; i++)
468 new_shape.dim(i) = ti.dims[i];
470 _execution->changeInputShape(onert::ir::IOIndex(index), new_shape);
473 return NNFW_STATUS_NO_ERROR;
476 NNFW_STATUS nnfw_session::set_input_tensorinfo(uint32_t index, const nnfw_tensorinfo *ti)
478 nnfw_tensorinfo ti_copy = *ti;
479 return apply_tensorinfo(index, ti_copy);
482 NNFW_STATUS nnfw_session::input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti)
484 if (isStateInitialized())
485 return NNFW_STATUS_INVALID_STATE;
491 std::cerr << "Error during nnfw_session::input_tensorinfo, tensorinfo is null pointer."
493 return NNFW_STATUS_UNEXPECTED_NULL;
495 if (index >= primary_subgraph()->getInputs().size())
497 std::cerr << "Error during nnfw_session::input_tensorinfo, index is out of range."
499 return NNFW_STATUS_ERROR;
501 auto opidx = primary_subgraph()->getInputs().at(index);
502 auto shape = primary_subgraph()->operands().at(opidx).shape();
503 if (isStatePreparedOrFinishedRun())
504 shape = _execution->getInputShape(onert::ir::IOIndex{index});
505 ti->rank = shape.rank();
506 for (int j = 0; j < ti->rank; ++j)
508 ti->dims[j] = shape.dim(j);
510 ti->dtype = datatype_to_nnfw_dtype(primary_subgraph()->operands().at(opidx).typeInfo().type());
512 catch (const std::exception &e)
514 std::cerr << "Error during nnfw_session::input_tensorinfo : " << e.what() << std::endl;
515 return NNFW_STATUS_ERROR;
517 return NNFW_STATUS_NO_ERROR;
520 NNFW_STATUS nnfw_session::output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti)
522 if (isStateInitialized())
523 return NNFW_STATUS_INVALID_STATE;
527 std::cerr << "Error during nnfw_session::output_tensorinfo, tensorinfo is null pointer."
529 return NNFW_STATUS_UNEXPECTED_NULL;
532 if (index >= primary_subgraph()->getOutputs().size())
534 std::cerr << "Error during nnfw_session::output_tensorinfo, index is out of range."
536 return NNFW_STATUS_ERROR;
541 auto opidx = primary_subgraph()->getOutputs().at(index);
542 auto shape = primary_subgraph()->operands().at(opidx).shape();
543 // If it is called after `nnfw_run` then get the shape from Execution, not from the graph
544 if (isStateFinishedRun())
545 shape = _execution->getOutputShape(onert::ir::IOIndex{index});
546 ti->rank = shape.rank();
547 for (int j = 0; j < ti->rank; ++j)
549 ti->dims[j] = shape.dim(j);
551 ti->dtype = datatype_to_nnfw_dtype(primary_subgraph()->operands().at(opidx).typeInfo().type());
553 catch (const std::exception &e)
555 std::cerr << "Error during nnfw_session::output_tensorinfo : " << e.what() << std::endl;
556 return NNFW_STATUS_ERROR;
559 return NNFW_STATUS_NO_ERROR;
561 NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id,
562 nnfw_custom_eval eval_func)
564 _kernel_registry->registerKernel(id, eval_func);
565 return NNFW_STATUS_NO_ERROR;
568 static std::string get_op_backend_string(std::string op)
570 #define MAP_MACRO(CircleName, OneRTName) {#CircleName, #OneRTName},
572 static std::unordered_map<std::string, std::string> operation_map = {
578 auto n = operation_map.find(op);
580 if (n == operation_map.end())
582 // this return value is handled by a caller to return error code
583 return std::string("");
591 NNFW_STATUS nnfw_session::set_available_backends(const char *backends)
593 if (!isStateModelLoaded())
594 return NNFW_STATUS_INVALID_STATE;
599 return NNFW_STATUS_UNEXPECTED_NULL;
600 if (null_terminating(backends, MAX_BACKEND_NAME_LENGTH) == false)
601 return NNFW_STATUS_ERROR;
603 auto &options = _compiler->options();
605 using namespace onert::util;
607 options.backend_list = nnfw::misc::split(std::string{backends}, ';');
609 catch (const std::exception &e)
611 std::cerr << "Error during nnfw_session::set_available_backends : " << e.what() << std::endl;
612 return NNFW_STATUS_ERROR;
614 return NNFW_STATUS_NO_ERROR;
617 NNFW_STATUS nnfw_session::set_op_backend(const char *op, const char *backend)
619 if (!isStateModelLoaded())
620 return NNFW_STATUS_INVALID_STATE;
625 return NNFW_STATUS_UNEXPECTED_NULL;
626 if (!null_terminating(op, MAX_OP_NAME_LENGTH) ||
627 !null_terminating(backend, MAX_BACKEND_NAME_LENGTH))
628 return NNFW_STATUS_ERROR;
630 auto key = get_op_backend_string(op);
634 return NNFW_STATUS_ERROR;
637 auto &opcode_to_backend = _compiler->options().manual_scheduler_options.opcode_to_backend;
638 opcode_to_backend.emplace(onert::ir::toOpCode(key), backend);
640 catch (const std::exception &e)
642 std::cerr << "Error during nnfw_session::set_op_backend : " << e.what() << std::endl;
643 return NNFW_STATUS_ERROR;
645 return NNFW_STATUS_NO_ERROR;
648 NNFW_STATUS nnfw_session::set_config(const char *key, const char *value)
650 if (!isStateModelLoaded())
651 return NNFW_STATUS_INVALID_STATE;
654 return NNFW_STATUS_UNEXPECTED_NULL;
656 auto &options = _compiler->options();
658 using namespace onert::util;
660 const std::string skey = key;
662 if (skey == config::TRACE_FILEPATH)
664 options.trace_filepath = value;
666 else if (skey == config::GRAPH_DOT_DUMP)
668 options.graph_dump_level = toInt(value);
670 else if (skey == config::OP_SEQ_MAX_NODE)
672 options.op_seq_max_node = toInt(value);
674 else if (skey == config::EXECUTOR)
676 options.executor = value;
678 else if (skey == config::OP_BACKEND_ALLOPS)
680 options.manual_scheduler_options.backend_for_all = value;
682 else if (skey == config::USE_SCHEDULER)
684 options.he_scheduler = toBool(value);
686 else if (skey == config::PROFILING_MODE)
688 options.he_profiling_mode = toBool(value);
690 else if (skey == config::DISABLE_COMPILE)
692 options.disable_compile = toBool(value);
696 return NNFW_STATUS_ERROR;
698 return NNFW_STATUS_NO_ERROR;
701 onert::ir::Graph *nnfw_session::primary_subgraph()
706 return _subgraphs->primary().get();
711 // TODO Remove const_cast
712 // We assumed the graph will not change after compilation, but shape could change
713 return const_cast<onert::ir::Graph *>(&_execution->primary_subgraph());
717 NNFW_STATUS nnfw_session::get_config(const char *key, char *value, size_t value_size)
719 if (!isStateModelLoaded())
720 return NNFW_STATUS_INVALID_STATE;
723 return NNFW_STATUS_UNEXPECTED_NULL;
725 auto &options = _compiler->options();
727 auto check_boundary = [](size_t dest_size, std::string &src) {
728 if (dest_size < src.length() + 1 /* for '\0' */)
730 std::cerr << "buffer is small to copy config value." << std::endl;
736 if (key == onert::util::config::BACKENDS)
738 if (options.backend_list.size() == 0)
739 return NNFW_STATUS_NO_ERROR; // no setting backend is not an error of get_config_str()
741 auto str = nnfw::misc::join(options.backend_list.begin(), options.backend_list.end(), ";");
743 if (!check_boundary(value_size, str))
744 return NNFW_STATUS_ERROR;
746 strncpy(value, str.c_str(), value_size);
748 else if (key == onert::util::config::EXECUTOR)
750 if (!check_boundary(value_size, options.executor))
751 return NNFW_STATUS_ERROR;
753 strncpy(value, options.executor.c_str(), options.executor.length());
757 return NNFW_STATUS_ERROR;
760 return NNFW_STATUS_NO_ERROR;
763 bool nnfw_session::isStateInitialized()
765 if (_state == State::INITIALIZED)
778 bool nnfw_session::isStateModelLoaded()
780 if (_state == State::MODEL_LOADED)
785 assert(!primary_subgraph()->isBuildingPhase());
794 bool nnfw_session::isStatePrepared()
796 if (_state == State::PREPARED)
801 assert(!primary_subgraph()->isBuildingPhase());
810 bool nnfw_session::isStateRunning()
812 if (_state == State::RUNNING)
817 assert(!primary_subgraph()->isBuildingPhase());
823 bool nnfw_session::isStateFinishedRun()
825 if (_state == State::FINISHED_RUN)
830 assert(!primary_subgraph()->isBuildingPhase());
839 bool nnfw_session::isStatePreparedOrFinishedRun()
841 return isStatePrepared() || isStateFinishedRun();