Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / train / BackendContext.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "BackendContext.h"
18
19 #include "TensorBuilder.h"
20 #include "KernelGenerator.h"
21
22 #include <backend/basic/train/TrainableBackendContextHelpers.h>
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace train
29 {
30
31 backend::ITensorRegistry *BackendContext::genTensors()
32 {
33   return basic::train::genTensors(*this, _tensor_builder);
34 }
35
36 backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
37 {
38   const ir::train::TrainableGraph &tgraph = *trainable_graph();
39   auto tensor_builder = _tensor_builder;
40
41   tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
42     if (external_operands().contains(ind))
43       return;
44     // NOTE Assuming there is no layout changes (Always assume NHWC or UNKNOWN)
45     assert(tgraph.layout() != ir::Layout::NCHW);
46
47     // TODO Different shape of deriv tensor
48     ir::OperandInfo backend_info{obj.shape(), obj.typeInfo(), obj.info().memAllocType(),
49                                  obj.isConstant()};
50     tensor_builder->registerBackwardTensorInfo(ind, backend_info, ir::Layout::NHWC);
51   });
52
53   // TODO Plan tensor builds to reduce peak memory usage
54   tgraph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
55     if (tensor_builder->isRegisteredBackward(ind))
56       tensor_builder->notifyBackwardFirstUse(ind);
57   });
58
59   tensor_builder->allocateBackward();
60
61   return _tensor_registry.get();
62 }
63
64 FunctionMap BackendContext::genKernels()
65 {
66   train::FunctionMap ret;
67
68   for (const auto &op_ind : _tdata->op_order)
69   {
70     auto fn_seq = kernel_gen->generate(op_ind);
71     ret.emplace_back(op_ind, std::move(fn_seq));
72   }
73
74   // Initialize TrainableTensors
75   trainable_graph()->operands().iterate(
76     [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
77       if (external_operands().contains(ind) || !operand.isConstant())
78         return;
79
80       auto tensor = tensor_registry()->getNativeITensor(ind);
81       assert(tensor != nullptr);
82
83       VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;
84
85       auto data = operand.shareData();
86       assert(data && data->base());
87       auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);
88
89       if (trainable_tensor == nullptr)
90         throw std::runtime_error{"This tensor is not trainable tensor"};
91
92       trainable_tensor->fillBuffer(data);
93     });
94
95   // NOTE For memory optimization, we want to free some operand data
96   const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
97     .operands()
98     .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
99
100   // TODO Enable
101   // for (auto &&it : ret)
102   // {
103   //   auto &fn_seq = it.second;
104   //   fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
105   // }
106
107   return ret;
108 }
109
110 } // namespace train
111 } // namespace backend
112 } // namespace onert