2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TrainableExecutor.h"
19 #include "ruy/profiler/instrumentation.h"
22 #include <misc/polymorphic_downcast.h>
31 TrainableExecutor::TrainableExecutor(
32 std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
33 backend::train::TrainableBackendContexts &&backend_contexts,
34 const compiler::train::TensorRegistries &tensor_regs,
35 compiler::train::TrainableCodeMap &&code_map, const std::vector<ir::OperationIndex> &order,
36 const util::TracingCtx *tracing_ctx)
37 : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
38 _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)},
39 _mutex(), _tracing_ctx(tracing_ctx)
41 auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
42 assert(tensors.empty());
43 for (auto &&ind : ind_seq)
45 backend::ITensor *tensor = tensor_regs.getITensor(ind);
46 assert(tensor != nullptr);
47 auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor);
48 tensors.push_back(io_tensor);
51 build_tensor_list(_trainable_graph.getInputs(), _input_tensors);
52 build_tensor_list(_trainable_graph.getOutputs(), _output_tensors);
54 for (auto &&index : order)
56 auto &trainable_code = code_map.at(index);
57 _code.emplace_back(std::move(trainable_code));
61 void TrainableExecutor::execute(const std::vector<backend::IPortableTensor *> &,
62 const std::vector<backend::IPortableTensor *> &)
64 throw std::runtime_error("TrainableExecutor does not support multiple subgraphs yet");
67 void TrainableExecutor::forward(const IODescription &desc, bool training)
69 // For thread-safe, use mutex
70 // TODO: if all used backends on this executor are thread-safe,
71 // do not need to use mutex (otherwise, use mutex)
72 std::lock_guard<std::mutex> lock(_mutex);
74 // TODO Update IO tensors if desc has dynamic input
76 assert(_input_tensors.size() == desc.inputs.size());
77 for (uint32_t i = 0; i < _input_tensors.size(); ++i)
79 auto tensor = _input_tensors[i];
81 // TODO Check if (desc.inputs[i] == nullptr)
82 // TODO Better design for ITensor? (we need const_cast as ITensor is writable)
83 tensor->setUserTensor(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)),
84 desc.inputs[i]->size);
90 assert(_output_tensors.size() == desc.outputs.size());
91 for (uint32_t i = 0; i < _output_tensors.size(); ++i)
93 auto tensor = _output_tensors[i];
95 if (desc.outputs[i] == nullptr)
96 throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."};
97 tensor->setUserTensor(static_cast<uint8_t *>(desc.outputs[i]->buffer), desc.outputs[i]->size);
101 forwardImpl(training);
103 // TODO Update output(s) desc if desc has dynamic input
106 void TrainableExecutor::forwardImpl(bool training)
110 auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
112 _subject.notifySubgraphBegin(profiling_subg_index);
113 for (auto &&code : _code)
115 const auto backend = code.lower_info->backend();
116 // TODO : Move ruy profiler into ExecutionObserver
118 ruy::profiler::ScopeLabel label(code.op->name());
120 _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
122 auto &tn_seq = code.tn_seq;
123 tn_seq->forward(training);
125 _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
127 _subject.notifySubgraphEnd(profiling_subg_index);
131 for (auto &&code : _code)
133 // TODO : Move ruy profiler into ExecutionObserver
135 ruy::profiler::ScopeLabel label(code.op->name());
137 auto &tn_seq = code.tn_seq;
138 tn_seq->forward(training);
143 void TrainableExecutor::backward(const IODescription &, uint32_t training_step)
145 // For thread-safe, use mutex
146 // TODO: if all used backends on this executor are thread-safe,
147 // do not need to use mutex (otherwise, use mutex)
148 std::lock_guard<std::mutex> lock(_mutex);
150 backwardImpl(training_step);
153 void TrainableExecutor::backwardImpl(uint32_t training_step)
157 auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
159 _subject.notifySubgraphBegin(profiling_subg_index);
160 for (auto it = _code.rbegin(); it != _code.rend(); ++it)
162 const auto &code = *it;
163 const auto backend = code.lower_info->backend();
164 // TODO : Move ruy profiler into ExecutionObserver
166 ruy::profiler::ScopeLabel label(code.op->name());
168 _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
170 auto &tn_seq = code.tn_seq;
171 tn_seq->backward(training_step);
173 _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
175 _subject.notifySubgraphEnd(profiling_subg_index);
179 for (auto it = _code.rbegin(); it != _code.rend(); ++it)
181 const auto &code = *it;
182 // TODO : Move ruy profiler into ExecutionObserver
184 ruy::profiler::ScopeLabel label(code.op->name());
186 auto &tn_seq = code.tn_seq;
187 tn_seq->backward(training_step);
192 float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const
194 const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind);
195 if (loss_ind.undefined())
196 throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."};
197 backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind);
198 auto loss_buf = reinterpret_cast<float *>(tensor->buffer());