Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / Fp32ToFp16Converter.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__
18 #define __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__
19
20 #include "compiler/LoweredGraph.h"
21
22 namespace onert
23 {
24
25 namespace compiler
26 {
27
28 class Fp32ToFp16Converter
29 {
30 public:
31   Fp32ToFp16Converter(compiler::LoweredGraph &lowered_graph);
32
33 public:
34   void run();
35
36 private:
37   using OpSeqIndexList = std::unordered_set<ir::OpSequenceIndex>;
38   using InputToOpSeqs = std::unordered_map<ir::OperandIndex, OpSeqIndexList>;
39   using OpSeqIndexToOpSeqIndexList = std::unordered_map<ir::OpSequenceIndex, OpSeqIndexList>;
40
41 private:
42   void appendOpSequences();
43   void optimize();
44   void convertOperands();
45   void convertDatas();
46   void printOpSequences(const std::string &pre_msg = std::string(),
47                         const std::string &post_msg = std::string());
48
49   bool checkOperandType(const ir::OperandIndex &op_ind) const;
50   bool checkOperandsOfOpSequence(const ir::OpSequence &op_seq) const;
51
52   void appendNewOpSeqForConvertFp32ToFp16(const ir::OpSequenceIndex &op_seq_ind,
53                                           ir::OpSequence &op_seq);
54   void appendNewOpSeqForConvertFp16ToFp32(const ir::OpSequenceIndex &op_seq_ind,
55                                           ir::OpSequence &op_seq);
56
57   ir::OperandIndex newCopiedOperand(const ir::OperandIndex &op_ind);
58   ir::OperationIndex newOperationConvertFp32ToFp16(const ir::OperandIndex &op_seq_input_ind,
59                                                    const ir::OperandIndex &new_op_ind);
60   ir::OperationIndex newOperationConvertFp16ToFp32(const ir::OperandIndex &op_seq_output_ind,
61                                                    const ir::OperandIndex &new_op_ind);
62   ir::OpSequenceIndex newOpSequence(const ir::OpSequenceIndex &op_seq_ind,
63                                     const ir::OperationIndex &node_index);
64
65   void setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
66                               const ir::OperandIndex &new_op_ind);
67   void setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
68                                  const ir::OpSequenceIndex &new_op_seq_ind);
69
70   void manipulateInput(const ir::OpSequenceIndex &op_seq_ind,
71                        const ir::OperandIndex &op_seq_input_ind,
72                        const ir::OperandIndex &new_op_ind);
73   void manipulateOutput(const ir::OpSequenceIndex &op_seq_ind,
74                         const ir::OperandIndex &op_seq_output_ind,
75                         const ir::OperandIndex &new_op_ind);
76
77   void removeContiguousConvertOpSequences();
78   InputToOpSeqs prepareInputToOpSeqs() const;
79   OpSeqIndexToOpSeqIndexList
80   findOpSequencesContiguous(const InputToOpSeqs &intput_to_op_seqs) const;
81   OpSeqIndexList getListOpSequences(const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete) const;
82   ir::OperandIndexSequence
83   findOperationsToDelete(const OpSeqIndexList &list_to_delete_op_seqs) const;
84   void manipulateContiguousOpSequences(const InputToOpSeqs &input_to_op_seqs,
85                                        const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete);
86   void deleteContiguousOpSequences(const OpSeqIndexList &list_to_delete_op_seqs,
87                                    const ir::OperandIndexSequence &list_to_delete_ops);
88
89   void convertOperandsOfOpSequence(ir::OpSequence &op_seq);
90
91 private:
92   compiler::LoweredGraph &_lowered_graph;
93   OpSeqIndexList _list_fp32_to_fp16;
94   OpSeqIndexList _list_fp16_to_fp32;
95 };
96
97 } // namespace compiler
98
99 } // namespace onert
100
101 #endif // __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__