2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "LossLayer.h"
18 #include "OperationUtils.h"
20 #include <cker/train/operation/Loss.h>
31 LossLayer::LossLayer()
32 : _y_pred(nullptr), _y_true(nullptr), _output(nullptr), _deriv_y_pred(nullptr),
33 _loss_type(LossType::kMSE)
38 void LossLayer::configure(const IPortableTensor *y_pred, const IPortableTensor *y_true,
39 IPortableTensor *output, IPortableTensor *deriv_y_pred,
42 assert(y_pred != nullptr);
43 assert(y_true != nullptr);
44 assert(output != nullptr);
45 assert(deriv_y_pred != nullptr);
51 throw std::runtime_error("LossLayer: unsupported loss type");
57 _deriv_y_pred = deriv_y_pred;
58 _loss_type = loss_type;
61 void LossLayer::forward(bool)
63 // TODO Implement this
67 if (_y_pred->data_type() == OperandType::FLOAT32)
69 nnfw::cker::train::MSE(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true),
70 getBuffer<float>(_y_true), getShape(_output),
71 getBuffer<float>(_output));
75 throw std::runtime_error("LossLayer: unsupported loss type");
79 void LossLayer::backward()
84 if (_y_pred->data_type() == OperandType::FLOAT32)
86 nnfw::cker::train::MSEGrad(getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true),
87 getBuffer<float>(_y_true), getShape(_deriv_y_pred),
88 getBuffer<float>(_deriv_y_pred));
92 throw std::runtime_error("LossLayer: unsupported loss type");
98 } // namespace backend