2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <logo/SimplifyDomainConversionPass.h>
19 #include <loco/IR/Algorithm.h>
20 #include <loco/IR/CanonicalDialect.h>
21 #include <loco/IR/CanonicalNode.h>
23 #include <stdex/Memory.h>
34 // TODO Move this helper into loco
35 bool equal(const Permutation<Domain::Feature> *lhs, const Permutation<Domain::Feature> *rhs)
37 for (const auto &axis :
38 {FeatureAxis::Count, FeatureAxis::Depth, FeatureAxis::Height, FeatureAxis::Width})
40 if (lhs->axis(axis) != rhs->axis(axis))
48 bool equal(const Permutation<Domain::Filter> *lhs, const Permutation<Domain::Filter> *rhs)
50 for (const auto &axis :
51 {FilterAxis::Count, FilterAxis::Depth, FilterAxis::Height, FilterAxis::Width})
53 if (lhs->axis(axis) != rhs->axis(axis))
61 bool equal(const Permutation<Domain::DepthwiseFilter> *lhs,
62 const Permutation<Domain::DepthwiseFilter> *rhs)
64 for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Multiplier,
65 DepthwiseFilterAxis::Height, DepthwiseFilterAxis::Width})
67 if (lhs->axis(axis) != rhs->axis(axis))
75 bool equal(const Permutation<Domain::Matrix> *lhs, const Permutation<Domain::Matrix> *rhs)
77 for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
79 if (lhs->axis(axis) != rhs->axis(axis))
87 void set_input_null(loco::Node *node)
89 if (auto casted = dynamic_cast<loco::FeatureEncode *>(node))
90 casted->input(nullptr);
91 else if (auto casted = dynamic_cast<loco::FeatureDecode *>(node))
92 casted->input(nullptr);
93 else if (auto casted = dynamic_cast<loco::BiasDecode *>(node))
94 casted->input(nullptr);
95 else if (auto casted = dynamic_cast<loco::FilterEncode *>(node))
96 casted->input(nullptr);
97 else if (auto casted = dynamic_cast<loco::FilterDecode *>(node))
98 casted->input(nullptr);
99 else if (auto casted = dynamic_cast<loco::DepthwiseFilterEncode *>(node))
100 casted->input(nullptr);
101 else if (auto casted = dynamic_cast<loco::DepthwiseFilterDecode *>(node))
102 casted->input(nullptr);
103 else if (auto casted = dynamic_cast<loco::MatrixEncode *>(node))
104 casted->input(nullptr);
105 else if (auto casted = dynamic_cast<loco::MatrixDecode *>(node))
106 casted->input(nullptr);
108 assert(false && "not supported node type");
116 bool SimplifyDomainConversionPass::run(loco::Graph *g)
118 // TODO Introduce and Use "Pattern Match"
119 struct Collector final : public loco::CanonicalNodeMutableVisitor<void>
121 // Let's find FeatureDecode followed by FeatureEncode
122 void visit(loco::FeatureEncode *encode_node) final
124 using namespace loco;
126 auto encoder = encode_node->encoder();
127 assert(encoder != nullptr);
129 auto decode_node = dynamic_cast<loco::FeatureDecode *>(encode_node->input());
130 if (decode_node == nullptr)
134 assert(decode_node->input() != nullptr);
136 auto decoder = decode_node->decoder();
137 assert(decoder != nullptr);
139 // NOTE Work only for permuting codec
140 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
141 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
143 if (perm_encoder == nullptr || perm_decoder == nullptr)
148 if (equal(perm_encoder->perm(), perm_decoder->perm()))
150 forwardCandidates.insert({encode_node, decode_node->input()});
154 // Let's find `FeatureEncode -- FeatureDecode` pattern
155 void visit(loco::FeatureDecode *decode_node) final
157 using namespace loco;
159 auto encode_node = dynamic_cast<loco::FeatureEncode *>(decode_node->input());
160 if (encode_node == nullptr)
164 assert(encode_node->input() != nullptr);
166 auto encoder = encode_node->encoder();
167 assert(encoder != nullptr);
169 auto decoder = decode_node->decoder();
170 assert(decoder != nullptr);
172 // NOTE Work only for permuting codec
173 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Feature> *>(decoder);
174 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Feature> *>(encoder);
176 if (perm_encoder == nullptr || perm_decoder == nullptr)
181 if (equal(perm_encoder->perm(), perm_decoder->perm()))
183 forwardCandidates.insert({decode_node, encode_node->input()});
187 // Let's find `FilterEncode -- FilterDecode` pattern
188 void visit(loco::FilterDecode *decode_node) final
190 using namespace loco;
192 auto encode_node = dynamic_cast<loco::FilterEncode *>(decode_node->input());
193 if (encode_node == nullptr)
197 assert(encode_node->input() != nullptr);
199 auto encoder = encode_node->encoder();
200 assert(encoder != nullptr);
202 auto decoder = decode_node->decoder();
203 assert(decoder != nullptr);
205 // NOTE Work only for permuting codec
206 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Filter> *>(decoder);
207 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Filter> *>(encoder);
209 if (perm_encoder == nullptr || perm_decoder == nullptr)
214 if (equal(perm_encoder->perm(), perm_decoder->perm()))
216 forwardCandidates.insert({decode_node, encode_node->input()});
220 std::vector<loco::TensorAxis> perm_vec;
223 auto enc_perm = perm_encoder->perm();
224 auto dec_perm = perm_decoder->perm();
226 for (const auto &axis :
227 {FilterAxis::Count, FilterAxis::Height, FilterAxis::Width, FilterAxis::Depth})
229 auto from = enc_perm->axis(axis);
230 auto to = dec_perm->axis(axis);
234 transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
235 encode_node, decode_node, encode_node->input(), perm_vec));
239 // Let's find `BiasEncode -- BiasDecode` pattern
240 void visit(loco::BiasDecode *decode_node) final
242 if (auto encode_node = dynamic_cast<loco::BiasEncode *>(decode_node->input()))
244 assert(encode_node->input() != nullptr);
245 forwardCandidates.insert({decode_node, encode_node->input()});
249 // Let's find `DepthwiseFilterEncode -- DepthwiseFilterDecode` pattern
250 void visit(loco::DepthwiseFilterDecode *decode_node) final
252 using namespace loco;
254 auto encode_node = dynamic_cast<loco::DepthwiseFilterEncode *>(decode_node->input());
255 if (encode_node == nullptr)
259 assert(encode_node->input() != nullptr);
261 auto encoder = encode_node->encoder();
262 assert(encoder != nullptr);
264 auto decoder = decode_node->decoder();
265 assert(decoder != nullptr);
267 // NOTE Work only for permuting codec
268 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::DepthwiseFilter> *>(decoder);
269 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::DepthwiseFilter> *>(encoder);
271 if (perm_encoder == nullptr || perm_decoder == nullptr)
276 if (equal(perm_encoder->perm(), perm_decoder->perm()))
278 forwardCandidates.insert({decode_node, encode_node->input()});
282 std::vector<TensorAxis> perm_vec;
285 auto enc_perm = perm_encoder->perm();
286 auto dec_perm = perm_decoder->perm();
288 for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Height,
289 DepthwiseFilterAxis::Width, DepthwiseFilterAxis::Multiplier})
291 auto from = enc_perm->axis(axis);
292 auto to = dec_perm->axis(axis);
296 transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
297 encode_node, decode_node, encode_node->input(), perm_vec));
301 // Let's find MatrixDecode followed by MatrixEncode
302 void visit(loco::MatrixEncode *encode_node) final
304 using namespace loco;
306 auto encoder = encode_node->encoder();
307 assert(encoder != nullptr);
309 auto decode_node = dynamic_cast<loco::MatrixDecode *>(encode_node->input());
310 if (decode_node == nullptr)
314 assert(decode_node->input() != nullptr);
316 auto decoder = decode_node->decoder();
317 assert(decoder != nullptr);
319 // NOTE Work only for permuting codec
320 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
321 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
323 if (perm_encoder == nullptr || perm_decoder == nullptr)
328 if (equal(perm_encoder->perm(), perm_decoder->perm()))
330 forwardCandidates.insert({encode_node, decode_node->input()});
334 // Let's find MatrixEncode followed by MatrixDecode
335 void visit(loco::MatrixDecode *decode_node) final
337 using namespace loco;
339 auto encode_node = dynamic_cast<loco::MatrixEncode *>(decode_node->input());
340 if (encode_node == nullptr)
344 assert(encode_node->input() != nullptr);
346 auto encoder = encode_node->encoder();
347 assert(encoder != nullptr);
349 auto decoder = decode_node->decoder();
350 assert(decoder != nullptr);
352 // NOTE Work only for permuting codec
353 auto perm_decoder = dynamic_cast<const PermutingDecoder<Domain::Matrix> *>(decoder);
354 auto perm_encoder = dynamic_cast<const PermutingEncoder<Domain::Matrix> *>(encoder);
356 if (perm_encoder == nullptr || perm_decoder == nullptr)
361 if (equal(perm_encoder->perm(), perm_decoder->perm()))
363 forwardCandidates.insert({decode_node, encode_node->input()});
367 std::vector<loco::TensorAxis> perm_vec;
370 auto enc_perm = perm_encoder->perm();
371 auto dec_perm = perm_decoder->perm();
373 for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
375 auto from = enc_perm->axis(axis);
376 auto to = dec_perm->axis(axis);
380 transposeCandidates.insert(stdex::make_unique<TransposeCtx>(
381 encode_node, decode_node, encode_node->input(), perm_vec));
385 void visit(loco::Node *) final { return; }
387 using SimplifyingInfo = std::pair<loco::Node * /* end node of subgraph that will be replaced*/,
388 loco::Node * /* input of subgraph */>;
389 std::set<SimplifyingInfo> forwardCandidates;
393 loco::Node *first_node; // starting node of subgraph that will be replaced
394 loco::Node *last_node; // end node of subgraph that will be replaced
395 loco::Node *input_node; // input of subgraph
396 std::vector<loco::TensorAxis> perm_vec; // perm vector for transpose
398 TransposeCtx(loco::Node *first, loco::Node *last, loco::Node *input,
399 std::vector<loco::TensorAxis> perm)
400 : first_node(first), last_node(last), input_node(input), perm_vec(perm)
405 std::set<std::unique_ptr<TransposeCtx>> transposeCandidates;
410 for (auto node : loco::active_nodes(loco::output_nodes(g)))
412 if (node->dialect() == loco::CanonicalDialect::get())
414 auto canonical_node = loco::must_cast<loco::CanonicalNode *>(node);
415 canonical_node->accept(&collector);
419 for (auto p : collector.forwardCandidates)
421 auto forward_node = g->nodes()->create<loco::Forward>();
422 forward_node->input(p.second);
423 replace(p.first).with(forward_node);
424 set_input_null(p.first);
427 for (auto &ctx : collector.transposeCandidates)
429 auto transpose_node = g->nodes()->create<loco::TensorTranspose>();
431 transpose_node->perm()->size(ctx->perm_vec.size());
433 for (loco::TensorAxis axis = 0; axis < ctx->perm_vec.size(); axis++)
434 transpose_node->perm()->axis(axis) = ctx->perm_vec[axis];
437 transpose_node->input(ctx->input_node);
438 replace(ctx->last_node).with(transpose_node);
439 set_input_null(ctx->first_node);
442 return (collector.forwardCandidates.size() > 0 or collector.transposeCandidates.size() > 0);