Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / train / TrainableExecutor.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TrainableExecutor.h"
18 #ifdef RUY_PROFILER
19 #include "ruy/profiler/instrumentation.h"
20 #endif
21
22 #include <misc/polymorphic_downcast.h>
23
24 namespace onert
25 {
26 namespace exec
27 {
28 namespace train
29 {
30
31 TrainableExecutor::TrainableExecutor(
32   std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
33   backend::train::TrainableBackendContexts &&backend_contexts,
34   const compiler::train::TensorRegistries &tensor_regs,
35   compiler::train::TrainableCodeMap &&code_map, const std::vector<ir::OperationIndex> &order,
36   const util::TracingCtx *tracing_ctx)
37   : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
38     _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)},
39     _mutex(), _tracing_ctx(tracing_ctx)
40 {
41   auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
42     assert(tensors.empty());
43     for (auto &&ind : ind_seq)
44     {
45       backend::ITensor *tensor = tensor_regs.getITensor(ind);
46       assert(tensor != nullptr);
47       auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor);
48       tensors.push_back(io_tensor);
49     }
50   };
51   build_tensor_list(_trainable_graph.getInputs(), _input_tensors);
52   build_tensor_list(_trainable_graph.getOutputs(), _output_tensors);
53
54   for (auto &&index : order)
55   {
56     auto &trainable_code = code_map.at(index);
57     _code.emplace_back(std::move(trainable_code));
58   }
59 }
60
61 void TrainableExecutor::execute(const std::vector<backend::IPortableTensor *> &,
62                                 const std::vector<backend::IPortableTensor *> &)
63 {
64   throw std::runtime_error("TrainableExecutor does not support multiple subgraphs yet");
65 }
66
67 void TrainableExecutor::forward(const IODescription &desc, bool training)
68 {
69   // For thread-safe, use mutex
70   // TODO: if all used backends on this executor are thread-safe,
71   //       do not need to use mutex (otherwise, use mutex)
72   std::lock_guard<std::mutex> lock(_mutex);
73
74   // TODO Update IO tensors if desc has dynamic input
75   // Set input(s)
76   assert(_input_tensors.size() == desc.inputs.size());
77   for (uint32_t i = 0; i < _input_tensors.size(); ++i)
78   {
79     auto tensor = _input_tensors[i];
80
81     // TODO Check if (desc.inputs[i] == nullptr)
82     // TODO Better design for ITensor? (we need const_cast as ITensor is writable)
83     tensor->setUserTensor(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)),
84                           desc.inputs[i]->size);
85   }
86
87   if (!training)
88   {
89     // Set output(s)
90     assert(_output_tensors.size() == desc.outputs.size());
91     for (uint32_t i = 0; i < _output_tensors.size(); ++i)
92     {
93       auto tensor = _output_tensors[i];
94
95       if (desc.outputs[i] == nullptr)
96         throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."};
97       tensor->setUserTensor(static_cast<uint8_t *>(desc.outputs[i]->buffer), desc.outputs[i]->size);
98     }
99   }
100
101   forwardImpl(training);
102
103   // TODO Update output(s) desc if desc has dynamic input
104 }
105
106 void TrainableExecutor::forwardImpl(bool training)
107 {
108   if (_tracing_ctx)
109   {
110     auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
111
112     _subject.notifySubgraphBegin(profiling_subg_index);
113     for (auto &&code : _code)
114     {
115       const auto backend = code.lower_info->backend();
116 // TODO : Move ruy profiler into ExecutionObserver
117 #ifdef RUY_PROFILER
118       ruy::profiler::ScopeLabel label(code.op->name());
119 #endif
120       _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
121
122       auto &tn_seq = code.tn_seq;
123       tn_seq->forward(training);
124
125       _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
126     }
127     _subject.notifySubgraphEnd(profiling_subg_index);
128   }
129   else
130   {
131     for (auto &&code : _code)
132     {
133 // TODO : Move ruy profiler into ExecutionObserver
134 #ifdef RUY_PROFILER
135       ruy::profiler::ScopeLabel label(code.op->name());
136 #endif
137       auto &tn_seq = code.tn_seq;
138       tn_seq->forward(training);
139     }
140   }
141 }
142
143 void TrainableExecutor::backward(const IODescription &, uint32_t training_step)
144 {
145   // For thread-safe, use mutex
146   // TODO: if all used backends on this executor are thread-safe,
147   //       do not need to use mutex (otherwise, use mutex)
148   std::lock_guard<std::mutex> lock(_mutex);
149
150   backwardImpl(training_step);
151 }
152
153 void TrainableExecutor::backwardImpl(uint32_t training_step)
154 {
155   if (_tracing_ctx)
156   {
157     auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
158
159     _subject.notifySubgraphBegin(profiling_subg_index);
160     for (auto it = _code.rbegin(); it != _code.rend(); ++it)
161     {
162       const auto &code = *it;
163       const auto backend = code.lower_info->backend();
164 // TODO : Move ruy profiler into ExecutionObserver
165 #ifdef RUY_PROFILER
166       ruy::profiler::ScopeLabel label(code.op->name());
167 #endif
168       _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
169
170       auto &tn_seq = code.tn_seq;
171       tn_seq->backward(training_step);
172
173       _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
174     }
175     _subject.notifySubgraphEnd(profiling_subg_index);
176   }
177   else
178   {
179     for (auto it = _code.rbegin(); it != _code.rend(); ++it)
180     {
181       const auto &code = *it;
182 // TODO : Move ruy profiler into ExecutionObserver
183 #ifdef RUY_PROFILER
184       ruy::profiler::ScopeLabel label(code.op->name());
185 #endif
186       auto &tn_seq = code.tn_seq;
187       tn_seq->backward(training_step);
188     }
189   }
190 }
191
192 float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const
193 {
194   const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind);
195   if (loss_ind.undefined())
196     throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."};
197   backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind);
198   auto loss_buf = reinterpret_cast<float *>(tensor->buffer());
199   return *loss_buf;
200 }
201
202 } // namespace train
203 } // namespace exec
204 } // namespace onert