2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ExecutorFactory.h"
20 #include "../backend/builtin/BackendContext.h"
21 #include "../backend/builtin/Config.h"
22 #include "../backend/builtin/UserTensor.h"
23 #include "../dumper/text/GraphDumper.h"
24 #include "../exec/DataflowExecutor.h"
25 #include "../exec/ExecTime.h"
26 #include "../exec/ExecutionObservers.h"
27 #include "../exec/LinearExecutor.h"
28 #ifdef MINMAX_H5DUMPER
29 #include "../exec/MinMaxRecorder.h"
31 #include "../exec/ParallelExecutor.h"
32 #include "../ir/OperationCloner.h"
34 #include <backend/IPortableTensor.h>
35 #include <compiler/BackendManager.h>
36 #include <compiler/ExecutionBuilder.h>
37 #include <util/TracingCtx.h>
43 #include "../backend/builtin/train/BackendContext.h"
44 #include "../exec/train/TrainableExecutor.h"
46 #include <backend/train/TrainableBackendContext.h>
47 #include <backend/train/ITrainableBackend.h>
55 class SyncFunction final : public exec::IFunction
58 virtual ~SyncFunction() = default;
59 SyncFunction(std::unique_ptr<exec::IFunction> fn, const std::shared_ptr<backend::IConfig> config)
60 : _fn{std::move(fn)}, _config{config}
72 void prepare() override { _fn->prepare(); }
75 std::unique_ptr<exec::IFunction> _fn;
76 std::shared_ptr<backend::IConfig> _config;
79 using DeallocList = std::vector<backend::ITensor *>;
80 // Deallocation after execution of an operation used by Linear Executor
81 class DeallocFunction final : public exec::IFunction
84 DeallocFunction(const DeallocList &tensors) : _dealloc_list{tensors} {}
88 for (auto &&tensor : _dealloc_list)
90 if (!tensor->is_dynamic())
92 tensor->deallocBuffer();
97 DeallocList _dealloc_list;
100 // TODO Unify initializeSubgraphIOTensors
101 void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
102 const backend::BackendContexts &backend_contexts,
103 const ir::OperandIndexSequence &indices)
105 // TODO Store builtin backend in BackendContext
106 std::shared_ptr<backend::builtin::TensorRegistry> builtin_tensor_reg;
107 for (const auto &e : backend_contexts)
109 auto backend = e.first;
110 auto &context = e.second;
111 if (backend->config()->id() == backend::builtin::Config::ID)
114 std::dynamic_pointer_cast<backend::builtin::TensorRegistry>(context->tensor_registry);
117 assert(builtin_tensor_reg);
119 for (auto &&ind : indices)
121 const auto &operand = lowered_graph.graph().operands().at(ind);
122 auto tensor = std::make_unique<backend::builtin::IOTensor>(
124 ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */
127 // Add tensor to builtin TensorRegistry.
128 builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
133 void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
134 const backend::train::TrainableBackendContexts &backend_contexts,
135 const ir::OperandIndexSequence &indices)
137 std::shared_ptr<backend::builtin::train::TensorRegistry> builtin_tensor_reg;
138 for (const auto &e : backend_contexts)
140 auto backend = e.first;
141 auto &context = e.second;
142 if (backend->config()->id() == backend::builtin::Config::ID)
144 builtin_tensor_reg = std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(
145 context->tensor_registry());
148 assert(builtin_tensor_reg);
150 for (auto &&ind : indices)
152 const auto &operand = lowered_graph.graph().operands().at(ind);
153 auto tensor = std::make_unique<backend::builtin::IOTensor>(
155 ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */
158 // Add tensor to builtin TensorRegistry.
159 builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
162 #endif // ONERT_TRAIN
164 backend::BackendContexts
165 createBackendContexts(compiler::ILoweredGraph &lgraph, bool linear_executor,
166 std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder)
168 backend::BackendContexts contexts;
169 auto &backend_manager = compiler::BackendManager::get();
171 std::unordered_map<const backend::Backend *, backend::ContextData> context_data_map;
173 // Generate partial graphs for each backend
174 for (auto &&backend : backend_manager.getAll())
176 auto &data = context_data_map[backend];
177 auto graph = std::make_unique<ir::Graph>();
178 graph->setLayout(lgraph.graph().layout());
179 data.graph = std::move(graph);
182 auto &whole_graph = lgraph.graph();
183 // Separate operands into partial graphs
184 whole_graph.operands().iterate([&](const ir::OperandIndex &operand_ind, ir::Operand &operand) {
185 auto &operand_li = lgraph.lower_info().operand;
186 const auto &def_factors = operand_li.at(operand_ind).def_factors();
187 if (def_factors.size() == 0) // Ignore unused tensor
189 const auto &def_factor = def_factors.getOnlyElement();
190 const auto backend = def_factor.backend();
191 auto &partial_graph = *context_data_map[backend].graph;
192 auto &operand_layouts = context_data_map[backend].operand_layouts;
193 assert(operand_layouts.find(operand_ind) == operand_layouts.end());
194 operand_layouts[operand_ind] = def_factor.layout();
196 // Copy the operand and insert it to the partial graph
197 auto new_operand = std::make_unique<ir::Operand>(operand);
198 new_operand->clearDefUse();
199 operand.releaseData(); // Deref data of LoweredGraph
200 auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand));
201 UNUSED_RELEASE(new_operand_ind);
202 assert(new_operand_ind == operand_ind);
204 // Separate operations into partial graphs
205 whole_graph.operations().iterate(
206 [&](const ir::OperationIndex &op_ind, const ir::IOperation &operation) {
207 auto &op_li = lgraph.lower_info().operation;
208 auto backend = op_li.at(op_ind).backend();
209 auto &partial_graph = *context_data_map[backend].graph;
210 auto &external_operands = context_data_map[backend].external_operands;
211 auto &operand_layouts = context_data_map[backend].operand_layouts;
214 // Add missing operands (externals)
215 auto io_list = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED |
216 ir::Remove::UNDEFINED;
217 for (auto &&operand_ind : io_list)
219 if (partial_graph.operands().exist(operand_ind))
222 // Copy the operand and insert it to the partial graph
223 const auto &operand = whole_graph.operands().at(operand_ind);
224 auto new_operand = std::make_unique<ir::Operand>(operand);
225 new_operand->clearDefUse();
226 auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand));
227 UNUSED_RELEASE(new_operand_ind);
228 assert(new_operand_ind == operand_ind);
231 lgraph.lower_info().operand.at(operand_ind).def_factors().getOnlyElement().layout();
232 assert(operand_layouts.find(operand_ind) == operand_layouts.end());
233 operand_layouts[operand_ind] = layout;
234 external_operands.add(operand_ind);
237 auto new_op_ind = partial_graph.addOperation(op_ind, clone(operation));
238 UNUSED_RELEASE(new_op_ind);
239 assert(new_op_ind == op_ind);
244 auto whole_op_order = lgraph.graph().topolSortOperations();
245 for (auto &&pair : context_data_map)
247 auto backend = pair.first;
248 auto &data = pair.second;
249 // Handle graph input/outputs or external tensors
250 data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) {
251 if (whole_graph.getInputs().contains(ind) || whole_graph.getOutputs().contains(ind))
252 data.external_operands.add(ind);
253 // Inputs are either "graph input" or "no def op and non-constant"
254 if (whole_graph.getInputs().contains(ind) ||
255 (!operand.getDef().valid() && !operand.isConstant()))
256 // Outputs are either "graph output" or "no uses"
257 data.graph->addInput(ind);
258 if (whole_graph.getOutputs().contains(ind) || operand.getUses().size() == 0)
259 data.graph->addOutput(ind);
261 dumper::text::dumpGraph(*data.graph);
263 std::copy_if(whole_op_order.begin(), whole_op_order.end(), std::back_inserter(data.op_order),
264 [&](const auto &ind) { return data.graph->operations().exist(ind); });
265 data.is_linear_executor = linear_executor;
266 data.custom_kernel_builder = custom_kernel_builder;
267 contexts.emplace(backend, backend->newContext(std::move(data)));
272 template <typename Context>
273 std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext(
274 const std::unordered_map<const backend::Backend *, std::unique_ptr<Context>> &tbackend_contexts)
276 std::deque<std::pair<const backend::Backend *, Context *>> ordered_contexts;
278 for (auto &&pair : tbackend_contexts)
280 // NOTE builtin backend must be processed lastly.
281 // This is because of Permute layer's specialty which is the only operation that could have
282 // different ITensor objects for the input and the output. And it requires all other backends'
283 // tensors are ready to use.
284 if (pair.first->config()->id() == "builtin")
285 ordered_contexts.emplace_back(pair.first, pair.second.get());
287 ordered_contexts.emplace_front(pair.first, pair.second.get());
290 return ordered_contexts;
301 ExecutorFactory &ExecutorFactory::get()
303 static ExecutorFactory singleton;
307 ExecutorFactory::ExecutorFactory()
309 _map["Linear"] = createLinearExecutor;
310 _map["Dataflow"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
311 std::placeholders::_3, false);
312 _map["Parallel"] = std::bind(createDataflowExecutor, std::placeholders::_1, std::placeholders::_2,
313 std::placeholders::_3, true);
316 exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
317 const std::shared_ptr<exec::IExecutors> &executors,
318 const ExecutorFactoryArgs &args)
320 assert(args.options != nullptr);
321 return _map.at(args.options->executor)(std::move(lowered_graph), executors, args);
324 void ExecutorFactory::prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
325 const backend::BackendContexts &backend_contexts)
327 TensorRegistries tensor_regs{backend_contexts, true};
329 lowered_graph.graph().operations().iterate(
330 [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
331 auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
332 auto &backend_ctx = backend_contexts.at(lower_info->backend());
334 (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
336 // If an Operation's input/output tensor does not have an own tensor object,
337 // it must be using migrant tensors, so find the tensor from other tensor registries and
338 // register it to the current tensor registry if it is portable
339 if (!backend_ctx->tensor_registry->getITensor(ind))
341 auto tensor = tensor_regs.getITensor(ind);
342 assert(tensor); // The tensor must have been registered
343 auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor);
345 backend_ctx->tensor_registry->setMigrantTensor(ind, ptensor);
351 void ExecutorFactory::prepareBuiltinBackend(const TensorRegistries &tensor_regs,
352 const std::shared_ptr<exec::IExecutors> &executors,
353 const backend::BackendContexts &backend_contexts,
354 const ir::ModelIndex &index)
356 for (auto &&pair : backend_contexts)
358 auto builtin_context = dynamic_cast<backend::builtin::BackendContext *>(pair.second.get());
359 if (builtin_context != nullptr)
361 auto builtin_kernel_gen = builtin_context->kernel_gen;
362 builtin_kernel_gen->setTensorRegistries(tensor_regs);
363 builtin_kernel_gen->setExecutors(executors);
364 builtin_kernel_gen->setModelIndex(index);
369 std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
370 ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_contexts)
372 std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ordered_contexts;
373 for (auto &&pair : backend_contexts)
375 // NOTE builtin backend must be processed lastly.
376 // This is because of Permute layer's specialty which is the only operation that could have
377 // different ITensor objects for the input and the output. And it requires all other backends'
378 // tensors are ready to use.
379 if (pair.first->config()->id() == "builtin")
380 ordered_contexts.emplace_back(pair.first, pair.second.get());
382 ordered_contexts.emplace_front(pair.first, pair.second.get());
384 return ordered_contexts;
388 ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
389 const std::shared_ptr<exec::IExecutors> &executors,
390 const ExecutorFactoryArgs &args)
392 const auto options = args.options;
393 const auto &model_index = args.model_index;
394 const auto tracing_ctx = args.tracing_ctx;
395 auto custom_kernel_builder = args.custom_kernel_builder;
396 auto &graph = lowered_graph->graph();
398 backend::BackendContexts backend_contexts =
399 createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
401 TensorRegistries tensor_regs{backend_contexts, true};
403 initializeSubgraphIOTensors(
404 *lowered_graph, backend_contexts,
405 (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
406 ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
409 auto order = Linear::linearize(*lowered_graph);
410 Linear::dump(*lowered_graph, order);
412 for (auto &&pair : backend_contexts)
414 pair.second->genTensors();
417 prepareMigrantTensors(*lowered_graph, backend_contexts);
419 // Give some runtime objects to builtin KernelGenerator
420 prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
422 ExecutionBuilder builder;
424 // Adjust the order of backends for the upcoming iteration
425 auto ordered_contexts = orderBackendContext(backend_contexts);
427 // Simulate the execution for deallocation of tensors
428 std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map;
430 ir::OperandIndexMap<uint32_t> uses_map;
431 ir::OperandIndexSequence constants;
434 (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
437 graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
438 uses_map[ind] = obj.getUses().size();
440 if (obj.isConstant())
441 constants.append(ind);
444 // A trick to consider constants as an execption
445 for (const auto &ind : constants)
450 for (const auto &op_ind : order)
452 const auto &op = graph.operations().at(op_ind);
453 auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
454 auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
456 for (const auto &ind : op_inputs)
458 const auto &operand = graph.operands().at(ind);
459 assert(uses_map.find(ind) != uses_map.end());
460 assert(uses_map[ind] > 0);
462 if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind))
464 dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind));
469 // Dispose and validate
470 for (const auto &ind : constants)
476 std::all_of(uses_map.begin(), uses_map.end(),
477 [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
481 for (auto &&pair : ordered_contexts)
483 auto codes = pair.second->genKernels();
484 for (auto &&pair : codes)
486 auto &op_ind = pair.first;
487 auto &fn_seq = pair.second;
488 auto &op = lowered_graph->graph().operations().at(op_ind);
489 auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
490 if (options->he_profiling_mode)
491 fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
492 if (!dealloc_list_map[op_ind].empty())
493 fn_seq->append(std::make_unique<DeallocFunction>(dealloc_list_map[op_ind]));
494 builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)});
498 auto code_map = builder.releaseCodeMap();
500 auto exec = new exec::LinearExecutor{std::move(lowered_graph),
501 std::move(backend_contexts),
507 if (!options->trace_filepath.empty())
509 std::unique_ptr<exec::IExecutionObserver> ctp =
510 std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
511 exec->addObserver(std::move(ctp));
513 #ifdef MINMAX_H5DUMPER
514 if (!options->minmax_filepath.empty())
515 exec->addObserver(std::make_unique<exec::MinMaxRecorder>(
516 options->minmax_filepath, exec->graph(), exec->getBackendContexts()));
523 ExecutorFactory::createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
524 const std::shared_ptr<exec::IExecutors> &executors,
525 const ExecutorFactoryArgs &args, bool parallel)
527 const auto options = args.options;
528 const auto &model_index = args.model_index;
529 const auto tracing_ctx = args.tracing_ctx;
530 auto custom_kernel_builder = args.custom_kernel_builder;
532 backend::BackendContexts backend_contexts =
533 createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
535 TensorRegistries tensor_regs{backend_contexts, true};
537 initializeSubgraphIOTensors(
538 *lowered_graph, backend_contexts,
539 (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
540 ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
542 for (auto &&pair : backend_contexts)
544 pair.second->genTensors();
547 prepareMigrantTensors(*lowered_graph, backend_contexts);
549 // Give some runtime objects to builtin KernelGenerator
550 prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
552 ExecutionBuilder builder;
554 // Adjust the order of backends for the upcoming iteration
555 auto ordered_contexts = orderBackendContext(backend_contexts);
558 for (auto &&pair : ordered_contexts)
560 auto codes = pair.second->genKernels();
561 for (auto &&pair : codes)
563 auto &op_ind = pair.first;
564 auto &fn_seq = pair.second;
565 auto &op = lowered_graph->graph().operations().at(op_ind);
566 auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
567 if (options->he_profiling_mode)
568 fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
569 builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)});
573 auto code_map = builder.releaseCodeMap();
575 exec::ExecutorBase *exec = nullptr;
578 exec = new exec::ParallelExecutor{std::move(lowered_graph), std::move(backend_contexts),
579 tensor_regs, std::move(code_map), tracing_ctx};
584 new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs,
585 std::move(code_map), tracing_ctx};
586 if (options->he_profiling_mode)
588 std::vector<const backend::Backend *> backends;
589 for (const auto &pair : backend_contexts)
591 backends.push_back(pair.first);
593 auto et = std::make_shared<exec::ExecTime>(backends);
594 std::unique_ptr<exec::IExecutionObserver> obs =
595 std::make_unique<exec::ProfileObserver>(et, dataflow_exec->graph());
596 dataflow_exec->addObserver(std::move(obs));
598 exec = dataflow_exec;
601 if (!options->trace_filepath.empty())
603 std::unique_ptr<exec::IExecutionObserver> ctp =
604 std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
605 exec->addObserver(std::move(ctp));
613 ExecutorFactory::create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
614 const std::shared_ptr<exec::IExecutors> &executors,
615 const ExecutorFactoryArgs &args,
616 const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer)
618 assert(args.options != nullptr);
620 if (args.options->executor != "Linear")
621 throw std::runtime_error("ExecutorFactory: TrainableExecutor supports only 'Linear' now");
623 return createTrainableExecutor(std::move(lowered_graph), executors, args, optimizer);
626 void ExecutorFactory::prepareMigrantTensors(
627 compiler::ILoweredGraph &lowered_graph,
628 const backend::train::TrainableBackendContexts &backend_contexts)
630 train::TensorRegistries tensor_regs{backend_contexts, true};
632 lowered_graph.graph().operations().iterate(
633 [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
634 auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
635 auto &backend_ctx = backend_contexts.at(lower_info->backend());
637 (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
639 // If an Operation's input/output tensor does not have an own tensor object,
640 // it must be using migrant tensors, so find the tensor from other tensor registries and
641 // register it to the current tensor registry if it is portable
642 if (!backend_ctx->tensor_registry()->getITensor(ind))
644 auto tensor = tensor_regs.getITensor(ind);
645 assert(tensor); // The tensor must have been registered
646 auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor);
648 backend_ctx->tensor_registry()->setMigrantTensor(ind, ptensor);
654 exec::IExecutor *ExecutorFactory::createTrainableExecutor(
655 std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
656 const std::shared_ptr<exec::IExecutors> &, const ExecutorFactoryArgs &args,
657 const std::shared_ptr<exec::train::optimizer::Optimizer> &optimizer)
659 const auto options = args.options;
660 const auto tracing_ctx = args.tracing_ctx;
661 auto custom_kernel_builder = args.custom_kernel_builder;
663 auto &graph = lowered_graph->graph();
665 lowered_graph->trainable_graph().operations().iterate([](const onert::ir::OperationIndex &,
666 const onert::ir::IOperation &op) {
669 UNUSED_RELEASE(dynamic_cast<const ir::train::ITrainableOperation &>(op));
671 catch (std::bad_cast &)
673 throw std::runtime_error("ExecutorFactory: " + op.name() + " is not trainable operation yet");
677 // TODO Create context only once instead of replacing
678 backend::train::TrainableBackendContexts tbackend_contexts;
679 backend::BackendContexts base_backend_contexts =
680 createBackendContexts(*lowered_graph, true, custom_kernel_builder);
682 // Replace BackendContext with TrainbleBackendContext
683 for (auto &&pair : base_backend_contexts)
685 auto ctx = pair.second.get();
686 const auto &data = ctx->data();
688 // Create partial and trainable graphs
689 auto tgraph = std::make_unique<ir::train::TrainableGraph>(*data.graph);
690 data.graph->operations().iterate(
691 [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &) {
692 const auto &orig_tgraph = lowered_graph->trainable_graph();
693 const auto &trainable_op = orig_tgraph.operation(op_index);
694 auto gen_index = tgraph->replaceOperation(op_index, trainable_op.clone());
695 UNUSED_RELEASE(gen_index);
696 assert(gen_index == op_index);
698 data.graph->operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
699 const auto &orig_tgraph = lowered_graph->trainable_graph();
700 if (orig_tgraph.derivatives().exist(index))
702 const auto &deriv = orig_tgraph.derivatives().at(index);
703 auto new_deriv = std::make_unique<ir::Operand>(deriv);
704 auto gen_index = tgraph->addDerivative(index, std::move(new_deriv));
705 UNUSED_RELEASE(gen_index);
706 assert(gen_index == index);
710 // Remove outputs of whole graph from external_operands
711 auto external_operands = data.external_operands;
712 for (const auto &index : lowered_graph->trainable_graph().getOutputs())
714 if (external_operands.contains(index))
715 external_operands.remove(index);
718 // Set trainable context data
719 backend::train::TrainableContextData tdata;
720 tdata.tgraph = std::move(tgraph);
721 tdata.op_order = std::move(data.op_order);
722 tdata.external_operands = std::move(external_operands);
723 tdata.operand_layouts = std::move(data.operand_layouts);
724 tdata.custom_kernel_builder = std::move(data.custom_kernel_builder);
725 tdata.is_linear_executor = data.is_linear_executor;
726 tdata.optimizer = optimizer;
728 // TODO Remove dynamic_cast
731 const auto backend = pair.first;
732 const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend);
733 tbackend_contexts.emplace(backend, tbackend->newContext(std::move(tdata)));
735 catch (const std::bad_cast &)
737 throw std::runtime_error("ExecutorFactory: Invalid backend - TrainableExecutor does not "
738 "support non-trainble backends");
741 base_backend_contexts.clear();
743 train::TensorRegistries tensor_regs{tbackend_contexts, true};
745 initializeSubgraphIOTensors(
746 *lowered_graph, tbackend_contexts,
747 (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
748 ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
751 auto order = Linear::linearize(*lowered_graph);
752 Linear::dump(*lowered_graph, order);
754 for (auto &&pair : tbackend_contexts)
756 pair.second->genTensors();
759 for (auto &&pair : tbackend_contexts)
761 auto tctx = pair.second.get();
762 tctx->genTrainingTensors();
765 prepareMigrantTensors(*lowered_graph, tbackend_contexts);
767 // Give some runtime objects to builtin KernelGenerator
768 for (auto &&pair : tbackend_contexts)
770 auto builtin_context =
771 dynamic_cast<backend::builtin::train::BackendContext *>(pair.second.get());
772 if (builtin_context != nullptr)
774 auto builtin_kernel_gen = builtin_context->kernel_gen;
775 builtin_kernel_gen->setTensorRegistries(tensor_regs);
776 builtin_kernel_gen->setWholeGraphOutputs(lowered_graph->trainable_graph().getOutputs());
780 // Adjust the order of backends for the upcoming iteration
781 auto ordered_contexts =
782 onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts);
784 // TODO Remove this simulation
785 // Simulate the execution for deallocation of tensors
786 std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map;
788 ir::OperandIndexMap<uint32_t> uses_map;
789 ir::OperandIndexSequence constants;
792 (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
795 graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
796 uses_map[ind] = obj.getUses().size();
798 if (obj.isConstant())
799 constants.append(ind);
802 // A trick to consider constants as an execption
803 for (const auto &ind : constants)
808 for (const auto op_ind : order)
810 const auto &op = graph.operations().at(op_ind);
811 auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
812 auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
814 for (const auto &ind : op_inputs)
816 const auto &operand = graph.operands().at(ind);
817 assert(uses_map.find(ind) != uses_map.end());
818 assert(uses_map[ind] > 0);
820 if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind))
822 dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind));
827 // Dispose and validate
828 for (const auto &ind : constants)
834 std::all_of(uses_map.begin(), uses_map.end(),
835 [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
838 // Check derivative tensors
840 // TODO Support multiple subgraphs
841 // Check if the derivative tensors corresponding to inputs of model are nullptr
842 // NOTE The derivative tensors corresponding to inputs of model are for inputs of PermuteLayers
843 // and they are nullptr and because they are meaningless.
844 assert(std::all_of(lowered_graph->trainable_graph().getInputs().begin(),
845 lowered_graph->trainable_graph().getInputs().end(),
846 [&](const auto &input_idx) {
847 return tensor_regs.getDerivativeITensor(input_idx) == nullptr;
850 // Check if the derivative tensors corresponding to outputs of model exist
851 assert(std::all_of(lowered_graph->trainable_graph().getOutputs().begin(),
852 lowered_graph->trainable_graph().getOutputs().end(),
853 [&](const auto &output_idx) {
854 return tensor_regs.getDerivativeITensor(output_idx) == nullptr;
858 train::TrainableCodeMap code_map;
860 for (auto &&pair : ordered_contexts)
862 auto codes = pair.second->genKernels();
863 for (auto &&pair : codes)
865 auto &op_ind = pair.first;
866 auto &tn_seq = pair.second;
867 auto &op = lowered_graph->trainable_graph().operation(op_ind);
868 auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
870 assert(code_map.find(op_ind) == code_map.end());
872 {op_ind, train::TrainableCodeAndInfo{op_ind, &op, lower_info, std::move(tn_seq)}});
876 if (order.size() != code_map.size())
878 throw std::runtime_error("ExecutorFactory: Some kernels are not generated");
881 auto exec = new exec::train::TrainableExecutor{std::move(lowered_graph),
882 std::move(tbackend_contexts),
888 if (!options->trace_filepath.empty())
890 std::unique_ptr<exec::IExecutionObserver> ctp =
891 std::make_unique<exec::TracingObserver>(options->trace_filepath, exec->graph(), tracing_ctx);
892 exec->addObserver(std::move(ctp));
894 // TODO Support MINMAX_H5DUMPER
898 #endif // ONERT_TRAIN
900 } // namespace compiler