13e514a7822e37df6f9ec4d83d59c801b2f41430
[platform/core/ml/nnfw.git] / compiler / moco / support / src / TFShapeInferenceHelper.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "moco/Support/TFShapeInferenceHelper.h"
18
19 #include <loco/Service/ShapeInference.h>
20
21 #include <oops/InternalExn.h>
22
23 #include <cassert>
24
25 namespace
26 {
27
28 // TODO Use codes in loco and remove duplicate broadcast_shape() and related
29 /**
30  * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
31  *
32  * HOW TO USE:
33  *
34  *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
35  */
36 class TensorShapeExpander
37 {
38 public:
39   TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
40   {
41     // DO NOTHING
42   }
43
44 public:
45   loco::TensorShape to(uint32_t output_rank)
46   {
47     auto const &input_shape = _shape;
48     uint32_t const input_rank = input_shape.rank();
49
50     assert(input_rank <= output_rank && "Cannot shrink rank");
51     uint32_t const axis_shift = output_rank - input_rank;
52
53     loco::TensorShape output_shape;
54
55     output_shape.rank(output_rank);
56     for (uint32_t axis = 0; axis < output_rank; ++axis)
57     {
58       output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
59     }
60
61     return output_shape;
62   }
63
64 private:
65   const loco::TensorShape _shape;
66 };
67
68 /**
69  * @breif  Expand shape x and y to same rank by align right and filling with 1
70  */
71 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
72 {
73   auto x_rank = x.rank();
74   auto y_rank = y.rank();
75
76   if (x_rank == y_rank)
77     return;
78
79   TensorShapeExpander x_exp(x);
80   TensorShapeExpander y_exp(y);
81
82   auto xy_rank = std::max(x_rank, y_rank);
83
84   x = x_rank > y_rank ? x : x_exp.to(xy_rank);
85   y = y_rank > x_rank ? y : y_exp.to(xy_rank);
86 }
87
88 /**
89  * @breif  Returns shape of expanded dimension of input x and y having same rank
90  */
91 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
92 {
93   assert(x.rank() == y.rank());
94
95   auto rank = x.rank();
96
97   loco::TensorShape output_shape;
98
99   output_shape.rank(rank);
100   for (uint32_t axis = 0; axis < rank; ++axis)
101   {
102     assert(x.dim(axis).known() && y.dim(axis).known());
103
104     auto x_dim = x.dim(axis).value();
105     auto y_dim = y.dim(axis).value();
106
107     // each dimension of x and y should be same or one must be 1 if different
108     if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
109     {
110       // TODO may need to refine message
111       INTERNAL_EXN("ShapeInference: Input shapes don't match");
112     }
113
114     output_shape.dim(axis) = std::max(x_dim, y_dim);
115   }
116
117   return output_shape;
118 }
119
120 } // namespace
121
122 namespace moco
123 {
124
125 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
126 {
127   auto x_match = x;
128   auto y_match = y;
129
130   expand_rank(x_match, y_match);
131
132   auto output_shape = expand_dimension(x_match, y_match);
133
134   return output_shape;
135 }
136
137 } // namespace moco
138
139 namespace moco
140 {
141
142 loco::NodeShape node_shape(const loco::Node *node)
143 {
144   loco::NodeShape nodeshape; // default domain is Unknown
145
146   if (loco::shape_known(node))
147   {
148     nodeshape = loco::shape_get(node);
149   }
150
151   return nodeshape;
152 }
153
154 bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
155 {
156   nodeshape = node_shape(node);
157   return (nodeshape.domain() != loco::Domain::Unknown);
158 }
159
160 loco::TensorShape as_tensor_shape(const loco::FeatureShape &feature_shape,
161                                   const TFDataLayout &data_layout)
162 {
163   loco::TensorShape tensor_shape;
164
165   tensor_shape.rank(4);
166   if (data_layout == "NHWC")
167   {
168     tensor_shape.dim(0) = feature_shape.count();
169     tensor_shape.dim(1) = feature_shape.height();
170     tensor_shape.dim(2) = feature_shape.width();
171     tensor_shape.dim(3) = feature_shape.depth();
172   }
173   else if (data_layout == "NCHW")
174   {
175     tensor_shape.dim(0) = feature_shape.count();
176     tensor_shape.dim(1) = feature_shape.depth();
177     tensor_shape.dim(2) = feature_shape.height();
178     tensor_shape.dim(3) = feature_shape.width();
179   }
180   else
181   {
182     // TODO support for other data_layout if needed
183     INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
184   }
185
186   return tensor_shape;
187 }
188
189 loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
190                                     const TFDataLayout &data_layout)
191 {
192   if (nodeshape.domain() == loco::Domain::Feature)
193     return nodeshape.as<loco::FeatureShape>();
194
195   loco::FeatureShape feature_shape;
196
197   // only convert from tensor to feature
198   if (nodeshape.domain() != loco::Domain::Tensor)
199   {
200     INTERNAL_EXN("ShapeInference: Invalid shape information");
201   }
202
203   loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
204
205   if (tensor_shape.rank() != 4)
206   {
207     INTERNAL_EXN("ShapeInference: Rank is not 4");
208   }
209
210   if (data_layout == "NHWC")
211   {
212     feature_shape.count() = tensor_shape.dim(0);
213     feature_shape.height() = tensor_shape.dim(1);
214     feature_shape.width() = tensor_shape.dim(2);
215     feature_shape.depth() = tensor_shape.dim(3);
216   }
217   else if (data_layout == "NCHW")
218   {
219     feature_shape.count() = tensor_shape.dim(0);
220     feature_shape.depth() = tensor_shape.dim(1);
221     feature_shape.height() = tensor_shape.dim(2);
222     feature_shape.width() = tensor_shape.dim(3);
223   }
224   else
225   {
226     // TODO support for other data_layout if needed
227     INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
228   }
229
230   return feature_shape;
231 }
232
233 } // namespace moco
234
235 namespace moco
236 {
237
238 PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
239 {
240   PlaneShape plane_shape;
241
242   plane_shape.height = feature_shape.height();
243   plane_shape.width = feature_shape.width();
244
245   return plane_shape;
246 }
247
248 FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
249 {
250   return FeatureShapeUpdater{&feature_shape};
251 }
252
253 } // namespace moco
254
255 namespace
256 {
257
258 /**
259  * @brief Class to represent TensorFlow "data_format" attr.
260  */
261 enum class DataLayout
262 {
263   NHWC,
264   NCHW,
265 };
266
267 DataLayout as_data_layout(const std::string &tf_layout_str)
268 {
269   if (tf_layout_str == "NHWC")
270     return DataLayout::NHWC;
271   else if (tf_layout_str == "NCHW")
272     return DataLayout::NCHW;
273   else
274     /// @note data layout tag in TensorFlow is 'data_format'
275     INTERNAL_EXN_V("ShapeInference: Unknown data_format", tf_layout_str);
276 }
277
278 } // namespace
279
280 namespace moco
281 {
282
283 loco::Stride<2> stride_of(const TFStrides &strides, const TFDataLayout &datalayout)
284 {
285   loco::Stride<2> stride;
286
287   auto data_layout = as_data_layout(datalayout);
288   if (data_layout == DataLayout::NHWC)
289   {
290     stride.vertical(strides[1]);
291     stride.horizontal(strides[2]);
292   }
293   else if (data_layout == DataLayout::NCHW)
294   {
295     stride.vertical(strides[2]);
296     stride.horizontal(strides[3]);
297   }
298   else
299   {
300     // TODO add more datalayout supports if needed
301     INTERNAL_EXN("ShapeInference: Unknown data_format");
302   }
303
304   return stride;
305 }
306
307 loco::Window<2> window_of(const TFKSize &ksize, const TFDataLayout &datalayout)
308 {
309   loco::Window<2> window;
310
311   auto data_layout = as_data_layout(datalayout);
312   if (data_layout == DataLayout::NHWC)
313   {
314     window.vertical(ksize[1]);
315     window.horizontal(ksize[2]);
316   }
317   else if (data_layout == DataLayout::NCHW)
318   {
319     window.vertical(ksize[2]);
320     window.horizontal(ksize[3]);
321   }
322   else
323   {
324     // TODO add more datalayout supports if needed
325     INTERNAL_EXN("ShapeInference: Unknown data_format");
326   }
327
328   return window;
329 }
330
331 loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &datalayout)
332 {
333   loco::Window<2> window;
334
335   if (datalayout == "HWIO")
336   {
337     window.vertical(shape.dim(0).value());
338     window.horizontal(shape.dim(1).value());
339   }
340   else if (datalayout == "HWCM")
341   {
342     window.vertical(shape.dim(0).value());
343     window.horizontal(shape.dim(1).value());
344   }
345   else
346   {
347     // TODO add more datalayout supports if needed
348     INTERNAL_EXN_V("ShapeInference: Unknown data_format", datalayout);
349   }
350
351   return window;
352 }
353
354 } // namespace moco