Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / Fp32ToFp16Converter.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Fp32ToFp16Converter.h"
18 #include "ir/operation/ConvertFp32ToFp16.h"
19 #include "ir/operation/ConvertFp16ToFp32.h"
20 #include "util/logging.h"
21
22 #include <Half.h>
23
24 using float16 = Half;
25
26 namespace
27 {
28
29 const std::string kAclClBackendConfigId = "acl_cl";
30
31 void copyDataFromFp32ToFp16(const float *from, float16 *into, size_t num_elements)
32 {
33   for (size_t i = 0; i < num_elements; ++i)
34   {
35     into[i] = static_cast<float16>(from[i]);
36   }
37 }
38
39 } // namespace
40
41 namespace onert
42 {
43
44 namespace compiler
45 {
46
47 Fp32ToFp16Converter::Fp32ToFp16Converter(compiler::LoweredGraph &lowered_graph)
48     : _lowered_graph{lowered_graph}
49 {
50   VERBOSE(Fp32ToFp16Converter) << "Fp16 Enable on" << std::endl;
51 }
52
53 // For example, two OpSequences are there and each OpSequence has an Operation
54 //
55 //   OP#0      // model input
56 //    |
57 // [OPERATION] // OpSeq#0
58 //    |
59 //   OP#1
60 //    |
61 // [OPERATION] // OpSeq#1
62 //    |
63 //   OP#2      // model output
64 //
65 //
66 // AFTER `appendOpSequences()`,
67 // note that model_input and model_output are not changed.
68 //
69 //   OP#0
70 //    |
71 // [FP32TO16]  // OpSeq#2
72 //    |
73 //   OP#3
74 //    |
75 // [OPERATION] // OpSeq#0
76 //    |
77 //   OP#4
78 //    |
79 // [FP16TO32]  // OpSeq#3
80 //    |
81 //   OP#1
82 //    |
83 // [FP32TO16]  // OpSeq#4
84 //    |
85 //   OP#5
86 //    |
87 // [OPERATION] // OpSeq#1
88 //    |
89 //   OP#6
90 //    |
91 // [FP16TO32]  // OpSeq#5
92 //    |
93 //   OP#2
94 //
95 //
96 // AFTER `optimize()`,
97 //
98 //   OP#0
99 //    |
100 // [FP32TO16]  // OpSeq#2
101 //    |
102 //   OP#3
103 //    |
104 // [OPERATION] // OpSeq#0
105 //    |
106 //   OP#4
107 //    |
108 // [OPERATION] // OpSeq#1
109 //    |
110 //   OP#6
111 //    |
112 // [FP16TO32]  // OpSeq#5
113 //    |
114 //   OP#2
115 //
116 //
117 // AFTER `convertOperands()`,
118 //
119 //   OP#0      // model_input, not fp16
120 //    |
121 // [FP32TO16]  // OpSeq#2
122 //    |
123 //   OP#3      // fp16
124 //    |
125 // [OPERATION] // OpSeq#0
126 //    |
127 //   OP#4      // fp16
128 //    |
129 // [OPERATION] // OpSeq#1
130 //    |
131 //   OP#6      // fp16
132 //    |
133 // [FP16TO32]  // OpSeq#5
134 //    |
135 //   OP#2      // model_output, notfp16
136 //
137 //
138 // AFTER `convertDatas()`,
139 //
140 //   OP#0      // model_input, not fp16
141 //    |
142 // [FP32TO16]  // OpSeq#2
143 //    |
144 //   OP#3      // fp16
145 //    |
146 // [OPERATION] // OpSeq#0, constants are fp16
147 //    |
148 //   OP#4      // fp16
149 //    |
150 // [OPERATION] // OpSeq#1, constants are fp16
151 //    |
152 //   OP#6      // fp16
153 //    |
154 // [FP16TO32]  // OpSeq#5
155 //    |
156 //   OP#2      // model_output, notfp16
157 //
158 void Fp32ToFp16Converter::run()
159 {
160   // Append new OpSequence which includes ConvertFp32ToFp16
161   //   and append new OpSequence which includes ConvertFp16ToFp32
162   appendOpSequences();
163
164   // Remove unnecessary converting operations
165   optimize();
166
167   // Convert operands' data types from fp32 to fp16
168   convertOperands();
169
170   // Convert Datas
171   convertDatas();
172
173   // Print the result
174   printOpSequences("FINAL OpSequences");
175 }
176
177 void Fp32ToFp16Converter::appendOpSequences()
178 {
179   _lowered_graph.op_seqs().iterate(
180       [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
181         const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
182         assert(lower_info != nullptr);
183
184         // For now, the only acl_cl supports fully fp16 type
185         // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat
186         // operations.
187         //      To do this, we could check the support by `operation by operation`. After that, we
188         //      would partition an op_seq if it contains unsupported operations.
189         if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
190           return;
191
192         // OpSeq's input set should be included in the first operation's input set or
193         // OpSeq's output set should be included in the last operation's output set
194         assert(checkOperandsOfOpSequence(op_seq));
195
196         // Append converting OpSequence for fp16 but all operands' types are not fp16 still.
197         appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq);
198         appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq);
199       });
200 }
201
202 //
203 // BEFORE
204 //
205 //   OP#0      // model input
206 //    |
207 // [OPERATION] // OpSeq#0
208 //    |
209 //   OP#1      // model output
210 //
211 //
212 // AFTER
213 //
214 //   OP#0      // model input
215 //    |
216 // [FP32TO16]  // OpSeq#1
217 //    |
218 //   OP#2
219 //    |
220 // [OPERATION] // OpSeq#0
221 //    |
222 //   OP#1      // model output
223 //
224 void Fp32ToFp16Converter::appendNewOpSeqForConvertFp32ToFp16(const ir::OpSequenceIndex &op_seq_ind,
225                                                              ir::OpSequence &op_seq)
226 {
227   // OpSeq's input set is included in the first operation's input set
228   const ir::OperandIndexSequence op_seq_inputs = op_seq.getInputs(); // copied
229
230   // NOTE Please do not change sequence of op_seq_inputs. It can change the sequence of inputs of
231   // Subgraph
232   for (const auto &op_seq_input_ind :
233        op_seq_inputs | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
234   {
235     if (checkOperandType(op_seq_input_ind) == false)
236       continue;
237
238     // new operand w/ datatype fp32
239     const auto new_op_ind = newCopiedOperand(op_seq_input_ind);
240
241     // set new lower_info for operand
242     setNewOperandLowerInfo(op_seq_ind, new_op_ind);
243
244     // manipulate input of operation and op_seq
245     // - replace the first operation's input to new operand
246     //   with old operand's removeUse and new operand's appendUse()
247     manipulateInput(op_seq_ind, op_seq_input_ind, new_op_ind);
248
249     // new op
250     const auto new_node_ind = newOperationConvertFp32ToFp16(op_seq_input_ind, new_op_ind);
251
252     // new op_seq
253     const auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
254
255     // set new lower_info for op_seq
256     setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind);
257
258     _list_fp32_to_fp16.insert(new_op_seq_ind);
259
260     VERBOSE(Fp32ToFp16Converter) << "NEW   |Fp32To16]"
261                                  << ir::getStrFromOpSeq(_lowered_graph.op_seqs().at(new_op_seq_ind),
262                                                         _lowered_graph.graph().operations())
263                                  << std::endl;
264   }
265 }
266
267 //
268 // BEFORE
269 //
270 //   OP#0      // model input
271 //    |
272 // [FP32TO16]  // OpSeq#1
273 //    |
274 //   OP#2
275 //    |
276 // [OPERATION] // OpSeq#0
277 //    |
278 //   OP#1      // model output
279 //
280 //
281 // AFTER
282 //
283 //   OP#0      // model input
284 //    |
285 // [FP32TO16]  // OpSeq#1
286 //    |
287 //   OP#2
288 //    |
289 // [OPERATION] // OpSeq#0
290 //    |
291 //   OP#3
292 //    |
293 // [FP16TO32]  // OpSeq#2
294 //    |
295 //   OP#1      // model output
296 //
297 void Fp32ToFp16Converter::appendNewOpSeqForConvertFp16ToFp32(const ir::OpSequenceIndex &op_seq_ind,
298                                                              ir::OpSequence &op_seq)
299 {
300   // OpSeq's output set is included in the last operation's output set
301   const ir::OperandIndexSequence op_seq_outputs = op_seq.getOutputs(); // copied
302
303   // NOTE Please do not change sequence of op_seq_outputs. It can change the sequence of outputs of
304   // Subgraph
305   for (const auto &op_seq_output_ind :
306        op_seq_outputs | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
307   {
308     if (checkOperandType(op_seq_output_ind) == false)
309       continue;
310
311     // new operand w/ datatype fp32
312     const auto new_op_ind = newCopiedOperand(op_seq_output_ind);
313
314     // set new lower_info for operand
315     setNewOperandLowerInfo(op_seq_ind, new_op_ind);
316
317     // manipulate output of operation and op_seq
318     // - replace output of the last operation's output to new operand
319     //    with old operand's unsetDef and new operand's appendDef()
320     manipulateOutput(op_seq_ind, op_seq_output_ind, new_op_ind);
321
322     // new op
323     auto new_node_ind = newOperationConvertFp16ToFp32(op_seq_output_ind, new_op_ind);
324
325     // new op_seq
326     auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
327
328     // set new lower_info for op_seq
329     setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind);
330
331     _list_fp16_to_fp32.insert(new_op_seq_ind);
332
333     VERBOSE(Fp32ToFp16Converter) << "NEW   |Fp16To32]"
334                                  << ir::getStrFromOpSeq(_lowered_graph.op_seqs().at(new_op_seq_ind),
335                                                         _lowered_graph.graph().operations())
336                                  << std::endl;
337   }
338 }
339
340 void Fp32ToFp16Converter::optimize()
341 {
342   printOpSequences("BEFORE opt");
343
344   removeContiguousConvertOpSequences();
345
346   printOpSequences("AFTER removeContiguousConverts");
347
348   // TODO Handle Split from the beginning of the model. ex) MODELS/inception_module
349   //
350   // BEFORE)
351   //
352   //   OP#0---------------------.         // model_input
353   //    |                       |
354   // [FP32TO16]  // OpSeq#0   [FP32TO16]  // OpSeq#1
355   //    |                       |
356   //   OP#1                    OP#2
357   //    |                       |
358   // [OPERATION] // OpSeq#2   [OPERATION] // OpSeq#3
359   //
360   //
361   // AFTER)
362   //
363   //   OP#0      // model_input
364   //    |
365   // [FP32TO16]  // OpSeq#4
366   //    |
367   //   OP#3---------------------------.
368   //    |                             |
369   // [OPERATION] // OpSeq#2   [OPERATION] // OpSeq#3
370 }
371
372 void Fp32ToFp16Converter::convertOperands()
373 {
374   _lowered_graph.op_seqs().iterate(
375       [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
376         const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
377         assert(lower_info != nullptr);
378         // For now, the only acl_cl supports fully fp16
379         if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
380           return;
381
382         // Convert input,output operands' type to fp16
383         convertOperandsOfOpSequence(op_seq);
384       });
385 }
386
387 void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
388 {
389   auto &operands = _lowered_graph.graph().operands();
390   const auto &operations = _lowered_graph.graph().operations();
391   const auto &op_seq_inputs = _lowered_graph.graph().getInputs();
392   const auto &op_seq_outputs = _lowered_graph.graph().getOutputs();
393
394   for (auto &op_idx : op_seq)
395   {
396     const auto &node = operations.at(op_idx);
397     for (auto &ind : node.getInputs() | ir::Remove::UNDEFINED)
398     {
399       if (node.opcode() == ir::OpCode::ConvertFp32ToFp16 || op_seq_inputs.contains(ind))
400         continue;
401
402       auto &obj = operands.at(ind);
403       if (obj.isConstant() || obj.typeInfo().type() != ir::DataType::FLOAT32)
404         continue;
405
406       obj.type(ir::DataType::FLOAT16);
407
408       VERBOSE(Fp32ToFp16Converter) << "Input Operand #" << ind.value() << ": fp16" << std::endl;
409     }
410
411     for (auto &ind : node.getOutputs())
412     {
413       if (node.opcode() == ir::OpCode::ConvertFp16ToFp32 || op_seq_outputs.contains(ind))
414         continue;
415
416       auto &obj = operands.at(ind);
417       if (obj.isConstant() || obj.typeInfo().type() != ir::DataType::FLOAT32)
418         continue;
419
420       obj.type(ir::DataType::FLOAT16);
421
422       VERBOSE(Fp32ToFp16Converter) << "Output Operand #" << ind.value() << ": fp16" << std::endl;
423     }
424   }
425 }
426
427 void Fp32ToFp16Converter::convertDatas()
428 {
429   _lowered_graph.graph().operands().iterate([&](const ir::OperandIndex &ind, ir::Operand &obj) {
430     const auto type = obj.typeInfo().type();
431     if (type == ir::DataType::FLOAT32 && obj.isConstant())
432     {
433       auto data = obj.data();
434       assert(data != nullptr);
435
436       size_t num_elements = obj.operandSize() / ir::sizeOfDataType(type);
437       size_t new_ptr_size = num_elements * sizeof(float16);
438       auto new_ptr = std::make_unique<uint8_t[]>(new_ptr_size);
439       copyDataFromFp32ToFp16(reinterpret_cast<const float *>(data->base()),
440                              reinterpret_cast<float16 *>(new_ptr.get()), num_elements);
441       obj.releaseData();
442
443       auto new_data = std::make_unique<ir::CachedData>(new_ptr.get(), new_ptr_size);
444
445       obj.data(std::move(new_data));
446       obj.type(ir::DataType::FLOAT16);
447       VERBOSE(Fp32ToFp16Converter) << "Constant Operand #" << ind.value() << ": fp16" << std::endl;
448     }
449   });
450 }
451
452 void Fp32ToFp16Converter::printOpSequences(const std::string &pre_msg, const std::string &post_msg)
453 {
454   if (pre_msg.empty() == false)
455   {
456     VERBOSE(Fp32ToFp16Converter) << pre_msg << std::endl;
457   }
458
459   _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, const ir::OpSequence &op_seq) {
460     VERBOSE(Fp32ToFp16Converter) << ir::getStrFromOpSeq(op_seq, _lowered_graph.graph().operations())
461                                  << std::endl;
462   });
463
464   if (post_msg.empty() == false)
465   {
466     VERBOSE(Fp32ToFp16Converter) << post_msg << std::endl;
467   }
468 }
469
470 bool Fp32ToFp16Converter::checkOperandType(const ir::OperandIndex &op_ind) const
471 {
472   const auto &operands = _lowered_graph.graph().operands();
473   const auto &obj = operands.at(op_ind);
474   return (obj.isConstant() == false && obj.typeInfo().type() == ir::DataType::FLOAT32);
475 }
476
477 bool Fp32ToFp16Converter::checkOperandsOfOpSequence(const ir::OpSequence &op_seq) const
478 {
479   const auto &operations = _lowered_graph.graph().operations();
480
481   // the first node's input
482   const auto &first_node_ind = op_seq.operations().at(0);
483   const auto &first_node = operations.at(first_node_ind);
484   const auto &first_node_inputs = first_node.getInputs();
485   for (const auto &op_seq_input_ind : op_seq.getInputs() | ir::Remove::UNDEFINED)
486   {
487     if (first_node_inputs.contains(op_seq_input_ind) == false)
488       return false;
489   }
490
491   // the last node's output
492   size_t last_ind = op_seq.size() - 1;
493   const auto &last_node_ind = op_seq.operations().at(last_ind);
494   const auto &last_node = operations.at(last_node_ind);
495   const auto &last_node_outputs = last_node.getOutputs();
496   for (const auto &op_seq_output_ind : op_seq.getOutputs())
497   {
498     if (last_node_outputs.contains(op_seq_output_ind) == false)
499       return false;
500   }
501
502   return true;
503 }
504
505 ir::OperandIndex Fp32ToFp16Converter::newCopiedOperand(const ir::OperandIndex &op_ind)
506 {
507   auto &operands = _lowered_graph.graph().operands();
508   const auto &obj = operands.at(op_ind);
509   auto new_op_ind = operands.emplace(obj.shape(), obj.typeInfo());
510   return new_op_ind;
511 }
512
513 void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
514                                                  const ir::OperandIndex &new_op_ind)
515 {
516   const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
517   assert(lower_info != nullptr);
518   auto new_lower_info = std::make_unique<ir::operand::LowerInfo>();
519   auto permute_factor = ir::operand::PermuteFactor(lower_info->backend(), lower_info->layout());
520   new_lower_info->addDefPermuteFactor(permute_factor);
521   new_lower_info->addUsePermuteFactor(permute_factor);
522   _lowered_graph.setLowerInfo(new_op_ind, std::move(new_lower_info));
523 }
524
525 void Fp32ToFp16Converter::setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
526                                                     const ir::OpSequenceIndex &new_op_seq_ind)
527 {
528   const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
529   assert(lower_info != nullptr);
530
531   auto new_lower_info =
532       std::make_unique<ir::operation::LowerInfo>(lower_info->backend(), lower_info->layout());
533   _lowered_graph.setLowerInfo(new_op_seq_ind, std::move(new_lower_info));
534 }
535
536 void Fp32ToFp16Converter::manipulateInput(const ir::OpSequenceIndex &op_seq_ind,
537                                           const ir::OperandIndex &op_seq_input_ind,
538                                           const ir::OperandIndex &new_op_ind)
539 {
540   auto &operands = _lowered_graph.graph().operands();
541   auto &operations = _lowered_graph.graph().operations();
542
543   auto &op_seq = _lowered_graph.op_seqs().at(op_seq_ind);
544
545   auto &first_node_ind = op_seq.operations().at(0);
546   auto &first_node = operations.at(first_node_ind);
547   assert(first_node.getInputs().contains(op_seq_input_ind));
548
549   auto &input_obj = operands.at(op_seq_input_ind);
550   assert(input_obj.isConstant() == false);
551
552   auto &new_op_obj = operands.at(new_op_ind);
553
554   // The same inputs having the index as op_seq_input_ind are replaced all at once
555   op_seq.replaceInputs(op_seq_input_ind, new_op_ind);
556   first_node.replaceInputs(op_seq_input_ind, new_op_ind);
557
558   // op_seq_obj doesn't have uses/def
559   input_obj.removeUse(first_node_ind);
560   new_op_obj.insertUse(first_node_ind);
561 }
562
563 void Fp32ToFp16Converter::manipulateOutput(const ir::OpSequenceIndex &op_seq_ind,
564                                            const ir::OperandIndex &op_seq_output_ind,
565                                            const ir::OperandIndex &new_op_ind)
566 {
567   auto &operands = _lowered_graph.graph().operands();
568   auto &operations = _lowered_graph.graph().operations();
569
570   auto &op_seq = _lowered_graph.op_seqs().at(op_seq_ind);
571
572   size_t last_ind = op_seq.size() - 1;
573   auto &last_node_ind = op_seq.operations().at(last_ind);
574   auto &last_node = operations.at(last_node_ind);
575   assert(last_node.getOutputs().contains(op_seq_output_ind));
576
577   auto &output_obj = operands.at(op_seq_output_ind);
578   assert(output_obj.isConstant() == false);
579
580   auto &new_op_obj = operands.at(new_op_ind);
581
582   // The same outputs having the index as op_seq_output_ind are replaced all at once
583   op_seq.replaceOutputs(op_seq_output_ind, new_op_ind);
584   last_node.replaceOutputs(op_seq_output_ind, new_op_ind);
585
586   // op_seq_obj doesn't have uses/def
587   assert(output_obj.getDef() == last_node_ind);
588   output_obj.unsetDef();
589   new_op_obj.setDef(last_node_ind);
590 }
591
592 ir::OperationIndex
593 Fp32ToFp16Converter::newOperationConvertFp32ToFp16(const ir::OperandIndex &op_seq_input_ind,
594                                                    const ir::OperandIndex &new_op_ind)
595 {
596   auto &operands = _lowered_graph.graph().operands();
597   auto &operations = _lowered_graph.graph().operations();
598
599   auto &input_obj = operands.at(op_seq_input_ind);
600   auto &new_op_obj = operands.at(new_op_ind);
601
602   std::unique_ptr<ir::Operation> new_node(
603       new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind}));
604   const auto new_node_ind = operations.push(std::move(new_node));
605
606   input_obj.insertUse(new_node_ind);
607   new_op_obj.setDef(new_node_ind);
608
609   return new_node_ind;
610 }
611
612 ir::OperationIndex
613 Fp32ToFp16Converter::newOperationConvertFp16ToFp32(const ir::OperandIndex &op_seq_output_ind,
614                                                    const ir::OperandIndex &new_op_ind)
615 {
616   auto &operands = _lowered_graph.graph().operands();
617   auto &operations = _lowered_graph.graph().operations();
618
619   auto &output_obj = operands.at(op_seq_output_ind);
620   auto &new_op_obj = operands.at(new_op_ind);
621
622   std::unique_ptr<ir::Operation> new_node(
623       new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind}));
624   const auto new_node_ind = operations.push(std::move(new_node));
625
626   new_op_obj.insertUse(new_node_ind);
627   output_obj.setDef(new_node_ind);
628
629   return new_node_ind;
630 }
631
632 ir::OpSequenceIndex Fp32ToFp16Converter::newOpSequence(const ir::OpSequenceIndex &op_seq_ind,
633                                                        const ir::OperationIndex &node_index)
634 {
635   auto &node = _lowered_graph.graph().operations().at(node_index);
636   const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
637   assert(lower_info != nullptr);
638   auto layout = lower_info->layout();
639
640   auto op_seq = std::make_unique<ir::OpSequence>(layout);
641   op_seq->appendOperation(node_index);
642   op_seq->setOutputs(node.getOutputs());
643   op_seq->setInputs(node.getInputs());
644
645   return _lowered_graph.op_seqs().emplace(std::move(op_seq));
646 }
647
648 // The op_seq(Fp16To32)'s output operand is the next to op_seq (Fp32To16)?
649 // If so, connect Fp16To32's previous OpSeq to Fp32To16's next OpSeq
650 //
651 // Assume that an OpSequence has an operation for easy explaination
652 //
653 // BEFORE)
654 //
655 // [OPERATION] // OpSeq#0
656 //    |
657 //   OP#0
658 //    |
659 // [FP16TO32]  // OpSeq#1
660 //    |
661 //   OP#1
662 //    |
663 // [FP32TO16]  // OpSeq#2
664 //    |
665 //   OP#2
666 //    |
667 // [OPERATION] // OpSeq#3
668 //
669 //
670 // AFTER)
671 //
672 // [OPERATION] // OpSeq#0
673 //    |
674 //   OP#0
675 //    |
676 // [OPERATION] // OpSeq#3
677 //
678 void Fp32ToFp16Converter::removeContiguousConvertOpSequences()
679 {
680   // Prepare InputToOpSeqs map
681   const auto input_to_op_seqs = prepareInputToOpSeqs();
682
683   // Find OpSequences to delete while manipulating input of OpSeq.
684   auto opseq_map_to_delete = findOpSequencesContiguous(input_to_op_seqs);
685
686   // Find Operations to delete
687   auto list_to_delete_op_seqs = getListOpSequences(opseq_map_to_delete);
688   auto list_to_delete_ops = findOperationsToDelete(list_to_delete_op_seqs);
689
690   // Before deleting, manipulateInputs of OpSeq & Operation
691   manipulateContiguousOpSequences(input_to_op_seqs, opseq_map_to_delete);
692
693   // Delete OpSequences & Operations & obj's use/def & operands
694   deleteContiguousOpSequences(list_to_delete_op_seqs, list_to_delete_ops);
695 }
696
697 Fp32ToFp16Converter::OpSeqIndexToOpSeqIndexList
698 Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_seqs) const
699 {
700   const auto &op_seqs = _lowered_graph.op_seqs();
701   OpSeqIndexToOpSeqIndexList opseq_map_to_delete;
702
703   //
704   // Assume that an Operation an OpSequence for easy explaination
705   //
706   // [OPERATION]
707   //    |
708   //   OP#0
709   //    |
710   // [FP16TO32]  // op_seq_ind_fp16_to_fp32 & op_seq_fp16_to_fp32
711   //    |
712   //   OP#1      // output_ind_fp16_fp32
713   //    |
714   // [FP32TO16]  // op_seq_ind
715   //    |
716   //   OP#2
717   //    |
718   // [OPERATION]
719   //
720   for (auto it = _list_fp16_to_fp32.cbegin(); it != _list_fp16_to_fp32.cend(); ++it)
721   {
722     // fp16_to_fp32's input/output num is always 1
723     auto &op_seq_ind_fp16_to_fp32 = *it;
724     auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32);
725     assert(op_seq_fp16_to_fp32.size() == 1);
726     assert(op_seq_fp16_to_fp32.getInputs().size() == 1);
727
728     auto &output_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getOutputs().at(0);
729     auto found_input_in_op_seqs = input_to_op_seqs.find(output_ind_fp16_to_fp32);
730     if (found_input_in_op_seqs == input_to_op_seqs.end())
731     {
732       continue;
733     }
734
735     // DO NOT FORGET THE CASE
736     //
737     //    |
738     // [FP16TO32]
739     //    |
740     //   OP#0---------------------.
741     //    |                       |
742     // [FP32TO16]              [FP32TO16]
743     //    |                       |
744     //   OP#1                    OP#2
745     //    |                       |
746     // [OPERATION]             [OPERATION]
747     //
748     for (auto &op_seq_ind : found_input_in_op_seqs->second)
749     {
750       auto found_in_fp32_to_fp16 = _list_fp32_to_fp16.find(op_seq_ind);
751       if (found_in_fp32_to_fp16 != _list_fp32_to_fp16.end())
752       {
753         if (opseq_map_to_delete.find(op_seq_ind_fp16_to_fp32) == opseq_map_to_delete.end())
754         {
755           opseq_map_to_delete[op_seq_ind_fp16_to_fp32].emplace(op_seq_ind);
756         }
757         else
758         {
759           opseq_map_to_delete[op_seq_ind_fp16_to_fp32].insert(op_seq_ind);
760         }
761
762         VERBOSE(Fp32ToFp16Converter)
763             << "Contiguous from OpSeq#" << op_seq_ind_fp16_to_fp32.value() << "(ToFp32)"
764             << " to OpSeq#" << op_seq_ind.value() << "(ToFp16)" << std::endl;
765       }
766     }
767   }
768
769   return opseq_map_to_delete;
770 }
771
772 Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() const
773 {
774   const auto &op_seqs = _lowered_graph.op_seqs();
775
776   InputToOpSeqs input_to_op_seqs;
777   op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) {
778     for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED)
779     {
780       auto it = input_to_op_seqs.find(input);
781       if (it == input_to_op_seqs.end())
782       {
783         input_to_op_seqs[input].emplace(op_seq_idx);
784       }
785       else
786       {
787         input_to_op_seqs[input].insert(op_seq_idx);
788       }
789     }
790   });
791
792   return input_to_op_seqs;
793 }
794
795 Fp32ToFp16Converter::OpSeqIndexList
796 Fp32ToFp16Converter::getListOpSequences(const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete) const
797 {
798   OpSeqIndexList list;
799   for (const auto &it : opseq_map_to_delete)
800   {
801     auto &opseq_ind_fp16_to_fp32 = it.first;
802     if (list.find(opseq_ind_fp16_to_fp32) == list.end())
803     {
804       list.emplace(opseq_ind_fp16_to_fp32);
805     }
806
807     for (auto &opseq_ind_fp32_to_fp16 : it.second)
808     {
809       if (list.find(opseq_ind_fp32_to_fp16) == list.end())
810       {
811         list.emplace(opseq_ind_fp32_to_fp16);
812       }
813     }
814   }
815   return list;
816 }
817
818 ir::OperandIndexSequence
819 Fp32ToFp16Converter::findOperationsToDelete(const OpSeqIndexList &list_to_delete_op_seqs) const
820 {
821   const auto &operations = _lowered_graph.graph().operations();
822   const auto &op_seqs = _lowered_graph.op_seqs();
823
824   ir::OperandIndexSequence list_to_delete_ops;
825   for (const auto &op_seq_ind : list_to_delete_op_seqs)
826   {
827     const auto &op_seq = op_seqs.at(op_seq_ind);
828     assert(op_seq.size() == 1);
829
830     const auto &first_node_ind = op_seq.operations().at(0);
831     const auto &first_node = operations.at(first_node_ind);
832     assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 ||
833            first_node.opcode() == ir::OpCode::ConvertFp16ToFp32);
834
835     for (const auto &ind : first_node.getOutputs())
836     {
837       list_to_delete_ops.append(ind);
838     }
839   }
840
841   return list_to_delete_ops;
842 }
843
844 void Fp32ToFp16Converter::manipulateContiguousOpSequences(
845     const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete)
846 {
847   auto &op_seqs = _lowered_graph.op_seqs();
848
849   //
850   // [OPERATION]
851   //    |
852   //   OP#0      // input_ind_fp16_to_fp32
853   //    |
854   // [FP16TO32]  // op_seq_ind_fp16_to_fp32 & op_seq_fp16_to_fp32
855   //    |
856   //   OP#1
857   //    |
858   // [FP32TO16]  // op_seq_ind_fp32_to_fp16, op_seq_fp32_to_fp16
859   //    |
860   //   OP#2      // output_ind_fp32_to_fp16
861   //    |
862   // [OPERATION] // op_seq_ind_next_to_fp16
863   //
864   for (auto it : opseq_map_to_delete)
865   {
866     // fp16_to_fp32's input/output num is always 1
867     auto &op_seq_ind_fp16_to_fp32 = it.first;
868     auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32);
869     auto &input_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getInputs().at(0);
870
871     for (auto &op_seq_ind_fp32_to_fp16 : it.second)
872     {
873       auto &op_seq_fp32_to_fp16 = op_seqs.at(op_seq_ind_fp32_to_fp16);
874       assert(op_seq_fp32_to_fp16.size() == 1);
875       assert(op_seq_fp32_to_fp16.getInputs().size() == 1);
876
877       auto &output_ind_fp32_to_fp16 = op_seq_fp32_to_fp16.getOutputs().at(0);
878       auto found_next_to_fp16 = input_to_op_seqs.find(output_ind_fp32_to_fp16);
879       assert(found_next_to_fp16 != input_to_op_seqs.end());
880
881       for (auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second)
882       {
883         manipulateInput(op_seq_ind_next_to_fp16, output_ind_fp32_to_fp16, input_ind_fp16_to_fp32);
884       }
885       //
886       // [OPERATION]
887       //    |
888       //   OP#0      // input_ind_fp16_to_fp32
889       //    |
890       // [OPERATION] // op_seq_ind_next_to_fp16
891       //
892     }
893   }
894 }
895
896 void Fp32ToFp16Converter::deleteContiguousOpSequences(
897     const OpSeqIndexList &list_to_delete_op_seqs,
898     const ir::OperandIndexSequence &list_to_delete_ops)
899 {
900   auto &operands = _lowered_graph.graph().operands();
901   auto &operations = _lowered_graph.graph().operations();
902   auto &op_seqs = _lowered_graph.op_seqs();
903
904   for (auto &op_seq_ind : list_to_delete_op_seqs)
905   {
906     auto &op_seq = op_seqs.at(op_seq_ind);
907     assert(op_seq.size() == 1);
908     VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq #" << op_seq_ind.value() << std::endl;
909
910     auto &first_node_ind = op_seq.operations().at(0);
911     auto &first_node = operations.at(first_node_ind);
912     assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 ||
913            first_node.opcode() == ir::OpCode::ConvertFp16ToFp32);
914     VERBOSE(Fp32ToFp16Converter) << "Delete Node #" << first_node_ind.value() << std::endl;
915
916     // Uses
917     for (auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
918     {
919       auto &obj = operands.at(ind);
920       obj.removeUse(first_node_ind);
921       VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Use(Node#"
922                                    << first_node_ind.value() << ") is removed" << std::endl;
923     }
924
925     // Def
926     for (auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
927     {
928       auto &obj = operands.at(ind);
929       assert(obj.getDef() == first_node_ind);
930       obj.unsetDef();
931       VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Def(Node#"
932                                    << first_node_ind.value() << ") is removed" << std::endl;
933     }
934
935     // Operation
936     operations.remove(first_node_ind);
937     VERBOSE(Fp32ToFp16Converter) << "Node#" << first_node_ind.value() << " is removed" << std::endl;
938
939     // OpSequence
940     op_seqs.remove(op_seq_ind);
941     VERBOSE(Fp32ToFp16Converter) << "OpSeq#" << op_seq_ind.value() << " is removed" << std::endl;
942   }
943
944   // Operand
945   for (auto &ind : list_to_delete_ops)
946   {
947     operands.remove(ind);
948     VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << " is removed" << std::endl;
949   }
950 }
951
952 } // namespace compiler
953
954 } // namespace onert