#include "TensorRegistries.h"
#include "backend/ITensor.h"
+
+#ifdef ONERT_TRAIN
+#include "backend/train/TrainableBackendContext.h"
+#endif // ONERT_TRAIN
#include "compiler/LoweredGraph.h"
+#ifdef ONERT_TRAIN
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "exec/train/optimizer/Optimizer.h"
+#endif // ONERT_TRAIN
#include "exec/IExecutors.h"
#include <deque>
namespace compiler
{
+// TODO Change to a better name
+struct ExecutorFactoryArgs
+{
+ const util::TracingCtx *tracing_ctx;
+ const compiler::CompilerOptions *options;
+ ir::ModelIndex model_index;
+ std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder;
+};
+
class ExecutorFactory
{
public:
public:
exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options,
const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index);
+ const ExecutorFactoryArgs &args);
+
+#ifdef ONERT_TRAIN
+ // TODO Unify create()
+ exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
private:
ExecutorFactory();
private:
- static void prepareMigrantTensors(compiler::LoweredGraph &lowered_graph,
+ static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
const backend::BackendContexts &backend_contexts);
static void prepareBuiltinBackend(const TensorRegistries &tensor_regs,
const std::shared_ptr<exec::IExecutors> &executors,
static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
orderBackendContext(const backend::BackendContexts &backend_contexts);
- static exec::IExecutor *createLinearExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index);
- static exec::IExecutor *createDataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index, bool parallel);
+ static exec::IExecutor *
+ createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args);
+ static exec::IExecutor *
+ createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args, bool parallel);
+#ifdef ONERT_TRAIN
+ // TODO Unify prepareMigrantTensors
+ static void
+ prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts);
+ static exec::IExecutor *
+ createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer);
+#endif // ONERT_TRAIN
private:
std::unordered_map<
- std::string,
- std::function<exec::IExecutor *(
- std::unique_ptr<compiler::LoweredGraph>, const util::TracingCtx *tracing_ctx,
- const compiler::CompilerOptions &options, const std::shared_ptr<exec::IExecutors> &executors,
- const ir::ModelIndex &index)>>
+ std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)>>
_map;
};