2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
18 #define __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
20 #include "cker/Shape.h"
21 #include "cker/eigen/Utils.h"
31 inline void MSE(const Shape &y_pred_shape, const T *y_pred_data, const Shape &y_true_shape,
32 const T *y_true_data, const Shape &output_shape, T *output_data)
34 // TODO Consider Reduction
35 if (output_shape != Shape{1})
36 throw std::runtime_error("cker::MSE: output_shape != Shape{1}");
37 if (y_pred_shape != y_true_shape)
38 throw std::runtime_error("cker::MSE: y_pred_shape != y_true_shape");
40 const auto y_pred = MapAsMatrixWithLastDimAsRows(y_pred_data, y_pred_shape);
41 const auto y_true = MapAsMatrixWithLastDimAsRows(y_true_data, y_true_shape);
43 double squared_sum = 0.0f;
44 for (size_t c = 0; c < (size_t)y_pred.cols(); ++c)
46 for (size_t r = 0; r < (size_t)y_pred.rows(); ++r)
48 double error = y_pred.coeff(r, c) - y_true.coeff(r, c);
49 squared_sum += (error * error);
53 auto size = y_pred.cols() * y_pred.rows();
54 output_data[0] = (T)(squared_sum / size);
58 inline void MSEGrad(const Shape &y_pred_shape, const T *y_pred_data, const Shape &y_true_shape,
59 const T *y_true_data, const Shape &grad_shape, T *grad_data)
61 if (y_pred_shape != y_true_shape)
62 throw std::runtime_error("cker::MSEGrad: y_pred_shape != y_true_shape");
63 if (y_pred_shape != grad_shape)
64 throw std::runtime_error("cker::MSEGrad: y_pred_shape != grad_shape");
66 const int size = grad_shape.FlatSize();
67 for (int i = 0; i < size; ++i)
69 grad_data[i] = static_cast<T>(-2 * (y_true_data[i] - y_pred_data[i]) / size);
77 #endif // __NNFW_CKER_TRAIN_OPERATION_LOSS_H__