Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / logo / src / Passes / SimplifyDomainConversionPass.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <logo/SimplifyDomainConversionPass.h>
18
19 #include <loco/IR/Algorithm.h>
20 #include <loco/IR/CanonicalDialect.h>
21 #include <loco/IR/CanonicalNode.h>
22
23 #include <stdex/Memory.h>
24
25 #include <set>
26 #include <vector>
27 #include <cassert>
28
29 namespace
30 {
31
32 using namespace loco;
33
34 // TODO Move this helper into loco
35 bool equal(const Permutation<Domain::Feature> *lhs, const Permutation<Domain::Feature> *rhs)
36 {
37   for (const auto &axis :
38        {FeatureAxis::Count, FeatureAxis::Depth, FeatureAxis::Height, FeatureAxis::Width})
39   {
40     if (lhs->axis(axis) != rhs->axis(axis))
41     {
42       return false;
43     }
44   }
45   return true;
46 }
47
48 bool equal(const Permutation<Domain::Filter> *lhs, const Permutation<Domain::Filter> *rhs)
49 {
50   for (const auto &axis :
51        {FilterAxis::Count, FilterAxis::Depth, FilterAxis::Height, FilterAxis::Width})
52   {
53     if (lhs->axis(axis) != rhs->axis(axis))
54     {
55       return false;
56     }
57   }
58   return true;
59 }
60
61 bool equal(const Permutation<Domain::DepthwiseFilter> *lhs,
62            const Permutation<Domain::DepthwiseFilter> *rhs)
63 {
64   for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Multiplier,
65                            DepthwiseFilterAxis::Height, DepthwiseFilterAxis::Width})
66   {
67     if (lhs->axis(axis) != rhs->axis(axis))
68     {
69       return false;
70     }
71   }
72   return true;
73 }
74
75 bool equal(const Permutation<Domain::Matrix> *lhs, const Permutation<Domain::Matrix> *rhs)
76 {
77   for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
78   {
79     if (lhs->axis(axis) != rhs->axis(axis))
80     {
81       return false;
82     }
83   }
84   return true;
85 }
86
87 void set_input_null(loco::Node *node)
88 {
89   if (auto casted = dynamic_cast<loco::FeatureEncode *>(node))
90     casted->input(nullptr);
91   else if (auto casted = dynamic_cast<loco::FeatureDecode *>(node))
92     casted->input(nullptr);
93   else if (auto casted = dynamic_cast<loco::BiasDecode *>(node))
94     casted->input(nullptr);
95   else if (auto casted = dynamic_cast<loco::FilterEncode *>(node))
96     casted->input(nullptr);
97   else if (auto casted = dynamic_cast<loco::FilterDecode *>(node))
98     casted->input(nullptr);
99   else if (auto casted = dynamic_cast<loco::DepthwiseFilterEncode *>(node))
100     casted->input(nullptr);
101   else if (auto casted = dynamic_cast<loco::DepthwiseFilterDecode *>(node))
102     casted->input(nullptr);
103   else if (auto casted = dynamic_cast<loco::MatrixEncode *>(node))
104     casted->input(nullptr);
105   else if (auto casted = dynamic_cast<loco::MatrixDecode *>(node))
106     casted->input(nullptr);
107   else
108     assert(false && "not supported node type");
109 }
110
111 } // namespace
112
113 namespace logo
114 {
115
116 bool SimplifyDomainConversionPass::run(loco::Graph *g)
117 {
118   // TODO Introduce and Use "Pattern Match"
119   struct Collector final : public loco::CanonicalNodeMutableVisitor<void>
120   {
121     // Let's find FeatureDecode followed by FeatureEncode
122     void visit(loco::FeatureEncode *encode_node) final
123     {
124       using namespace loco;
125
126       auto encoder = encode_node->encoder();
127       assert(encoder != nullptr);
128
129       auto decode_node = dynamic_cast<loco::FeatureDecode *>(encode_node->input());
130       if (decode_node == nullptr)
131       {
132         return;
133       }
134       assert(decode_node->input() != nullptr);
135
136       auto decoder = decode_node->decoder();
137       assert(decoder != nullptr);
138
139       // NOTE Work only for permuting codec
140       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
141       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
142
143       if (perm_encoder == nullptr || perm_decoder == nullptr)
144       {
145         return;
146       }
147
148       if (equal(perm_encoder->perm(), perm_decoder->perm()))
149       {
150         forwardCandidates.insert({encode_node, decode_node->input()});
151       }
152     }
153
154     // Let's find `FeatureEncode -- FeatureDecode` pattern
155     void visit(loco::FeatureDecode *decode_node) final
156     {
157       using namespace loco;
158
159       auto encode_node = dynamic_cast<loco::FeatureEncode *>(decode_node->input());
160       if (encode_node == nullptr)
161       {
162         return;
163       }
164       assert(encode_node->input() != nullptr);
165
166       auto encoder = encode_node->encoder();
167       assert(encoder != nullptr);
168
169       auto decoder = decode_node->decoder();
170       assert(decoder != nullptr);
171
172       // NOTE Work only for permuting codec
173       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
174       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
175
176       if (perm_encoder == nullptr || perm_decoder == nullptr)
177       {
178         return;
179       }
180
181       if (equal(perm_encoder->perm(), perm_decoder->perm()))
182       {
183         forwardCandidates.insert({decode_node, encode_node->input()});
184       }
185     }
186
187     // Let's find `FilterEncode -- FilterDecode` pattern
188     void visit(loco::FilterDecode *decode_node) final
189     {
190       using namespace loco;
191
192       auto encode_node = dynamic_cast<loco::FilterEncode *>(decode_node->input());
193       if (encode_node == nullptr)
194       {
195         return;
196       }
197       assert(encode_node->input() != nullptr);
198
199       auto encoder = encode_node->encoder();
200       assert(encoder != nullptr);
201
202       auto decoder = decode_node->decoder();
203       assert(decoder != nullptr);
204
205       // NOTE Work only for permuting codec
206       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Filter> *>(decoder);
207       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Filter> *>(encoder);
208
209       if (perm_encoder == nullptr || perm_decoder == nullptr)
210       {
211         return;
212       }
213
214       if (equal(perm_encoder->perm(), perm_decoder->perm()))
215       {
216         forwardCandidates.insert({decode_node, encode_node->input()});
217       }
218       else
219       {
220         std::vector<loco::TensorAxis> perm_vec;
221         perm_vec.resize(4);
222
223         auto enc_perm = perm_encoder->perm();
224         auto dec_perm = perm_decoder->perm();
225
226         for (const auto &axis :
227              {FilterAxis::Count, FilterAxis::Height, FilterAxis::Width, FilterAxis::Depth})
228         {
229           auto from = enc_perm->axis(axis);
230           auto to = dec_perm->axis(axis);
231           perm_vec[to] = from;
232         }
233
234         transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
235             encode_node, decode_node, encode_node->input(), perm_vec));
236       }
237     }
238
239     // Let's find `BiasEncode -- BiasDecode` pattern
240     void visit(loco::BiasDecode *decode_node) final
241     {
242       if (auto encode_node = dynamic_cast<loco::BiasEncode *>(decode_node->input()))
243       {
244         assert(encode_node->input() != nullptr);
245         forwardCandidates.insert({decode_node, encode_node->input()});
246       }
247     }
248
249     // Let's find `DepthwiseFilterEncode -- DepthwiseFilterDecode` pattern
250     void visit(loco::DepthwiseFilterDecode *decode_node) final
251     {
252       using namespace loco;
253
254       auto encode_node = dynamic_cast<loco::DepthwiseFilterEncode *>(decode_node->input());
255       if (encode_node == nullptr)
256       {
257         return;
258       }
259       assert(encode_node->input() != nullptr);
260
261       auto encoder = encode_node->encoder();
262       assert(encoder != nullptr);
263
264       auto decoder = decode_node->decoder();
265       assert(decoder != nullptr);
266
267       // NOTE Work only for permuting codec
268       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::DepthwiseFilter> *>(decoder);
269       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::DepthwiseFilter> *>(encoder);
270
271       if (perm_encoder == nullptr || perm_decoder == nullptr)
272       {
273         return;
274       }
275
276       if (equal(perm_encoder->perm(), perm_decoder->perm()))
277       {
278         forwardCandidates.insert({decode_node, encode_node->input()});
279       }
280       else
281       {
282         std::vector<TensorAxis> perm_vec;
283         perm_vec.resize(4);
284
285         auto enc_perm = perm_encoder->perm();
286         auto dec_perm = perm_decoder->perm();
287
288         for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Height,
289                                  DepthwiseFilterAxis::Width, DepthwiseFilterAxis::Multiplier})
290         {
291           auto from = enc_perm->axis(axis);
292           auto to = dec_perm->axis(axis);
293           perm_vec[to] = from;
294         }
295
296         transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
297             encode_node, decode_node, encode_node->input(), perm_vec));
298       }
299     }
300
301     // Let's find MatrixDecode followed by MatrixEncode
302     void visit(loco::MatrixEncode *encode_node) final
303     {
304       using namespace loco;
305
306       auto encoder = encode_node->encoder();
307       assert(encoder != nullptr);
308
309       auto decode_node = dynamic_cast<loco::MatrixDecode *>(encode_node->input());
310       if (decode_node == nullptr)
311       {
312         return;
313       }
314       assert(decode_node->input() != nullptr);
315
316       auto decoder = decode_node->decoder();
317       assert(decoder != nullptr);
318
319       // NOTE Work only for permuting codec
320       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
321       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
322
323       if (perm_encoder == nullptr || perm_decoder == nullptr)
324       {
325         return;
326       }
327
328       if (equal(perm_encoder->perm(), perm_decoder->perm()))
329       {
330         forwardCandidates.insert({encode_node, decode_node->input()});
331       }
332     }
333
334     // Let's find MatrixEncode followed by MatrixDecode
335     void visit(loco::MatrixDecode *decode_node) final
336     {
337       using namespace loco;
338
339       auto encode_node = dynamic_cast<loco::MatrixEncode *>(decode_node->input());
340       if (encode_node == nullptr)
341       {
342         return;
343       }
344       assert(encode_node->input() != nullptr);
345
346       auto encoder = encode_node->encoder();
347       assert(encoder != nullptr);
348
349       auto decoder = decode_node->decoder();
350       assert(decoder != nullptr);
351
352       // NOTE Work only for permuting codec
353       auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
354       auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
355
356       if (perm_encoder == nullptr || perm_decoder == nullptr)
357       {
358         return;
359       }
360
361       if (equal(perm_encoder->perm(), perm_decoder->perm()))
362       {
363         forwardCandidates.insert({decode_node, encode_node->input()});
364       }
365       else
366       {
367         std::vector<loco::TensorAxis> perm_vec;
368         perm_vec.resize(2);
369
370         auto enc_perm = perm_encoder->perm();
371         auto dec_perm = perm_decoder->perm();
372
373         for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
374         {
375           auto from = enc_perm->axis(axis);
376           auto to = dec_perm->axis(axis);
377           perm_vec[to] = from;
378         }
379
380         transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
381             encode_node, decode_node, encode_node->input(), perm_vec));
382       }
383     }
384
385     void visit(loco::Node *) final { return; }
386
387     using SimplifyingInfo = std::pair<loco::Node * /* end node of subgraph that will be replaced*/,
388                                       loco::Node * /* input of subgraph */>;
389     std::set<SimplifyingInfo> forwardCandidates;
390
391     struct TransposeCtx
392     {
393       loco::Node *first_node;                 // starting node of subgraph that will be replaced
394       loco::Node *last_node;                  // end node of subgraph that will be replaced
395       loco::Node *input_node;                 // input of subgraph
396       std::vector<loco::TensorAxis> perm_vec; // perm vector for transpose
397
398       TransposeCtx(loco::Node *first, loco::Node *last, loco::Node *input,
399                    std::vector<loco::TensorAxis> perm)
400           : first_node(first), last_node(last), input_node(input), perm_vec(perm)
401       { /* empty */
402       }
403     };
404
405     std::set<std::unique_ptr<TransposeCtx>> transposeCandidates;
406   };
407
408   Collector collector;
409
410   for (auto node : loco::active_nodes(loco::output_nodes(g)))
411   {
412     if (node->dialect() == loco::CanonicalDialect::get())
413     {
414       auto canonical_node = loco::must_cast<loco::CanonicalNode *>(node);
415       canonical_node->accept(&collector);
416     }
417   }
418
419   for (auto p : collector.forwardCandidates)
420   {
421     auto forward_node = g->nodes()->create<loco::Forward>();
422     forward_node->input(p.second);
423     replace(p.first).with(forward_node);
424     set_input_null(p.first);
425   }
426
427   for (auto &ctx : collector.transposeCandidates)
428   {
429     auto transpose_node = g->nodes()->create<loco::TensorTranspose>();
430     {
431       transpose_node->perm()->size(ctx->perm_vec.size());
432
433       for (loco::TensorAxis axis = 0; axis < ctx->perm_vec.size(); axis++)
434         transpose_node->perm()->axis(axis) = ctx->perm_vec[axis];
435     }
436
437     transpose_node->input(ctx->input_node);
438     replace(ctx->last_node).with(transpose_node);
439     set_input_null(ctx->first_node);
440   }
441
442   return (collector.forwardCandidates.size() > 0 or collector.transposeCandidates.size() > 0);
443 }
444
445 } // namespace logo