Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / train / TrainableExecutors.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
18 #define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
19
20 #include "TrainableExecutor.h"
21 #include "exec/IExecutors.h"
22 #include "ir/NNPkg.h"
23
24 namespace onert
25 {
26 namespace exec
27 {
28 namespace train
29 {
30
31 /**
32  * @brief Class to gather executor set for trainable model NN package
33  */
34 class TrainableExecutors : public IExecutors
35 {
36 public:
37   /**
38    * @brief Construct a new TrainableExecutors object
39    */
40   TrainableExecutors(void) = default;
41   TrainableExecutors(const TrainableExecutors &) = delete;
42   TrainableExecutors(TrainableExecutors &&) = default;
43
44   /**
45    * @brief Destroy the TrainableExecutors object
46    */
47   ~TrainableExecutors() = default;
48
49 public:
50   TrainableExecutors &operator=(const TrainableExecutors &) = delete;
51   TrainableExecutors &operator=(TrainableExecutors &&) = default;
52
53 public:
54   void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
55                std::unique_ptr<IExecutor> exec) override;
56
57   TrainableExecutor *at(const ir::ModelIndex &model_index,
58                         const ir::SubgraphIndex &subg_index) const override;
59
60   TrainableExecutor *entryExecutor() const { return at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); }
61
62   uint32_t inputSize() const override;
63
64   uint32_t outputSize() const override;
65
66   const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override;
67
68   const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override;
69
70   void execute(const IODescription &desc) override;
71
72   /**
73    * @brief Train
74    *
75    * @param desc          IO information
76    * @param training_step The number of iterations of an training process.
77    *                      In other words, the number of gradient update.
78    */
79   void train(const IODescription &desc, uint32_t training_step);
80
81   float getLoss(const ir::IOIndex &index) const;
82
83 private:
84   // TODO Append model index to ModelIndex
85   std::unordered_map<ir::SubgraphIndex, std::unique_ptr<TrainableExecutor>> _executors;
86 };
87
88 } // namespace train
89 } // namespace exec
90 } // namespace onert
91
92 #endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__