Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / train / TrainableExecutors.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TrainableExecutors.h"
18
19 #include "../../backend/builtin/IOTensor.h"
20
21 #include <misc/polymorphic_downcast.h>
22
23 namespace onert
24 {
25 namespace exec
26 {
27 namespace train
28 {
29
30 void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index,
31                                  std::unique_ptr<IExecutor> exec)
32 {
33   std::unique_ptr<TrainableExecutor> t_exec{
34     nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())};
35   _executors.emplace(subg_index, std::move(t_exec));
36 }
37
38 TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &,
39                                           const ir::SubgraphIndex &subg_index) const
40 {
41   return _executors.at(subg_index).get();
42 }
43
44 uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->getInputTensors().size(); }
45
46 uint32_t TrainableExecutors::outputSize() const
47 {
48   return entryExecutor()->getOutputTensors().size();
49 }
50
51 const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const
52 {
53   return entryExecutor()->getInputTensors().at(index.value())->orig_info();
54 }
55
56 const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const
57 {
58   return entryExecutor()->getOutputTensors().at(index.value())->orig_info();
59 }
60
61 void TrainableExecutors::execute(const IODescription &desc)
62 {
63   if (_executors.size() > 1)
64     throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
65   entryExecutor()->forward(desc, false);
66
67   // TODO Support multple executors
68 }
69
70 void TrainableExecutors::train(const IODescription &desc, uint32_t training_step)
71 {
72   if (_executors.size() > 1)
73     throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
74   entryExecutor()->forward(desc, true);
75   entryExecutor()->backward(desc, training_step);
76
77   // TODO Support multple executors
78 }
79
80 float TrainableExecutors::getLoss(const ir::IOIndex &index) const
81 {
82   if (_executors.size() > 1)
83     throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
84   return entryExecutor()->getLoss(index);
85 }
86
87 } // namespace train
88 } // namespace exec
89 } // namespace onert