Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / train / TrainableExecutor.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
18 #define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
19
20 #include "exec/IExecutor.h"
21
22 #include "../ExecutionObservee.h"
23 #include "../../compiler/train/TensorRegistries.h"
24
25 #include "backend/train/TrainableBackendContext.h"
26 #include "compiler/train/TrainableCodeMap.h"
27 #include "compiler/train/LoweredTrainableGraph.h"
28 #include "ir/Index.h"
29 #include "util/TracingCtx.h"
30
31 namespace onert
32 {
33 namespace exec
34 {
35 namespace train
36 {
37
38 class TrainableExecutor : public IExecutor
39 {
40 public:
41   /**
42    * @brief Construct a new TrainableExecutor object
43    * @param lowered_graph LoweredTrainableGraph object
44    * @param tensor_builders Tensor builders that are currently used
45    * @param code_map @c ir::Operation and its code map
46    */
47   TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
48                     backend::train::TrainableBackendContexts &&backend_contexts,
49                     const compiler::train::TensorRegistries &tensor_regs,
50                     compiler::train::TrainableCodeMap &&code_map,
51                     const std::vector<ir::OperationIndex> &order,
52                     const util::TracingCtx *tracing_ctx);
53
54 public:
55   const ir::Graph &graph() const final { return _trainable_graph.graph(); }
56
57   void execute(const IODescription &desc) override { forward(desc, false); };
58
59   void execute(const std::vector<backend::IPortableTensor *> &inputs,
60                const std::vector<backend::IPortableTensor *> &outputs) override;
61
62   void forward(const IODescription &desc, bool training);
63   void backward(const IODescription &desc, uint32_t training_step);
64
65   // Used only in Dataflow and Parallel Executors
66   void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
67   {
68     _indexed_ranks = std::move(ranks);
69   };
70
71   void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); };
72
73   const std::vector<backend::builtin::IOTensor *> &getInputTensors() const override
74   {
75     return _input_tensors;
76   }
77
78   const std::vector<backend::builtin::IOTensor *> &getOutputTensors() const override
79   {
80     return _output_tensors;
81   }
82
83   float getLoss(const ir::IOIndex &pred_io_ind) const;
84
85   backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; }
86
87 private:
88   void forwardImpl(bool training);
89   void backwardImpl(uint32_t training_step);
90
91 private:
92   std::vector<compiler::train::TrainableCodeAndInfo> _code;
93   ExecutionObservee _subject;
94   std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
95   std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph;
96   backend::train::TrainableBackendContexts _backend_contexts;
97   const ir::train::TrainableGraph &_trainable_graph;
98   compiler::train::TensorRegistries _tensor_regs;
99   std::vector<backend::builtin::IOTensor *> _input_tensors;
100   std::vector<backend::builtin::IOTensor *> _output_tensors;
101   std::mutex _mutex;
102   const util::TracingCtx *_tracing_ctx;
103 };
104
105 } // namespace train
106 } // namespace exec
107 } // namespace onert
108
109 #endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_