Imported Upstream version 1.21.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / Fp32ToFp16Converter.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #if 0 // This file is temporarily unused
18
19 #include "Fp32ToFp16Converter.h"
20 #include "ir/operation/ConvertFp32ToFp16.h"
21 #include "ir/operation/ConvertFp16ToFp32.h"
22 #include "util/logging.h"
23
24 #include <Half.h>
25
26 using float16 = Half;
27
28 namespace
29 {
30
31 const std::string kAclClBackendConfigId = "acl_cl";
32
33 void copyDataFromFp32ToFp16(const float *from, float16 *into, size_t num_elements)
34 {
35   for (size_t i = 0; i < num_elements; ++i)
36   {
37     into[i] = static_cast<float16>(from[i]);
38   }
39 }
40
41 } // namespace
42
43 namespace onert
44 {
45
46 namespace compiler
47 {
48
49 Fp32ToFp16Converter::Fp32ToFp16Converter(compiler::LoweredGraph &lowered_graph)
50   : _lowered_graph{lowered_graph}
51 {
52   VERBOSE(Fp32ToFp16Converter) << "Fp16 Enable on" << std::endl;
53 }
54
55 // For example, two OpSequences are there and each OpSequence has an Operation
56 //
57 //   OP#0      // model input
58 //    |
59 // [OPERATION] // OpSeq#0
60 //    |
61 //   OP#1
62 //    |
63 // [OPERATION] // OpSeq#1
64 //    |
65 //   OP#2      // model output
66 //
67 //
68 // AFTER `appendOpSequences()`,
69 // note that model_input and model_output are not changed.
70 //
71 //   OP#0
72 //    |
73 // [FP32TO16]  // OpSeq#2
74 //    |
75 //   OP#3
76 //    |
77 // [OPERATION] // OpSeq#0
78 //    |
79 //   OP#4
80 //    |
81 // [FP16TO32]  // OpSeq#3
82 //    |
83 //   OP#1
84 //    |
85 // [FP32TO16]  // OpSeq#4
86 //    |
87 //   OP#5
88 //    |
89 // [OPERATION] // OpSeq#1
90 //    |
91 //   OP#6
92 //    |
93 // [FP16TO32]  // OpSeq#5
94 //    |
95 //   OP#2
96 //
97 //
98 // AFTER `optimize()`,
99 //
100 //   OP#0
101 //    |
102 // [FP32TO16]  // OpSeq#2
103 //    |
104 //   OP#3
105 //    |
106 // [OPERATION] // OpSeq#0
107 //    |
108 //   OP#4
109 //    |
110 // [OPERATION] // OpSeq#1
111 //    |
112 //   OP#6
113 //    |
114 // [FP16TO32]  // OpSeq#5
115 //    |
116 //   OP#2
117 //
118 //
119 // AFTER `convertOperands()`,
120 //
121 //   OP#0      // model_input, not fp16
122 //    |
123 // [FP32TO16]  // OpSeq#2
124 //    |
125 //   OP#3      // fp16
126 //    |
127 // [OPERATION] // OpSeq#0
128 //    |
129 //   OP#4      // fp16
130 //    |
131 // [OPERATION] // OpSeq#1
132 //    |
133 //   OP#6      // fp16
134 //    |
135 // [FP16TO32]  // OpSeq#5
136 //    |
137 //   OP#2      // model_output, notfp16
138 //
139 //
140 // AFTER `convertDatas()`,
141 //
142 //   OP#0      // model_input, not fp16
143 //    |
144 // [FP32TO16]  // OpSeq#2
145 //    |
146 //   OP#3      // fp16
147 //    |
148 // [OPERATION] // OpSeq#0, constants are fp16
149 //    |
150 //   OP#4      // fp16
151 //    |
152 // [OPERATION] // OpSeq#1, constants are fp16
153 //    |
154 //   OP#6      // fp16
155 //    |
156 // [FP16TO32]  // OpSeq#5
157 //    |
158 //   OP#2      // model_output, notfp16
159 //
160 void Fp32ToFp16Converter::run()
161 {
162   // Append new OpSequence which includes ConvertFp32ToFp16
163   //   and append new OpSequence which includes ConvertFp16ToFp32
164   appendOpSequences();
165
166   // Remove unnecessary converting operations
167   optimize();
168
169   // Convert operands' data types from fp32 to fp16
170   convertOperands();
171
172   // Convert Datas
173   convertDatas();
174
175   // Print the result
176   printOpSequences("FINAL OpSequences");
177 }
178
179 void Fp32ToFp16Converter::appendOpSequences()
180 {
181   _lowered_graph.op_seqs().iterate(
182     [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
183       const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
184       assert(lower_info != nullptr);
185
186       // For now, the only acl_cl supports fully fp16 type
187       // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat
188       // operations.
189       //      To do this, we could check the support by `operation by operation`. After that, we
190       //      would partition an op_seq if it contains unsupported operations.
191       if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
192         return;
193
194       // OpSeq's input set should be included in the first operation's input set or
195       // OpSeq's output set should be included in the last operation's output set
196       assert(checkOperandsOfOpSequence(op_seq));
197
198       // Append converting OpSequence for fp16 but all operands' types are not fp16 still.
199       appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq);
200       appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq);
201     });
202 }
203
204 //
205 // BEFORE
206 //
207 //   OP#0      // model input
208 //    |
209 // [OPERATION] // OpSeq#0
210 //    |
211 //   OP#1      // model output
212 //
213 //
214 // AFTER
215 //
216 //   OP#0      // model input
217 //    |
218 // [FP32TO16]  // OpSeq#1
219 //    |
220 //   OP#2
221 //    |
222 // [OPERATION] // OpSeq#0
223 //    |
224 //   OP#1      // model output
225 //
226 void Fp32ToFp16Converter::appendNewOpSeqForConvertFp32ToFp16(const ir::OpSequenceIndex &op_seq_ind,
227                                                              ir::OpSequence &op_seq)
228 {
229   // OpSeq's input set is included in the first operation's input set
230   const ir::OperandIndexSequence op_seq_inputs = op_seq.getInputs(); // copied
231
232   // NOTE Please do not change sequence of op_seq_inputs. It can change the sequence of inputs of
233   // Subgraph
234   for (const auto &op_seq_input_ind :
235        op_seq_inputs | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
236   {
237     if (checkOperandType(op_seq_input_ind) == false)
238       continue;
239
240     // new operand w/ datatype fp32
241     const auto new_op_ind = newCopiedOperand(op_seq_input_ind);
242
243     // set new lower_info for operand
244     setNewOperandLowerInfo(op_seq_ind, new_op_ind);
245
246     // manipulate input of operation and op_seq
247     // - replace the first operation's input to new operand
248     //   with old operand's removeUse and new operand's appendUse()
249     manipulateInput(op_seq_ind, op_seq_input_ind, new_op_ind);
250
251     // new op
252     const auto new_node_ind = newOperationConvertFp32ToFp16(op_seq_input_ind, new_op_ind);
253
254     // new op_seq
255     const auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
256
257     // set new lower_info for op_seq
258     setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind);
259
260     _list_fp32_to_fp16.insert(new_op_seq_ind);
261
262     VERBOSE(Fp32ToFp16Converter) << "NEW   |Fp32To16]"
263                                  << ir::getStrFromOpSeq(_lowered_graph.op_seqs().at(new_op_seq_ind),
264                                                         _lowered_graph.graph().operations())
265                                  << std::endl;
266   }
267 }
268
269 //
270 // BEFORE
271 //
272 //   OP#0      // model input
273 //    |
274 // [FP32TO16]  // OpSeq#1
275 //    |
276 //   OP#2
277 //    |
278 // [OPERATION] // OpSeq#0
279 //    |
280 //   OP#1      // model output
281 //
282 //
283 // AFTER
284 //
285 //   OP#0      // model input
286 //    |
287 // [FP32TO16]  // OpSeq#1
288 //    |
289 //   OP#2
290 //    |
291 // [OPERATION] // OpSeq#0
292 //    |
293 //   OP#3
294 //    |
295 // [FP16TO32]  // OpSeq#2
296 //    |
297 //   OP#1      // model output
298 //
299 void Fp32ToFp16Converter::appendNewOpSeqForConvertFp16ToFp32(const ir::OpSequenceIndex &op_seq_ind,
300                                                              ir::OpSequence &op_seq)
301 {
302   // OpSeq's output set is included in the last operation's output set
303   const ir::OperandIndexSequence op_seq_outputs = op_seq.getOutputs(); // copied
304
305   // NOTE Please do not change sequence of op_seq_outputs. It can change the sequence of outputs of
306   // Subgraph
307   for (const auto &op_seq_output_ind :
308        op_seq_outputs | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
309   {
310     if (checkOperandType(op_seq_output_ind) == false)
311       continue;
312
313     // new operand w/ datatype fp32
314     const auto new_op_ind = newCopiedOperand(op_seq_output_ind);
315
316     // set new lower_info for operand
317     setNewOperandLowerInfo(op_seq_ind, new_op_ind);
318
319     // manipulate output of operation and op_seq
320     // - replace output of the last operation's output to new operand
321     //    with old operand's unsetDef and new operand's appendDef()
322     manipulateOutput(op_seq_ind, op_seq_output_ind, new_op_ind);
323
324     // new op
325     auto new_node_ind = newOperationConvertFp16ToFp32(op_seq_output_ind, new_op_ind);
326
327     // new op_seq
328     auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
329
330     // set new lower_info for op_seq
331     setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind);
332
333     _list_fp16_to_fp32.insert(new_op_seq_ind);
334
335     VERBOSE(Fp32ToFp16Converter) << "NEW   |Fp16To32]"
336                                  << ir::getStrFromOpSeq(_lowered_graph.op_seqs().at(new_op_seq_ind),
337                                                         _lowered_graph.graph().operations())
338                                  << std::endl;
339   }
340 }
341
342 void Fp32ToFp16Converter::optimize()
343 {
344   printOpSequences("BEFORE opt");
345
346   removeContiguousConvertOpSequences();
347
348   printOpSequences("AFTER removeContiguousConverts");
349
350   // TODO Handle Split from the beginning of the model. ex) MODELS/inception_module
351   //
352   // BEFORE)
353   //
354   //   OP#0---------------------.         // model_input
355   //    |                       |
356   // [FP32TO16]  // OpSeq#0   [FP32TO16]  // OpSeq#1
357   //    |                       |
358   //   OP#1                    OP#2
359   //    |                       |
360   // [OPERATION] // OpSeq#2   [OPERATION] // OpSeq#3
361   //
362   //
363   // AFTER)
364   //
365   //   OP#0      // model_input
366   //    |
367   // [FP32TO16]  // OpSeq#4
368   //    |
369   //   OP#3---------------------------.
370   //    |                             |
371   // [OPERATION] // OpSeq#2   [OPERATION] // OpSeq#3
372 }
373
374 void Fp32ToFp16Converter::convertOperands()
375 {
376   _lowered_graph.op_seqs().iterate(
377     [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
378       const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
379       assert(lower_info != nullptr);
380       // For now, the only acl_cl supports fully fp16
381       if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
382         return;
383
384       // Convert input,output operands' type to fp16
385       convertOperandsOfOpSequence(op_seq);
386     });
387 }
388
389 void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
390 {
391   auto &operands = _lowered_graph.graph().operands();
392   const auto &operations = _lowered_graph.graph().operations();
393   const auto &op_seq_inputs = _lowered_graph.graph().getInputs();
394   const auto &op_seq_outputs = _lowered_graph.graph().getOutputs();
395
396   for (auto &op_idx : op_seq)
397   {
398     const auto &node = operations.at(op_idx);
399     for (auto &ind : node.getInputs() | ir::Remove::UNDEFINED)
400     {
401       if (node.opcode() == ir::OpCode::ConvertFp32ToFp16 || op_seq_inputs.contains(ind))
402         continue;
403
404       auto &obj = operands.at(ind);
405       if (obj.isConstant() || obj.typeInfo().type() != ir::DataType::FLOAT32)
406         continue;
407
408       obj.type(ir::DataType::FLOAT16);
409
410       VERBOSE(Fp32ToFp16Converter) << "Input Operand " << ind << ": fp16" << std::endl;
411     }
412
413     for (auto &ind : node.getOutputs())
414     {
415       if (node.opcode() == ir::OpCode::ConvertFp16ToFp32 || op_seq_outputs.contains(ind))
416         continue;
417
418       auto &obj = operands.at(ind);
419       if (obj.isConstant() || obj.typeInfo().type() != ir::DataType::FLOAT32)
420         continue;
421
422       obj.type(ir::DataType::FLOAT16);
423
424       VERBOSE(Fp32ToFp16Converter) << "Output Operand " << ind << ": fp16" << std::endl;
425     }
426   }
427 }
428
429 void Fp32ToFp16Converter::convertDatas()
430 {
431   _lowered_graph.graph().operands().iterate([&](const ir::OperandIndex &ind, ir::Operand &obj) {
432     const auto type = obj.typeInfo().type();
433     if (type == ir::DataType::FLOAT32 && obj.isConstant())
434     {
435       auto data = obj.data();
436       assert(data != nullptr);
437
438       size_t num_elements = obj.operandSize() / ir::sizeOfDataType(type);
439       size_t new_ptr_size = num_elements * sizeof(float16);
440       auto new_ptr = std::make_unique<uint8_t[]>(new_ptr_size);
441       copyDataFromFp32ToFp16(reinterpret_cast<const float *>(data->base()),
442                              reinterpret_cast<float16 *>(new_ptr.get()), num_elements);
443       obj.releaseData();
444
445       auto new_data = std::make_unique<ir::CachedData>(new_ptr.get(), new_ptr_size);
446
447       obj.data(std::move(new_data));
448       obj.type(ir::DataType::FLOAT16);
449       VERBOSE(Fp32ToFp16Converter) << "Constant Operand " << ind << ": fp16" << std::endl;
450     }
451   });
452 }
453
454 void Fp32ToFp16Converter::printOpSequences(const std::string &pre_msg, const std::string &post_msg)
455 {
456   if (pre_msg.empty() == false)
457   {
458     VERBOSE(Fp32ToFp16Converter) << pre_msg << std::endl;
459   }
460
461   _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, const ir::OpSequence &op_seq) {
462     VERBOSE(Fp32ToFp16Converter) << ir::getStrFromOpSeq(op_seq, _lowered_graph.graph().operations())
463                                  << std::endl;
464   });
465
466   if (post_msg.empty() == false)
467   {
468     VERBOSE(Fp32ToFp16Converter) << post_msg << std::endl;
469   }
470 }
471
472 bool Fp32ToFp16Converter::checkOperandType(const ir::OperandIndex &op_ind) const
473 {
474   const auto &operands = _lowered_graph.graph().operands();
475   const auto &obj = operands.at(op_ind);
476   return (obj.isConstant() == false && obj.typeInfo().type() == ir::DataType::FLOAT32);
477 }
478
479 bool Fp32ToFp16Converter::checkOperandsOfOpSequence(const ir::OpSequence &op_seq) const
480 {
481   const auto &operations = _lowered_graph.graph().operations();
482
483   // the first node's input
484   const auto &first_node_ind = op_seq.operations().at(0);
485   const auto &first_node = operations.at(first_node_ind);
486   const auto &first_node_inputs = first_node.getInputs();
487   for (const auto &op_seq_input_ind : op_seq.getInputs() | ir::Remove::UNDEFINED)
488   {
489     if (first_node_inputs.contains(op_seq_input_ind) == false)
490       return false;
491   }
492
493   // the last node's output
494   size_t last_ind = op_seq.size() - 1;
495   const auto &last_node_ind = op_seq.operations().at(last_ind);
496   const auto &last_node = operations.at(last_node_ind);
497   const auto &last_node_outputs = last_node.getOutputs();
498   for (const auto &op_seq_output_ind : op_seq.getOutputs())
499   {
500     if (last_node_outputs.contains(op_seq_output_ind) == false)
501       return false;
502   }
503
504   return true;
505 }
506
507 ir::OperandIndex Fp32ToFp16Converter::newCopiedOperand(const ir::OperandIndex &op_ind)
508 {
509   auto &operands = _lowered_graph.graph().operands();
510   const auto &obj = operands.at(op_ind);
511   auto new_op_ind = operands.emplace(obj.shape(), obj.typeInfo());
512   return new_op_ind;
513 }
514
515 void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
516                                                  const ir::OperandIndex &new_op_ind)
517 {
518   const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
519   assert(lower_info != nullptr);
520   auto new_lower_info = std::make_unique<compiler::OperandLowerInfo>();
521   auto permute_factor = compiler::PermuteFactor(lower_info->backend(), lower_info->layout());
522   new_lower_info->addDefPermuteFactor(permute_factor);
523   new_lower_info->addUsePermuteFactor(permute_factor);
524   _lowered_graph.setLowerInfo(new_op_ind, std::move(new_lower_info));
525 }
526
527 void Fp32ToFp16Converter::setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
528                                                    const ir::OpSequenceIndex &new_op_seq_ind)
529 {
530   const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
531   assert(lower_info != nullptr);
532
533   auto new_lower_info =
534     std::make_unique<compiler::OperationLowerInfo>(lower_info->backend(), lower_info->layout());
535   _lowered_graph.setLowerInfo(new_op_seq_ind, std::move(new_lower_info));
536 }
537
538 void Fp32ToFp16Converter::manipulateInput(const ir::OpSequenceIndex &op_seq_ind,
539                                           const ir::OperandIndex &op_seq_input_ind,
540                                           const ir::OperandIndex &new_op_ind)
541 {
542   auto &operands = _lowered_graph.graph().operands();
543   auto &operations = _lowered_graph.graph().operations();
544
545   auto &op_seq = _lowered_graph.op_seqs().at(op_seq_ind);
546
547   auto &first_node_ind = op_seq.operations().at(0);
548   auto &first_node = operations.at(first_node_ind);
549   assert(first_node.getInputs().contains(op_seq_input_ind));
550
551   auto &input_obj = operands.at(op_seq_input_ind);
552   assert(input_obj.isConstant() == false);
553
554   auto &new_op_obj = operands.at(new_op_ind);
555
556   // The same inputs having the index as op_seq_input_ind are replaced all at once
557   op_seq.replaceInputs(op_seq_input_ind, new_op_ind);
558   first_node.replaceInputs(op_seq_input_ind, new_op_ind);
559
560   // op_seq_obj doesn't have uses/def
561   input_obj.removeUse(first_node_ind);
562   new_op_obj.insertUse(first_node_ind);
563 }
564
565 void Fp32ToFp16Converter::manipulateOutput(const ir::OpSequenceIndex &op_seq_ind,
566                                            const ir::OperandIndex &op_seq_output_ind,
567                                            const ir::OperandIndex &new_op_ind)
568 {
569   auto &operands = _lowered_graph.graph().operands();
570   auto &operations = _lowered_graph.graph().operations();
571
572   auto &op_seq = _lowered_graph.op_seqs().at(op_seq_ind);
573
574   size_t last_ind = op_seq.size() - 1;
575   auto &last_node_ind = op_seq.operations().at(last_ind);
576   auto &last_node = operations.at(last_node_ind);
577   assert(last_node.getOutputs().contains(op_seq_output_ind));
578
579   auto &output_obj = operands.at(op_seq_output_ind);
580   assert(output_obj.isConstant() == false);
581
582   auto &new_op_obj = operands.at(new_op_ind);
583
584   // The same outputs having the index as op_seq_output_ind are replaced all at once
585   op_seq.replaceOutputs(op_seq_output_ind, new_op_ind);
586   last_node.replaceOutputs(op_seq_output_ind, new_op_ind);
587
588   // op_seq_obj doesn't have uses/def
589   assert(output_obj.getDef() == last_node_ind);
590   output_obj.unsetDef();
591   new_op_obj.setDef(last_node_ind);
592 }
593
594 ir::OperationIndex
595 Fp32ToFp16Converter::newOperationConvertFp32ToFp16(const ir::OperandIndex &op_seq_input_ind,
596                                                    const ir::OperandIndex &new_op_ind)
597 {
598   auto &operands = _lowered_graph.graph().operands();
599   auto &operations = _lowered_graph.graph().operations();
600
601   auto &input_obj = operands.at(op_seq_input_ind);
602   auto &new_op_obj = operands.at(new_op_ind);
603
604   std::unique_ptr<ir::Operation> new_node(
605     new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind}));
606   const auto new_node_ind = operations.push(std::move(new_node));
607
608   input_obj.insertUse(new_node_ind);
609   new_op_obj.setDef(new_node_ind);
610
611   return new_node_ind;
612 }
613
614 ir::OperationIndex
615 Fp32ToFp16Converter::newOperationConvertFp16ToFp32(const ir::OperandIndex &op_seq_output_ind,
616                                                    const ir::OperandIndex &new_op_ind)
617 {
618   auto &operands = _lowered_graph.graph().operands();
619   auto &operations = _lowered_graph.graph().operations();
620
621   auto &output_obj = operands.at(op_seq_output_ind);
622   auto &new_op_obj = operands.at(new_op_ind);
623
624   std::unique_ptr<ir::Operation> new_node(
625     new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind}));
626   const auto new_node_ind = operations.push(std::move(new_node));
627
628   new_op_obj.insertUse(new_node_ind);
629   output_obj.setDef(new_node_ind);
630
631   return new_node_ind;
632 }
633
634 ir::OpSequenceIndex Fp32ToFp16Converter::newOpSequence(const ir::OpSequenceIndex &op_seq_ind,
635                                                        const ir::OperationIndex &node_index)
636 {
637   auto &node = _lowered_graph.graph().operations().at(node_index);
638   const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
639   assert(lower_info != nullptr);
640   auto layout = lower_info->layout();
641
642   auto op_seq = std::make_unique<ir::OpSequence>(layout);
643   op_seq->appendOperation(node_index);
644   op_seq->setOutputs(node.getOutputs());
645   op_seq->setInputs(node.getInputs());
646
647   return _lowered_graph.op_seqs().emplace(std::move(op_seq));
648 }
649
650 // The op_seq(Fp16To32)'s output operand is the next to op_seq (Fp32To16)?
651 // If so, connect Fp16To32's previous OpSeq to Fp32To16's next OpSeq
652 //
653 // Assume that an OpSequence has an operation for easy explaination
654 //
655 // BEFORE)
656 //
657 // [OPERATION] // OpSeq#0
658 //    |
659 //   OP#0
660 //    |
661 // [FP16TO32]  // OpSeq#1
662 //    |
663 //   OP#1
664 //    |
665 // [FP32TO16]  // OpSeq#2
666 //    |
667 //   OP#2
668 //    |
669 // [OPERATION] // OpSeq#3
670 //
671 //
672 // AFTER)
673 //
674 // [OPERATION] // OpSeq#0
675 //    |
676 //   OP#0
677 //    |
678 // [OPERATION] // OpSeq#3
679 //
680 void Fp32ToFp16Converter::removeContiguousConvertOpSequences()
681 {
682   // Prepare InputToOpSeqs map
683   const auto input_to_op_seqs = prepareInputToOpSeqs();
684
685   // Find OpSequences to delete while manipulating input of OpSeq.
686   auto opseq_map_to_delete = findOpSequencesContiguous(input_to_op_seqs);
687
688   // Find Operations to delete
689   auto list_to_delete_op_seqs = getListOpSequences(opseq_map_to_delete);
690   auto list_to_delete_ops = findOperationsToDelete(list_to_delete_op_seqs);
691
692   // Before deleting, manipulateInputs of OpSeq & Operation
693   manipulateContiguousOpSequences(input_to_op_seqs, opseq_map_to_delete);
694
695   // Delete OpSequences & Operations & obj's use/def & operands
696   deleteContiguousOpSequences(list_to_delete_op_seqs, list_to_delete_ops);
697 }
698
699 Fp32ToFp16Converter::OpSeqIndexToOpSeqIndexList
700 Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_seqs) const
701 {
702   const auto &op_seqs = _lowered_graph.op_seqs();
703   OpSeqIndexToOpSeqIndexList opseq_map_to_delete;
704
705   //
706   // Assume that an Operation an OpSequence for easy explaination
707   //
708   // [OPERATION]
709   //    |
710   //   OP#0
711   //    |
712   // [FP16TO32]  // op_seq_ind_fp16_to_fp32 & op_seq_fp16_to_fp32
713   //    |
714   //   OP#1      // output_ind_fp16_fp32
715   //    |
716   // [FP32TO16]  // op_seq_ind
717   //    |
718   //   OP#2
719   //    |
720   // [OPERATION]
721   //
722   for (auto it = _list_fp16_to_fp32.cbegin(); it != _list_fp16_to_fp32.cend(); ++it)
723   {
724     // fp16_to_fp32's input/output num is always 1
725     auto &op_seq_ind_fp16_to_fp32 = *it;
726     auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32);
727     assert(op_seq_fp16_to_fp32.size() == 1);
728     assert(op_seq_fp16_to_fp32.getInputs().size() == 1);
729
730     auto &output_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getOutputs().at(0);
731     auto found_input_in_op_seqs = input_to_op_seqs.find(output_ind_fp16_to_fp32);
732     if (found_input_in_op_seqs == input_to_op_seqs.end())
733     {
734       continue;
735     }
736
737     // DO NOT FORGET THE CASE
738     //
739     //    |
740     // [FP16TO32]
741     //    |
742     //   OP#0---------------------.
743     //    |                       |
744     // [FP32TO16]              [FP32TO16]
745     //    |                       |
746     //   OP#1                    OP#2
747     //    |                       |
748     // [OPERATION]             [OPERATION]
749     //
750     for (auto &op_seq_ind : found_input_in_op_seqs->second)
751     {
752       auto found_in_fp32_to_fp16 = _list_fp32_to_fp16.find(op_seq_ind);
753       if (found_in_fp32_to_fp16 != _list_fp32_to_fp16.end())
754       {
755         if (opseq_map_to_delete.find(op_seq_ind_fp16_to_fp32) == opseq_map_to_delete.end())
756         {
757           opseq_map_to_delete[op_seq_ind_fp16_to_fp32].emplace(op_seq_ind);
758         }
759         else
760         {
761           opseq_map_to_delete[op_seq_ind_fp16_to_fp32].insert(op_seq_ind);
762         }
763
764         VERBOSE(Fp32ToFp16Converter) << "Contiguous from " << op_seq_ind_fp16_to_fp32 << "(ToFp32)"
765                                      << " to " << op_seq_ind << "(ToFp16)" << std::endl;
766       }
767     }
768   }
769
770   return opseq_map_to_delete;
771 }
772
773 Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() const
774 {
775   const auto &op_seqs = _lowered_graph.op_seqs();
776
777   InputToOpSeqs input_to_op_seqs;
778   op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) {
779     for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED)
780     {
781       auto it = input_to_op_seqs.find(input);
782       if (it == input_to_op_seqs.end())
783       {
784         input_to_op_seqs[input].emplace(op_seq_idx);
785       }
786       else
787       {
788         input_to_op_seqs[input].insert(op_seq_idx);
789       }
790     }
791   });
792
793   return input_to_op_seqs;
794 }
795
796 Fp32ToFp16Converter::OpSeqIndexList
797 Fp32ToFp16Converter::getListOpSequences(const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete) const
798 {
799   OpSeqIndexList list;
800   for (const auto &it : opseq_map_to_delete)
801   {
802     auto &opseq_ind_fp16_to_fp32 = it.first;
803     if (list.find(opseq_ind_fp16_to_fp32) == list.end())
804     {
805       list.emplace(opseq_ind_fp16_to_fp32);
806     }
807
808     for (auto &opseq_ind_fp32_to_fp16 : it.second)
809     {
810       if (list.find(opseq_ind_fp32_to_fp16) == list.end())
811       {
812         list.emplace(opseq_ind_fp32_to_fp16);
813       }
814     }
815   }
816   return list;
817 }
818
819 ir::OperandIndexSequence
820 Fp32ToFp16Converter::findOperationsToDelete(const OpSeqIndexList &list_to_delete_op_seqs) const
821 {
822   const auto &operations = _lowered_graph.graph().operations();
823   const auto &op_seqs = _lowered_graph.op_seqs();
824
825   ir::OperandIndexSequence list_to_delete_ops;
826   for (const auto &op_seq_ind : list_to_delete_op_seqs)
827   {
828     const auto &op_seq = op_seqs.at(op_seq_ind);
829     assert(op_seq.size() == 1);
830
831     const auto &first_node_ind = op_seq.operations().at(0);
832     const auto &first_node = operations.at(first_node_ind);
833     assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 ||
834            first_node.opcode() == ir::OpCode::ConvertFp16ToFp32);
835
836     for (const auto &ind : first_node.getOutputs())
837     {
838       list_to_delete_ops.append(ind);
839     }
840   }
841
842   return list_to_delete_ops;
843 }
844
845 void Fp32ToFp16Converter::manipulateContiguousOpSequences(
846   const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete)
847 {
848   auto &op_seqs = _lowered_graph.op_seqs();
849
850   //
851   // [OPERATION]
852   //    |
853   //   OP#0      // input_ind_fp16_to_fp32
854   //    |
855   // [FP16TO32]  // op_seq_ind_fp16_to_fp32 & op_seq_fp16_to_fp32
856   //    |
857   //   OP#1
858   //    |
859   // [FP32TO16]  // op_seq_ind_fp32_to_fp16, op_seq_fp32_to_fp16
860   //    |
861   //   OP#2      // output_ind_fp32_to_fp16
862   //    |
863   // [OPERATION] // op_seq_ind_next_to_fp16
864   //
865   for (auto it : opseq_map_to_delete)
866   {
867     // fp16_to_fp32's input/output num is always 1
868     auto &op_seq_ind_fp16_to_fp32 = it.first;
869     auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32);
870     auto &input_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getInputs().at(0);
871
872     for (auto &op_seq_ind_fp32_to_fp16 : it.second)
873     {
874       auto &op_seq_fp32_to_fp16 = op_seqs.at(op_seq_ind_fp32_to_fp16);
875       assert(op_seq_fp32_to_fp16.size() == 1);
876       assert(op_seq_fp32_to_fp16.getInputs().size() == 1);
877
878       auto &output_ind_fp32_to_fp16 = op_seq_fp32_to_fp16.getOutputs().at(0);
879       auto found_next_to_fp16 = input_to_op_seqs.find(output_ind_fp32_to_fp16);
880       assert(found_next_to_fp16 != input_to_op_seqs.end());
881
882       for (auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second)
883       {
884         manipulateInput(op_seq_ind_next_to_fp16, output_ind_fp32_to_fp16, input_ind_fp16_to_fp32);
885       }
886       //
887       // [OPERATION]
888       //    |
889       //   OP#0      // input_ind_fp16_to_fp32
890       //    |
891       // [OPERATION] // op_seq_ind_next_to_fp16
892       //
893     }
894   }
895 }
896
897 void Fp32ToFp16Converter::deleteContiguousOpSequences(
898   const OpSeqIndexList &list_to_delete_op_seqs, const ir::OperandIndexSequence &list_to_delete_ops)
899 {
900   auto &operands = _lowered_graph.graph().operands();
901   auto &operations = _lowered_graph.graph().operations();
902   auto &op_seqs = _lowered_graph.op_seqs();
903
904   for (auto &op_seq_ind : list_to_delete_op_seqs)
905   {
906     auto &op_seq = op_seqs.at(op_seq_ind);
907     assert(op_seq.size() == 1);
908     VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq " << op_seq_ind << std::endl;
909
910     auto &first_node_ind = op_seq.operations().at(0);
911     auto &first_node = operations.at(first_node_ind);
912     assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 ||
913            first_node.opcode() == ir::OpCode::ConvertFp16ToFp32);
914     VERBOSE(Fp32ToFp16Converter) << "Delete Node " << first_node_ind << std::endl;
915
916     // Uses
917     for (auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
918     {
919       auto &obj = operands.at(ind);
920       obj.removeUse(first_node_ind);
921       VERBOSE(Fp32ToFp16Converter)
922         << "Operand " << ind << "'s Use(Node" << first_node_ind << ") is removed" << std::endl;
923     }
924
925     // Def
926     for (auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
927     {
928       auto &obj = operands.at(ind);
929       assert(obj.getDef() == first_node_ind);
930       obj.unsetDef();
931       VERBOSE(Fp32ToFp16Converter)
932         << "Operand " << ind << "'s Def(Node" << first_node_ind << ") is removed" << std::endl;
933     }
934
935     // Operation
936     operations.remove(first_node_ind);
937     VERBOSE(Fp32ToFp16Converter) << "Node" << first_node_ind << " is removed" << std::endl;
938
939     // OpSequence
940     op_seqs.remove(op_seq_ind);
941     VERBOSE(Fp32ToFp16Converter) << "OpSeq" << op_seq_ind << " is removed" << std::endl;
942   }
943
944   // Operand
945   for (auto &ind : list_to_delete_ops)
946   {
947     operands.remove(ind);
948     VERBOSE(Fp32ToFp16Converter) << "Operand " << ind << " is removed" << std::endl;
949   }
950 }
951
952 } // namespace compiler
953
954 } // namespace onert
955
956 #endif