Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / ExecutorFactory.h
index f8f9890..cc621bc 100644 (file)
 #include "TensorRegistries.h"
 
 #include "backend/ITensor.h"
+
+#ifdef ONERT_TRAIN
+#include "backend/train/TrainableBackendContext.h"
+#endif // ONERT_TRAIN
 #include "compiler/LoweredGraph.h"
+#ifdef ONERT_TRAIN
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "exec/train/optimizer/Optimizer.h"
+#endif // ONERT_TRAIN
 #include "exec/IExecutors.h"
 
 #include <deque>
@@ -31,6 +39,15 @@ namespace onert
 namespace compiler
 {
 
+// TODO Change to a better name
+struct ExecutorFactoryArgs
+{
+  const util::TracingCtx *tracing_ctx;
+  const compiler::CompilerOptions *options;
+  ir::ModelIndex model_index;
+  std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder;
+};
+
 class ExecutorFactory
 {
 public:
@@ -38,16 +55,22 @@ public:
 
 public:
   exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
-                          const util::TracingCtx *tracing_ctx,
-                          const compiler::CompilerOptions &options,
                           const std::shared_ptr<exec::IExecutors> &executors,
-                          const ir::ModelIndex &index);
+                          const ExecutorFactoryArgs &args);
+
+#ifdef ONERT_TRAIN
+  // TODO Unify create()
+  exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+                          const std::shared_ptr<exec::IExecutors> &executors,
+                          const ExecutorFactoryArgs &args,
+                          const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
 
 private:
   ExecutorFactory();
 
 private:
-  static void prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
+  static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
                                     const backend::BackendContexts &backend_contexts);
   static void prepareBuiltinBackend(const TensorRegistries &tensor_regs,
                                     const std::shared_ptr<exec::IExecutors> &executors,
@@ -56,22 +79,31 @@ private:
   static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
   orderBackendContext(const backend::BackendContexts &backend_contexts);
 
-  static exec::IExecutor *createLinearExecutor(
-    std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
-    const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
-    const ir::ModelIndex &index);
-  static exec::IExecutor *createDataflowExecutor(
-    std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
-    const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
-    const ir::ModelIndex &index, bool parallel);
+  static exec::IExecutor *
+  createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+                       const std::shared_ptr<exec::IExecutors> &executors,
+                       const ExecutorFactoryArgs &args);
+  static exec::IExecutor *
+  createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+                         const std::shared_ptr<exec::IExecutors> &executors,
+                         const ExecutorFactoryArgs &args, bool parallel);
+#ifdef ONERT_TRAIN
+  // TODO Unify prepareMigrantTensors
+  static void
+  prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+                        const backend::train::TrainableBackendContexts &backend_contexts);
+  static exec::IExecutor *
+  createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+                          const std::shared_ptr<exec::IExecutors> &executors,
+                          const ExecutorFactoryArgs &args,
+                          const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
 
 private:
   std::unordered_map<
-    std::string,
-    std::function<exec::IExecutor *(
-      std::unique_ptr<compiler::LoweredGraph>, const util::TracingCtx *tracing_ctx,
-      const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
-      const ir::ModelIndex &index)>>
+    std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>,
+                                                 const std::shared_ptr<exec::IExecutors> &executors,
+                                                 const ExecutorFactoryArgs &args)>>
     _map;
 };