Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / train / ops / LossLayer.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "LossLayer.h"
18 #include "OperationUtils.h"
19
20 #include <cker/train/operation/Loss.h>
21
22 namespace onert
23 {
24 namespace backend
25 {
26 namespace train
27 {
28 namespace ops
29 {
30
31 LossLayer::LossLayer()
32   : _y_pred(nullptr), _y_true(nullptr), _output(nullptr), _deriv_y_pred(nullptr),
33     _loss_type(LossType::kMSE)
34 {
35   // DO NOTHING
36 }
37
38 void LossLayer::configure(const IPortableTensor *y_pred, const IPortableTensor *y_true,
39                           IPortableTensor *output, IPortableTensor *deriv_y_pred,
40                           LossType loss_type)
41 {
42   assert(y_pred != nullptr);
43   assert(y_true != nullptr);
44   assert(output != nullptr);
45   assert(deriv_y_pred != nullptr);
46   switch (loss_type)
47   {
48     case LossType::kMSE:
49       break;
50     default:
51       throw std::runtime_error("LossLayer: unsupported loss type");
52   }
53
54   _y_pred = y_pred;
55   _y_true = y_true;
56   _output = output;
57   _deriv_y_pred = deriv_y_pred;
58   _loss_type = loss_type;
59 }
60
61 void LossLayer::forward(bool)
62 {
63   // TODO Implement this
64   switch (_loss_type)
65   {
66     case LossType::kMSE:
67       if (_y_pred->data_type() == OperandType::FLOAT32)
68       {
69         nnfw::cker::train::MSE(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true),
70                                getBuffer<float>(_y_true), getShape(_output),
71                                getBuffer<float>(_output));
72       }
73       break;
74     default:
75       throw std::runtime_error("LossLayer: unsupported loss type");
76   }
77 }
78
79 void LossLayer::backward()
80 {
81   switch (_loss_type)
82   {
83     case LossType::kMSE:
84       if (_y_pred->data_type() == OperandType::FLOAT32)
85       {
86         nnfw::cker::train::MSEGrad(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true),
87                                    getBuffer<float>(_y_true), getShape(_deriv_y_pred),
88                                    getBuffer<float>(_deriv_y_pred));
89       }
90       break;
91     default:
92       throw std::runtime_error("LossLayer: unsupported loss type");
93   }
94 }
95
96 } // namespace ops
97 } // namespace train
98 } // namespace backend
99 } // namespace onert