Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / SubstituteSqueezeToReshapePass.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Profile/CircleNodeOrigin.h>
21
22 namespace
23 {
24
25 /**
26  * @brief return TRUE if all dim is known
27  * @note This pass can be applied even some of dimensions are unknown.
28          For now, do not consider about it and update logic later.
29  */
30 bool can_squeeze_shape(const luci::CircleNode *node)
31 {
32   for (uint32_t r = 0; r < node->rank(); ++r)
33   {
34     if (not node->dim(r).known())
35       return false;
36   }
37   return true;
38 }
39
40 /**
41  * @brief return valid unsigned dim value from 0 ~ (rank-1)
42  * @note  dim can be -rank to (rank-1)
43  */
44 uint32_t valid_unsigned_dim(uint32_t rank, int32_t dim)
45 {
46   int32_t irank = static_cast<int32_t>(rank);
47   return dim >= 0 ? static_cast<uint32_t>(dim) : static_cast<uint32_t>(irank + dim);
48 }
49
50 /**
51  * @brief return TRUE if input dim is 1 for squeeze_dims values
52  */
53 bool is_valid_input(const luci::CircleNode *node, const std::vector<int32_t> &squeeze_dims)
54 {
55   auto rank = node->rank();
56   for (auto dim : squeeze_dims)
57   {
58     auto udim = valid_unsigned_dim(rank, dim);
59     if (node->dim(udim).value() != 1)
60       return false;
61   }
62   return true;
63 }
64
65 /**
66  * @brief return shape vector from input
67  */
68 std::vector<uint32_t> node_shape(const luci::CircleNode *input)
69 {
70   std::vector<uint32_t> shape;
71   uint32_t rank = input->rank();
72   for (uint32_t r = 0; r < rank; ++r)
73     shape.push_back(input->dim(r).value());
74
75   return shape;
76 }
77
78 /**
79  * @brief copy quantparam of src to dst
80  */
81 void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
82 {
83   auto q = src->quantparam();
84   if (q == nullptr)
85     dst->quantparam(nullptr);
86   else
87     dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
88 }
89
90 /**
91  * @brief return CircleConst ptr with values of new_shape
92  */
93 luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
94 {
95   // NOTE dim_size can be 0
96   uint32_t dim_size = static_cast<uint32_t>(new_shape.size());
97
98   auto shape_const = graph->nodes()->create<luci::CircleConst>();
99
100   // const shape/dtype
101   shape_const->dtype(loco::DataType::S32);
102   if (dim_size > 0)
103   {
104     shape_const->rank(1);
105     shape_const->dim(0).set(dim_size);
106   }
107   else
108     shape_const->rank(0);
109   shape_const->shape_status(luci::ShapeStatus::VALID);
110
111   // constant values
112   shape_const->size<loco::DataType::S32>(dim_size);
113   for (uint32_t i = 0; i < dim_size; ++i)
114     shape_const->at<loco::DataType::S32>(i) = new_shape.at(i);
115
116   return shape_const;
117 }
118
119 bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
120 {
121   assert(squeeze != nullptr);
122
123   auto input = loco::must_cast<luci::CircleNode *>(squeeze->input());
124   // we need input node shape and all dim should be known
125   if (input->shape_status() != luci::ShapeStatus::VALID)
126     return false;
127   if (not can_squeeze_shape(input))
128     return false;
129
130   // we will use squeeze shape for new shape
131   if (squeeze->shape_status() != luci::ShapeStatus::VALID)
132     return false;
133
134   auto squeeze_dims = squeeze->squeeze_dims();
135   if (not is_valid_input(input, squeeze_dims))
136     throw std::runtime_error("Invalid values in squeeze_dims: " + squeeze->name());
137
138   auto name = squeeze->name();
139   assert(name.length() > 0);
140
141   auto reshape_shape = node_shape(squeeze);
142   auto graph = squeeze->graph();
143   auto reshape = graph->nodes()->create<luci::CircleReshape>();
144   auto shape_const = create_shape_const(graph, reshape_shape);
145   copy_quantparam(reshape, squeeze);
146   reshape->name(name + "/Reshape");
147   luci::add_origin(reshape, luci::get_origin(squeeze));
148   shape_const->name(name + "/Reshape/shape");
149
150   // graph connection
151   reshape->tensor(input);
152   reshape->shape(shape_const);
153   replace(squeeze).with(reshape);
154
155   return true;
156 }
157
158 } // namespace
159
160 namespace luci
161 {
162
163 /**
164  * BEFORE
165  *           |
166  *      [CircleNode]
167  *           |
168  *    [CircleSqueeze]
169  *           |
170  *      [CircleNode]
171  *           |
172  *
173  * AFTER
174  *               |
175  *          [CircleNode]  [CircleConst]
176  *             |    \             /
177  *  [CircleSqueeze] [CircleReshape]
178  *                        |
179  *                   [CircleNode]
180  *                        |
181  */
182 bool SubstituteSqueezeToReshapePass::run(loco::Graph *g)
183 {
184   bool changed = false;
185   for (auto node : loco::active_nodes(loco::output_nodes(g)))
186   {
187     if (auto squeeze = dynamic_cast<luci::CircleSqueeze *>(node))
188     {
189       if (substitute_squeeze_to_reshape(squeeze))
190         changed = true;
191     }
192   }
193   return changed;
194 }
195
196 } // namespace luci