Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / backend / builtin / train / BackendContext.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "BackendContext.h"
18
19 #include "backend/basic/train/TrainableBackendContextHelpers.h"
20 #include "exec/FunctionSequence.h"
21
22 namespace onert
23 {
24 namespace backend
25 {
26 namespace builtin
27 {
28 namespace train
29 {
30
31 backend::ITensorRegistry *BackendContext::genTensors()
32 {
33   // For now, there is no need to generate tensors for forwarding.
34   // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
35   // `Permute`: Tensor generation is not required.
36   // `IF`, `WHILE`: Not supported yet
37   return tensor_registry().get();
38 }
39
40 backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
41 {
42   // For now, there is no need to generate tensors for backwarding.
43   return tensor_registry().get();
44 }
45
46 backend::train::FunctionMap BackendContext::genKernels()
47 {
48   backend::train::FunctionMap ret;
49
50   for (auto &&op_ind : _tdata->op_order)
51   {
52     auto tn_seq = kernel_gen->generate(op_ind);
53     ret.emplace_back(op_ind, std::move(tn_seq));
54   }
55
56   trainable_graph()->operands().iterate(
57     [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
58       if (!external_operands().contains(ind) && operand.isConstant())
59       {
60         throw std::runtime_error(
61           "BackendContext: builtin backend does not support updatable weights yet");
62       }
63     });
64
65   // TODO Enable prepare()
66   // for (auto &&it : ret)
67   // {
68   //   auto &fn_seq = it.second;
69   //   fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
70   // }
71
72   return ret;
73 }
74
75 } // namespace train
76 } // namespace builtin
77 } // namespace backend
78 } // namespace onert