Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / train / TrainableFnSequence.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "exec/train/TrainableFnSequence.h"
18
19 namespace onert
20 {
21 namespace exec
22 {
23 namespace train
24 {
25
26 void TrainableFnSequence::forward(bool training)
27 {
28   for (const auto &function : _functions)
29   {
30     function->forward(training);
31   }
32 }
33
34 void TrainableFnSequence::backward(uint32_t training_step)
35 {
36   for (auto it = _functions.rbegin(); it != _functions.rend(); ++it)
37   {
38     (*it)->backward();
39   }
40
41   for (const auto &applier : _appliers)
42   {
43     applier->applyGradient(training_step);
44   }
45 }
46
47 void TrainableFnSequence::append(std::unique_ptr<ITrainableFunction> &&function)
48 {
49   _functions.push_back(std::move(function));
50 }
51
52 void TrainableFnSequence::append(std::unique_ptr<IGradientApplier> &&applier)
53 {
54   _appliers.push_back(std::move(applier));
55 }
56
57 void TrainableFnSequence::iterate(const std::function<void(ITrainableFunction &)> &fn)
58 {
59   for (const auto &func : _functions)
60   {
61     fn(*func);
62   }
63 }
64
65 } // namespace train
66 } // namespace exec
67 } // namespace onert