Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / moco / pass / src / Passes / ConstantFoldStridedSlice.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "moco/Pass/Passes/ConstantFoldStridedSlice.h"
18
19 #include "ConstantFoldHelper.h"
20 #include "TensorSliceEnumerator.h"
21
22 #include <moco/IR/Nodes/TFStridedSlice.h>
23 #include <moco/IR/Nodes/TFConst.h>
24
25 #include <moco/Support/NodeAs.h>
26 #include <moco/Support/TFShapeInferenceHelper.h>
27
28 #include <oops/UserExn.h>
29
30 #include <cassert>
31 #include <vector>
32
33 namespace
34 {
35
36 loco::TensorShape calc_output_shape(moco::TFStridedSlice *node)
37 {
38   auto const_input = loco::must_cast<moco::TFConst *>(node->input());
39   auto const_begin = loco::must_cast<moco::TFConst *>(node->begin());
40   auto const_end = loco::must_cast<moco::TFConst *>(node->end());
41   auto input_rank = const_input->rank();
42   auto output_rank = input_rank;
43   loco::TensorShape output_shape_range;
44
45   output_shape_range.rank(input_rank);
46   for (uint32_t r = 0; r < input_rank; ++r)
47   {
48     // TODO apply begin/end mask
49     // TODO apply ellipsis mask
50     // TODO apply strides
51     auto end = const_end->at<loco::DataType::S32>(r);
52     auto begin = const_begin->at<loco::DataType::S32>(r);
53     auto size = end - begin;
54     output_shape_range.dim(r).set(size);
55   }
56
57   loco::TensorShape output_tensor_shape;
58   if (node->shrink_axis_mask() != 0)
59   {
60     for (uint32_t rs = 0; rs < input_rank; ++rs)
61     {
62       int32_t bit = 1 << rs;
63       int32_t mask = node->shrink_axis_mask();
64       if (bit & mask)
65       {
66         // shrink one dimension
67         assert(output_rank > 0);
68         output_rank = output_rank - 1;
69       }
70     }
71     output_tensor_shape.rank(output_rank);
72     for (uint32_t rs = 0, rd = 0; rs < input_rank; ++rs)
73     {
74       int32_t bit = 1 << rs;
75       int32_t mask = node->shrink_axis_mask();
76       if ((bit & mask) == 0)
77       {
78         // use this dimension
79         output_tensor_shape.dim(rd).set(output_shape_range.dim(rs).value());
80         rd++;
81       }
82       // else this dimension is shrink-ed
83     }
84   }
85   else
86   {
87     output_tensor_shape = output_shape_range;
88   }
89
90   return output_tensor_shape;
91 }
92
93 moco::u32v_t vector_from_const(moco::TFConst *tfconst)
94 {
95   moco::u32v_t result;
96
97   auto rank = tfconst->rank();
98   assert(rank == 1);
99   auto dim = tfconst->dim(0).value();
100
101   result.resize(dim);
102   for (uint32_t r = 0; r < dim; ++r)
103   {
104     auto val = tfconst->at<loco::DataType::S32>(r);
105     result.at(r) = val;
106   }
107
108   return result;
109 }
110
111 moco::u32v_t operator-(const moco::u32v_t &lhs, const moco::u32v_t &rhs)
112 {
113   assert(lhs.size() == rhs.size());
114
115   moco::u32v_t res;
116   res.resize(lhs.size());
117   for (uint32_t r = 0; r < lhs.size(); r++)
118   {
119     res.at(r) = lhs.at(r) - rhs.at(r);
120   }
121   return res;
122 }
123
124 template <typename T> T tfconst_at(const moco::TFConst *tfconst, const moco::u32v_t &pos);
125
126 template <> int32_t tfconst_at<int32_t>(const moco::TFConst *tfconst, const moco::u32v_t &pos)
127 {
128   uint32_t rank = tfconst->rank();
129   assert(rank == pos.size());
130   uint32_t element = 0;
131   for (uint32_t r = 0; r < rank; ++r)
132   {
133     uint32_t dim = tfconst->dim(r).value();
134     element = element * dim + pos.at(r);
135   }
136   return tfconst->at<loco::DataType::S32>(element);
137 }
138
139 template <> float tfconst_at<float>(const moco::TFConst *tfconst, const moco::u32v_t &pos)
140 {
141   uint32_t rank = tfconst->rank();
142   assert(rank == pos.size());
143   uint32_t element = 0;
144   for (uint32_t r = 0; r < rank; ++r)
145   {
146     uint32_t dim = tfconst->dim(r).value();
147     element = element * dim + pos.at(r);
148   }
149   return tfconst->at<loco::DataType::FLOAT32>(element);
150 }
151
152 void tfconst_at(moco::TFConst *tfconst, const moco::u32v_t &pos, int32_t value)
153 {
154   // tfconst->rank() can be smaller than pos.size()
155   // i.e., tfconst: shape[3] and pos[0,1]
156   //                where shape[3] is output result shape
157   //                [0,1] is position of input const
158   uint32_t rank = pos.size();
159   uint32_t element = 0;
160   for (uint32_t r = 0; r < rank; ++r)
161   {
162     // this is like expand the shape from [3] to [1,3] to use same formula as in reading
163     uint32_t dim = tfconst->rank() < r ? tfconst->dim(r).value() : 1;
164     element = element * dim + pos.at(r);
165   }
166
167   tfconst->at<loco::DataType::S32>(element) = value;
168 }
169
170 void tfconst_at(moco::TFConst *tfconst, const moco::u32v_t &pos, float value)
171 {
172   uint32_t rank = pos.size();
173   uint32_t element = 0;
174   for (uint32_t r = 0; r < rank; ++r)
175   {
176     uint32_t dim = tfconst->rank() < r ? tfconst->dim(r).value() : 1;
177     element = element * dim + pos.at(r);
178   }
179
180   tfconst->at<loco::DataType::FLOAT32>(element) = value;
181 }
182
183 bool constantfold_stridedslice(moco::TFStridedSlice *node)
184 {
185   auto const_input = dynamic_cast<moco::TFConst *>(node->input());
186   if (const_input == nullptr)
187   {
188     // input is not TFConst, there's nothing to do
189     return false;
190   }
191
192   // TODO support full mask features: see import codes also
193   assert(node->begin_mask() == 0);
194   assert(node->end_mask() == 0);
195   assert(node->ellipsis_mask() == 0);
196   assert(node->shrink_axis_mask() == 1);
197
198   // TODO support other dtypes
199   assert(const_input->dtype() == loco::DataType::S32 ||
200          const_input->dtype() == loco::DataType::FLOAT32);
201
202   auto const_begin = dynamic_cast<moco::TFConst *>(node->begin());
203   auto const_end = dynamic_cast<moco::TFConst *>(node->end());
204   auto const_strides = dynamic_cast<moco::TFConst *>(node->strides());
205   if (const_begin == nullptr || const_end == nullptr || const_strides == nullptr)
206   {
207     return false;
208   }
209
210   // NOTE need shape but cannot depend on shape inference service module
211   auto tensor_shape = calc_output_shape(node);
212   auto input_shape = moco::tensor_shape(const_input);
213
214   auto graph = node->graph();
215
216   // Create our target TFConst node with shape from begin~end/strides
217   auto const_sliced = moco::new_const(graph, tensor_shape, const_input->dtype());
218
219   // Copy sliced elements using TensorSliceEnumerator
220   moco::TensorSliceEnumerator etor;
221   auto v_begin = vector_from_const(const_begin);
222   auto v_end = vector_from_const(const_end);
223   moco::u32v_t v_cursor;
224   moco::u32v_t v_offset;
225
226   etor.shape(input_shape);
227   etor.begin(v_begin);
228   etor.end(v_end);
229
230   for (etor.start(); etor.valid(); etor.advance())
231   {
232     v_cursor = etor.cursor();
233     v_offset = v_cursor - v_begin;
234
235     if (const_input->dtype() == loco::DataType::S32)
236     {
237       int32_t value = tfconst_at<int32_t>(const_input, v_cursor);
238       tfconst_at(const_sliced, v_offset, value);
239     }
240     else if (const_input->dtype() == loco::DataType::FLOAT32)
241     {
242       float value = tfconst_at<float>(const_input, v_cursor);
243       tfconst_at(const_sliced, v_offset, value);
244     }
245   }
246
247   // replace
248   loco::replace(node).with(const_sliced);
249
250   return true;
251 }
252
253 } // namespace
254
255 namespace moco
256 {
257
258 /**
259  * @note This will Replace TFStridedSlice with TFConst when 'input' is TFConst
260  *
261  *       Before
262  *                 A --- TFStridedSlice --- C
263  *                 B --/
264  *       After
265  *                 A --- TFStridedSlice
266  *                 B --/
267  *                       TFConst ---------- C
268  *       Where
269  *                 A,B : inputs of TFStridedSlice
270  *                 C : a node that uses TFStridedSlice as an input
271  *                 TFStridedSlice is disconnected from C
272  *                 Nodes are drawn multiple times to simplify the diagram
273  *       Limits
274  *                 Only limit set of inputs are supported for now
275  */
276 bool ConstantFoldStridedSlice::run(loco::Graph *graph)
277 {
278   bool changed = false;
279   for (auto node : loco::active_nodes(loco::output_nodes(graph)))
280   {
281     if (auto sslice_node = as<moco::TFStridedSlice>(node))
282     {
283       if (constantfold_stridedslice(sslice_node))
284         changed = true;
285     }
286   }
287
288   return changed;
289 }
290
291 } // namespace moco