2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TrainingCompiler.h"
19 #include "StaticDerivativeShapeInferer.h"
20 #include "TrainableOperationConverter.h"
21 #include "pass/LossInsertionPass.h"
22 #include "../CompilerHelpers.h"
23 #include "../ExecutorFactory.h"
24 #include "../pass/ConstantOutputPass.h"
25 #include "../pass/OddOutputPass.h"
26 #include "../pass/PassRunner.h"
27 #include "../pass/UnusedOperandEliminationPass.h"
28 #include "../ShapeValidator.h"
29 #include "../../dumper/dot/DotDumper.h"
30 #include "../../exec/train/TrainableExecutors.h"
31 #include "../../ir/OperationDumper.h"
32 #include "../../ir/verifier/Verifier.h"
34 #include <compiler/StaticShapeInferer.h>
35 #include <compiler/train/LoweredTrainableGraph.h>
36 #include <ir/train/TrainableGraph.h>
37 #include <exec/train/optimizer/SGD.h>
39 #include <misc/polymorphic_downcast.h>
40 #include <misc/string_helpers.h>
49 TrainingCompiler::TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg,
50 std::vector<std::unique_ptr<CompilerOptions>> &copts,
51 const TrainingInfo &training_info)
52 : _model{nnpkg->primary_model()}, _options{copts[0].get()}, _training_info{training_info}
54 if (nnpkg->model_count() > 1)
55 throw std::runtime_error("TrainingCompiler does not support multiple models yet");
57 if (nnpkg->primary_model()->subgraphs_count() > 1)
58 throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet");
61 std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
63 /***************************************************
64 * Prepare compilation phase
65 ***************************************************/
67 throw std::runtime_error{"Empty compile option"};
70 // TODO handle option for each model
71 if (_options->he_profiling_mode)
73 if (!_options->he_scheduler)
74 throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
76 if (_options->executor != "Dataflow")
77 throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
80 if (!_options->minmax_filepath.empty())
82 if (_options->executor != "Linear")
83 throw std::runtime_error("Recording minmax works only with Linear executor");
86 _options->forceInternalOptions();
87 _options->verboseOptions();
89 auto custom_kernel_builder = _model->getKernelBuilder();
91 _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
92 auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
94 compiler::pass::PassRunner{}
95 .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
96 .append(std::make_unique<compiler::pass::OddOutputPass>(subg))
100 compiler::pass::PassRunner{}
101 .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg))
105 std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>>
108 if (_model->hasOnly<ir::Graph>())
110 // Create trainable subgraphs by copy and converting inference model
111 _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) {
112 const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph);
113 // Create TrainableGraph by copying Graph
114 auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg);
116 // Convert operations to trainable operations
117 auto converter = TrainableOperationConverter{*trainable_subg, &_training_info};
118 subg.operations().iterate(
119 [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) {
120 auto trainable_op = converter(op);
121 auto gen_index = trainable_subg->replaceOperation(op_index, std::move(trainable_op));
122 UNUSED_RELEASE(gen_index);
123 assert(gen_index == op_index);
126 trainable_subgraphs[subg_index] = std::move(trainable_subg);
131 // TODO Support models that have TrainableGraphs
132 throw std::runtime_error("TrainingCompiler: Invalid model");
138 // Apply pass for trainable subgraphs
139 for (auto &&pair : trainable_subgraphs)
141 auto trainable_subg = pair.second;
142 auto subg_index = pair.first;
144 compiler::pass::PassRunner{}
145 .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info,
150 // Change input shape according to batch_size
151 for (auto &&pair : trainable_subgraphs)
153 auto trainable_subg = pair.second;
155 for (const auto &ind : trainable_subg->getInputs())
157 auto &input = trainable_subg->operands().at(ind);
158 auto new_shape = input.info().shape();
159 // TODO Consider batch size index
160 if (new_shape.dim(0) != 1)
161 throw std::runtime_error("the first dim is not 1. It is not supported yet.");
162 new_shape.dim(0) = _training_info.batchSize();
163 input.info().shape(new_shape);
167 /***************************************************
168 * Backend independent analysis & optimization phase
169 ***************************************************/
170 // TODO Handle dump level for each model
171 auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
172 onert::dumper::dot::DotDumper dot_dumper(dump_level);
175 auto tracing_ctx = std::make_unique<util::TracingCtx>();
177 // Lower: Assign backend
178 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>>
181 for (auto &&pair : trainable_subgraphs)
183 auto &subg_index = pair.first;
184 auto trainable_subg = pair.second;
186 // Lower: Assign backend
187 lowered_subgs[subg_index] =
188 std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options);
189 // Set tracing_ctx for copied graph
190 if (tracing_ctx != nullptr)
191 tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
195 for (const auto &pair : lowered_subgs)
197 const auto &subg_index = pair.first;
198 const auto &lowered_subg = pair.second;
199 dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
202 // Set derivatives as default tensor info
203 for (const auto &pair : lowered_subgs)
205 auto lowered_subg = pair.second.get();
206 auto &tgraph = lowered_subg->trainable_graph();
207 tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
208 if (!obj.isConstant())
210 auto deriv = std::make_unique<ir::Operand>(obj);
211 const auto gen_index = tgraph.addDerivative(index, std::move(deriv));
212 assert(gen_index == index);
213 UNUSED_RELEASE(gen_index);
220 // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
222 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
223 createStaticShapeInferers(lowered_subgs);
225 const auto primary_subg_idx = ir::SubgraphIndex{0};
226 inferers.at(primary_subg_idx)->infer();
228 for (const auto &pair_inferer : inferers)
230 const auto inferer = pair_inferer.second.get();
234 // NOTE StaticDerivativeShapeInferer is allocated for each subgraph,
235 // so it does not support models that have controlflow operations yet.
236 for (auto &&pair : lowered_subgs)
238 auto &lowered_subg = pair.second;
239 auto inferer = std::make_unique<StaticDerivativeShapeInferer>(lowered_subg.get());
246 for (const auto &pair : lowered_subgs)
248 auto &lowered_subg = pair.second;
249 compiler::ShapeValidator{lowered_subg->graph()}();
252 // TODO Validate shapes of derivative tensors
255 // TODO Set properties of optimizer
256 std::shared_ptr<exec::train::optimizer::Optimizer> optimizer;
257 const auto &optim_info = _training_info.optimizerInfo();
258 if (optim_info.optim_code == exec::train::optimizer::OptimizerCode::SGD)
259 optimizer = std::make_shared<exec::train::optimizer::SGD>(optim_info.learning_rate);
261 throw std::runtime_error("Invalid optimizer type, " +
262 exec::train::optimizer::toString(optim_info.optim_code));
264 /*************************************************************
265 * Backend independent analysis & optimization phase finished
266 *************************************************************/
267 auto executors = std::make_shared<exec::train::TrainableExecutors>();
268 for (auto &&pair : lowered_subgs)
270 auto const model_index = ir::ModelIndex{0};
271 auto const subg_index = pair.first;
272 auto &lowered_subg = pair.second;
273 auto const indexed_ranks = lowered_subg->indexed_ranks();
275 ir::OperationDumper dumper("Executor generation of Subgraph " +
276 std::to_string(subg_index.value()));
277 lowered_subg->graph().operations().iterate(
278 [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
280 ExecutorFactoryArgs args;
281 args.tracing_ctx = tracing_ctx.get();
282 args.options = _options;
283 args.model_index = model_index;
284 args.custom_kernel_builder = custom_kernel_builder;
285 auto executor = std::unique_ptr<exec::IExecutor>{
286 ExecutorFactory::get().create(std::move(lowered_subg), executors, args, optimizer)};
287 executor->setIndexedRanks(indexed_ranks);
288 executors->emplace(model_index, subg_index, std::move(executor));
291 /********************************
292 * Code generation phase finished
293 ********************************/
294 return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
298 } // namespace compiler