2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "LossInsertionPass.h"
19 #include "ir/train/TrainableGraph.h"
20 #include "ir/train/operation/Loss.h"
31 void LossInsertionPass::run()
33 const auto &loss_info = _training_info->lossInfo();
35 ir::operation::Loss::Param param;
36 param.op_type = loss_info.type;
38 if (_trainable_graph.getOutputs().size() != 1)
40 throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
43 // TODO Consider SparseCategoricalCrossentropy y_true shape
44 // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
46 // TODO Implement Loop [0, getOutputs().size())
47 // index: a loop index
49 const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
50 const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
51 const auto &shape = y_pred.shape();
52 const auto &type_info = y_pred.typeInfo();
53 auto y_true_index = _trainable_graph.addOperand(shape, type_info);
54 ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
56 // TODO Consider Reduction
57 // Some types of Reduction have the same shape y_true and output.
59 const ir::TypeInfo float_op(ir::DataType::FLOAT32);
60 auto output_index = _trainable_graph.addOperand(ir::Shape{1}, float_op);
61 ir::OperandIndexSequence outputs{output_index};
63 auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs, param);
64 auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op);
66 _trainable_graph.addOperation(std::move(trainable_loss_op));
68 _trainable_graph.addInput(y_true_index);
70 // TODO Add loss as many as output size
71 _trainable_graph.addLoss(output_index, ir::IOIndex{index});
76 } // namespace compiler