} // namespace exec
namespace ir
{
-class Graph;
+struct IGraph;
class Model;
class NNPkg;
} // namespace ir
struct CompilerArtifact;
class CompilerOptions;
} // namespace compiler
+namespace odc
+{
+class QuantizeManager;
+} // namespace odc
} // namespace onert
struct nnfw_session
*/
enum class State
{
- INITIALIZED, //< Session is initialized and nothing has done to it
- MODEL_LOADED, //< Model is loaded
- PREPARED, //< Prepared(compiled) for execution
- RUNNING, //< Execution is in progress (only for asynchronous execution)
- FINISHED_RUN //< Executed at least once
+ INITIALIZED, //< Session is initialized and nothing has done to it
+ MODEL_LOADED, //< Model is loaded
+ PREPARED, //< Prepared(compiled) for execution
+ RUNNING, //< Execution is in progress (only for asynchronous execution)
+ FINISHED_RUN, //< Executed at least once
+ PREPARED_TRAINING, //< Prepared for training
+ FINISHED_TRAINING //< Trained at least once
};
public:
*/
NNFW_STATUS set_backends_per_operation(const char *backend_settings);
+#ifdef ONERT_TRAIN
+ NNFW_STATUS train_prepare(const nnfw_train_info *info);
+ NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+ NNFW_STATUS train_set_input(uint32_t index, const void *input,
+ const nnfw_tensorinfo *input_tensorinfo);
+ NNFW_STATUS train_set_expected(uint32_t index, const void *expected,
+ const nnfw_tensorinfo *expected_tensorinfo);
+ NNFW_STATUS train_run(bool update_weights);
+ NNFW_STATUS train_get_loss(uint32_t index, float *loss);
+ NNFW_STATUS train_export_circle(const char *path);
+#endif // ONERT_TRAIN
+
+ NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype);
+ NNFW_STATUS set_quantized_model_path(const char *path);
+ NNFW_STATUS quantize();
+
private:
- const onert::ir::Graph *primary_subgraph();
+ const onert::ir::IGraph *primary_subgraph();
uint32_t getInputSize();
uint32_t getOutputSize();
bool isStateRunning();
bool isStateFinishedRun();
bool isStatePreparedOrFinishedRun();
+#ifdef ONERT_TRAIN
+ bool isStatePreparedTraining();
+ bool isStateFinishedTraining();
+ bool isStatePreparedOrFinishedTraining();
+#endif // ONERT_TRAIN
private:
State _state{State::INITIALIZED};
std::unique_ptr<onert::exec::Execution> _execution;
std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry;
std::vector<std::thread> _threads;
+#ifdef ONERT_TRAIN
+ uint32_t _training_step{0};
+#endif // ONERT_TRAIN
+ std::unique_ptr<onert::odc::QuantizeManager> _quant_manager;
};
#endif // __API_NNFW_API_INTERNAL_H__