Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / api / src / nnfw_api_internal.h
index 8e2c2fb..6279176 100644 (file)
@@ -39,7 +39,7 @@ class Execution;
 } // namespace exec
 namespace ir
 {
-class Graph;
+struct IGraph;
 class Model;
 class NNPkg;
 } // namespace ir
@@ -48,6 +48,10 @@ namespace compiler
 struct CompilerArtifact;
 class CompilerOptions;
 } // namespace compiler
+namespace odc
+{
+class QuantizeManager;
+} // namespace odc
 } // namespace onert
 
 struct nnfw_session
@@ -90,11 +94,13 @@ private:
    */
   enum class State
   {
-    INITIALIZED,  //< Session is initialized and nothing has done to it
-    MODEL_LOADED, //< Model is loaded
-    PREPARED,     //< Prepared(compiled) for execution
-    RUNNING,      //< Execution is in progress (only for asynchronous execution)
-    FINISHED_RUN  //< Executed at least once
+    INITIALIZED,       //< Session is initialized and nothing has done to it
+    MODEL_LOADED,      //< Model is loaded
+    PREPARED,          //< Prepared(compiled) for execution
+    RUNNING,           //< Execution is in progress (only for asynchronous execution)
+    FINISHED_RUN,      //< Executed at least once
+    PREPARED_TRAINING, //< Prepared for training
+    FINISHED_TRAINING  //< Trained at least once
   };
 
 public:
@@ -160,8 +166,25 @@ public:
    */
   NNFW_STATUS set_backends_per_operation(const char *backend_settings);
 
+#ifdef ONERT_TRAIN
+  NNFW_STATUS train_prepare(const nnfw_train_info *info);
+  NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+  NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
+  NNFW_STATUS train_set_input(uint32_t index, const void *input,
+                              const nnfw_tensorinfo *input_tensorinfo);
+  NNFW_STATUS train_set_expected(uint32_t index, const void *expected,
+                                 const nnfw_tensorinfo *expected_tensorinfo);
+  NNFW_STATUS train_run(bool update_weights);
+  NNFW_STATUS train_get_loss(uint32_t index, float *loss);
+  NNFW_STATUS train_export_circle(const char *path);
+#endif // ONERT_TRAIN
+
+  NNFW_STATUS set_quantization_type(NNFW_QUANTIZE_TYPE qtype);
+  NNFW_STATUS set_quantized_model_path(const char *path);
+  NNFW_STATUS quantize();
+
 private:
-  const onert::ir::Graph *primary_subgraph();
+  const onert::ir::IGraph *primary_subgraph();
   uint32_t getInputSize();
   uint32_t getOutputSize();
 
@@ -171,6 +194,11 @@ private:
   bool isStateRunning();
   bool isStateFinishedRun();
   bool isStatePreparedOrFinishedRun();
+#ifdef ONERT_TRAIN
+  bool isStatePreparedTraining();
+  bool isStateFinishedTraining();
+  bool isStatePreparedOrFinishedTraining();
+#endif // ONERT_TRAIN
 
 private:
   State _state{State::INITIALIZED};
@@ -180,6 +208,10 @@ private:
   std::unique_ptr<onert::exec::Execution> _execution;
   std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry;
   std::vector<std::thread> _threads;
+#ifdef ONERT_TRAIN
+  uint32_t _training_step{0};
+#endif // ONERT_TRAIN
+  std::unique_ptr<onert::odc::QuantizeManager> _quant_manager;
 };
 
 #endif // __API_NNFW_API_INTERNAL_H__