Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / train / pass / LossInsertionPass.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "LossInsertionPass.h"
18
19 #include "ir/train/TrainableGraph.h"
20 #include "ir/train/operation/Loss.h"
21
22 namespace onert
23 {
24 namespace compiler
25 {
26 namespace train
27 {
28 namespace pass
29 {
30
31 void LossInsertionPass::run()
32 {
33   const auto &loss_info = _training_info->lossInfo();
34
35   ir::operation::Loss::Param param;
36   param.op_type = loss_info.type;
37
38   if (_trainable_graph.getOutputs().size() != 1)
39   {
40     throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
41   }
42
43   // TODO Consider SparseCategoricalCrossentropy y_true shape
44   //      SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
45
46   // TODO Implement Loop [0, getOutputs().size())
47   //      index: a loop index
48   const auto index = 0;
49   const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
50   const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
51   const auto &shape = y_pred.shape();
52   const auto &type_info = y_pred.typeInfo();
53   auto y_true_index = _trainable_graph.addOperand(shape, type_info);
54   ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
55
56   // TODO Consider Reduction
57   //      Some types of Reduction have the same shape y_true and output.
58
59   const ir::TypeInfo float_op(ir::DataType::FLOAT32);
60   auto output_index = _trainable_graph.addOperand(ir::Shape{1}, float_op);
61   ir::OperandIndexSequence outputs{output_index};
62
63   auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs, param);
64   auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op);
65
66   _trainable_graph.addOperation(std::move(trainable_loss_op));
67
68   _trainable_graph.addInput(y_true_index);
69
70   // TODO Add loss as many as output size
71   _trainable_graph.addLoss(output_index, ir::IOIndex{index});
72 }
73
74 } // namespace pass
75 } // namespace train
76 } // namespace compiler
77 } // namespace onert