2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __API_NNFW_API_INTERNAL_H__
18 #define __API_NNFW_API_INTERNAL_H__
21 #include "nnfw_experimental.h"
23 #include <util/TracingCtx.h>
34 class CustomKernelRegistry;
48 struct CompilerArtifact;
49 class CompilerOptions;
50 } // namespace compiler
53 class QuantizeManager;
61 * @brief Enum class to express the session's state
63 * State transition diagram:
78 * | PREPARED | --------+
83 * +--------------+ run |
85 * +-----> | FINISHED_RUN | | | run_async
87 * | +--------------+ |
89 * | await | run_async |
91 * | +--------------+ |
92 * +------ | RUNNING | <-------+
97 INITIALIZED, //< Session is initialized and nothing has done to it
98 MODEL_LOADED, //< Model is loaded
99 PREPARED, //< Prepared(compiled) for execution
100 RUNNING, //< Execution is in progress (only for asynchronous execution)
101 FINISHED_RUN, //< Executed at least once
102 PREPARED_TRAINING, //< Prepared for training
103 FINISHED_TRAINING //< Trained at least once
108 * @brief Factory method. It creates and initialize nnfw_session
110 * @note Use factory instead of constructor to get status
112 static NNFW_STATUS create(nnfw_session **session);
119 NNFW_STATUS load_model_from_nnpackage(const char *package_file_path);
120 NNFW_STATUS prepare();
121 NNFW_STATUS prepare_pipeline(const char *map_file_path);
124 NNFW_STATUS run_async();
127 NNFW_STATUS set_input(uint32_t index, NNFW_TYPE type, const void *buffer, size_t length);
128 NNFW_STATUS set_output(uint32_t index, NNFW_TYPE type, void *buffer, size_t length);
130 NNFW_STATUS input_size(uint32_t *number);
131 NNFW_STATUS output_size(uint32_t *number);
133 NNFW_STATUS set_input_layout(uint32_t index, NNFW_LAYOUT layout);
134 NNFW_STATUS set_output_layout(uint32_t index, NNFW_LAYOUT layout);
136 NNFW_STATUS apply_tensorinfo(uint32_t index, nnfw_tensorinfo ti); // Will be deprecated
137 NNFW_STATUS set_input_tensorinfo(uint32_t index, const nnfw_tensorinfo *ti);
139 NNFW_STATUS input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
140 NNFW_STATUS output_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
142 NNFW_STATUS set_available_backends(const char *backends);
143 NNFW_STATUS set_op_backend(const char *op, const char *backend);
149 NNFW_STATUS set_config(const char *key, const char *value);
150 NNFW_STATUS get_config(const char *key, char *value, size_t value_size);
151 NNFW_STATUS load_circle_from_buffer(uint8_t *buffer, size_t size);
152 NNFW_STATUS load_model_from_modelfile(const char *file_path);
157 NNFW_STATUS push_pipeline_input(std::vector<void *> *inputs, std::vector<uint32_t> *lengths);
158 NNFW_STATUS pop_pipeline_output(std::vector<void *> *outputs);
160 NNFW_STATUS register_custom_operation(const std::string &id, nnfw_custom_eval eval_func);
161 NNFW_STATUS input_tensorindex(const char *tensorname, uint32_t *index);
162 NNFW_STATUS output_tensorindex(const char *tensorname, uint32_t *index);
164 * @brief Set backends with string-encoded mapping from operation index to backend type
167 NNFW_STATUS set_backends_per_operation(const char *backend_settings);
170 NNFW_STATUS train_prepare(const nnfw_train_info *info);
171 NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
172 NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
173 NNFW_STATUS train_set_input(uint32_t index, const void *input,
174 const nnfw_tensorinfo *input_tensorinfo);
175 NNFW_STATUS train_set_expected(uint32_t index, const void *expected,
176 const nnfw_tensorinfo *expected_tensorinfo);
177 NNFW_STATUS train_run(bool update_weights);
178 NNFW_STATUS train_get_loss(uint32_t index, float *loss);
179 NNFW_STATUS train_export_circle(const char *path);
180 #endif // ONERT_TRAIN
182 NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype);
183 NNFW_STATUS set_quantized_model_path(const char *path);
184 NNFW_STATUS quantize();
187 const onert::ir::IGraph *primary_subgraph();
188 uint32_t getInputSize();
189 uint32_t getOutputSize();
191 bool isStateInitialized();
192 bool isStateModelLoaded();
193 bool isStatePrepared();
194 bool isStateRunning();
195 bool isStateFinishedRun();
196 bool isStatePreparedOrFinishedRun();
198 bool isStatePreparedTraining();
199 bool isStateFinishedTraining();
200 bool isStatePreparedOrFinishedTraining();
201 #endif // ONERT_TRAIN
204 State _state{State::INITIALIZED};
205 std::shared_ptr<onert::ir::NNPkg> _nnpkg;
206 std::vector<std::unique_ptr<onert::compiler::CompilerOptions>> _coptions;
207 std::shared_ptr<onert::compiler::CompilerArtifact> _compiler_artifact;
208 std::unique_ptr<onert::exec::Execution> _execution;
209 std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry;
210 std::vector<std::thread> _threads;
212 uint32_t _training_step{0};
213 #endif // ONERT_TRAIN
214 std::unique_ptr<onert::odc::QuantizeManager> _quant_manager;
217 #endif // __API_NNFW_API_INTERNAL_H__