Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RemoveRedundantTranspose.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/RemoveRedundantTransposePass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 namespace
22 {
23
24 /// @brief Return true if first_perm[second_perm[i]] == i
25 bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *second_perm)
26 {
27   assert(first_perm->rank() == 1);
28   assert(second_perm->rank() == 1);
29   assert(second_perm->size<loco::DataType::S32>() == first_perm->size<loco::DataType::S32>());
30   for (int32_t i = 0; i < static_cast<int32_t>(first_perm->size<loco::DataType::S32>()); i++)
31   {
32     if (first_perm->at<loco::DataType::S32>(second_perm->at<loco::DataType::S32>(i)) != i)
33       return false;
34   }
35   return true;
36 }
37
38 bool remove_consecutive_transpose_function(luci::CircleNode *node)
39 {
40   auto target_node = dynamic_cast<luci::CircleTranspose *>(node);
41   if (target_node == nullptr)
42     return false;
43   auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a());
44   if (pred_node == nullptr)
45     return false;
46   if (loco::succs(pred_node).size() != 1)
47     return false;
48
49   auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
50   if (pred_perm == nullptr)
51     return false;
52
53   auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
54   if (main_perm == nullptr)
55     return false;
56
57   auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a());
58   if (check_perm(pred_perm, main_perm))
59   {
60     replace(node).with(main_node);
61   }
62   else
63   {
64     auto g = main_perm->graph();
65     auto new_const_node = g->nodes()->create<luci::CircleConst>();
66
67     new_const_node->dtype(loco::DataType::S32);
68     new_const_node->rank(1);
69     new_const_node->dim(0) = main_perm->dim(0);
70     new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value());
71     new_const_node->shape_status(luci::ShapeStatus::VALID);
72     for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++)
73     {
74       new_const_node->at<loco::DataType::S32>(i) =
75           pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i));
76     }
77     pred_node->perm(new_const_node);
78     replace(node).with(pred_node);
79   }
80   return true;
81 }
82
83 } // namespace
84
85 namespace luci
86 {
87 /**
88  *  BEFORE
89  *         |
90  *   [CircleNode]     [CircleConst]
91  *    (main_node)      (main_perm)
92  *         \               /
93  *         [CircleTranspose]  [CircleConst]
94  *            (pred_node)      (pred_perm)
95  *                 \               /
96  *                 [CircleTranspose]
97  *                   (target_node)
98  *                         |
99  *
100  *  AFTER
101  *      <Optional Case>
102  *
103  *          |                 |                   |
104  *    [CircleNode]      [CircleConst]             |
105  *     (main_node)     (new_const_node)           |
106  *           \               /           or  [CircleNode]
107  *           [CircleTranspose]                (main_node)
108  *              (pred_node)                       |
109  *                   |                            |
110  *
111  */
112 bool RemoveRedundantTransposePass::run(loco::Graph *g)
113 {
114   bool changed = false;
115   for (auto node : loco::active_nodes(loco::output_nodes(g)))
116   {
117     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
118     if (remove_consecutive_transpose_function(circle_node))
119     {
120       changed = true;
121       break;
122     }
123   }
124   return changed;
125 }
126
127 } // namespace luci