f4bb1036401fc012f47420f95ec821636f7fcffd
[platform/core/ml/nnfw.git] / compiler / exo / src / Dialect / Service / TFLShapeInferenceRule.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TFLShapeInferenceRule.h"
18
19 #include "Dialect/IR/TFLNodes.h"
20 #include "Dialect/IR/TFLDialect.h"
21 #include "Dialect/IR/TFLNodeVisitor.h"
22
23 #include "Check.h"
24
25 #include <oops/InternalExn.h>
26
27 #include <algorithm>
28 #include <cassert>
29 #include <stdexcept>
30
31 namespace
32 {
33
34 // Call this for TFLAvgPool2D and TFLMaxPool2D only
35 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
36 {
37   EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known");
38
39   auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
40   assert(ifm_shape.rank() == 4);
41
42   uint32_t input_height = ifm_shape.dim(1).value();
43   uint32_t input_width = ifm_shape.dim(2).value();
44   uint32_t stride_height = node->stride()->h();
45   uint32_t stride_width = node->stride()->w();
46   uint32_t window_height = node->filter()->h();
47   uint32_t window_width = node->filter()->w();
48   uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1
49   uint32_t dilation_width = 1;
50   uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
51   uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
52
53   uint32_t output_height = 0;
54   uint32_t output_width = 0;
55
56   if (node->padding() == locoex::Padding::VALID)
57   {
58     output_height = (input_height + stride_height - effective_window_height) / stride_height;
59     output_width = (input_width + stride_width - effective_window_width) / stride_width;
60   }
61   else if (node->padding() == locoex::Padding::SAME)
62   {
63     output_height = (input_height + stride_height - 1) / stride_height;
64     output_width = (input_width + stride_width - 1) / stride_width;
65   }
66   else
67     EXO_ASSERT(false, "Wrong padding type");
68
69   loco::TensorShape ofm_shape;
70   ofm_shape.rank(4);
71   ofm_shape.dim(0) = ifm_shape.dim(0);
72   ofm_shape.dim(1) = output_height;
73   ofm_shape.dim(2) = output_width;
74   ofm_shape.dim(3) = ifm_shape.dim(3);
75
76   return loco::NodeShape{ofm_shape};
77 }
78
79 /**
80  * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
81  *
82  * HOW TO USE:
83  *
84  *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
85  */
86 class TensorShapeExpander
87 {
88 public:
89   TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
90   {
91     // DO NOTHING
92   }
93
94 public:
95   loco::TensorShape to(uint32_t output_rank)
96   {
97     auto const &input_shape = _shape;
98     uint32_t const input_rank = input_shape.rank();
99
100     assert(input_rank <= output_rank && "Cannot shrink rank");
101     uint32_t const axis_shift = output_rank - input_rank;
102
103     loco::TensorShape output_shape;
104
105     output_shape.rank(output_rank);
106     for (uint32_t axis = 0; axis < output_rank; ++axis)
107     {
108       output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
109     }
110
111     return output_shape;
112   }
113
114 private:
115   const loco::TensorShape _shape;
116 };
117
118 /**
119  * @breif  Expand shape x and y to same rank by align right and filling with 1
120  */
121 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
122 {
123   auto x_rank = x.rank();
124   auto y_rank = y.rank();
125
126   if (x_rank == y_rank)
127     return;
128
129   TensorShapeExpander x_exp(x);
130   TensorShapeExpander y_exp(y);
131
132   auto xy_rank = std::max(x_rank, y_rank);
133
134   x = x_rank > y_rank ? x : x_exp.to(xy_rank);
135   y = y_rank > x_rank ? y : y_exp.to(xy_rank);
136 }
137
138 /**
139  * @breif  Returns shape of expanded dimension of input x and y having same rank
140  */
141 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
142 {
143   assert(x.rank() == y.rank());
144
145   auto rank = x.rank();
146
147   loco::TensorShape output_shape;
148
149   output_shape.rank(rank);
150   for (uint32_t axis = 0; axis < rank; ++axis)
151   {
152     assert(x.dim(axis).known() && y.dim(axis).known());
153
154     auto x_dim = x.dim(axis).value();
155     auto y_dim = y.dim(axis).value();
156
157     // each dimension of x and y should be same or one must be 1 if different
158     if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
159       INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
160
161     output_shape.dim(axis) = std::max(x_dim, y_dim);
162   }
163
164   return output_shape;
165 }
166
167 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
168 {
169   auto x_match = x;
170   auto y_match = y;
171
172   expand_rank(x_match, y_match);
173
174   auto output_shape = expand_dimension(x_match, y_match);
175
176   return output_shape;
177 }
178
179 /**
180  * @brief Class to infer the shape of TFLNode
181  *
182  * @note All TFLNode's inputs and outputs are always loco::Domain::Tensor
183  */
184 class ShapeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::NodeShape>
185 {
186 public:
187   loco::NodeShape visit(const locoex::TFLAdd *node) final
188   {
189     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
190     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
191
192     auto output_shape = broadcast_shape(x_shape, y_shape);
193
194     return loco::NodeShape{output_shape};
195   }
196
197   loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
198   {
199     return infer_pool_2d_shape(node);
200   }
201
202   loco::NodeShape visit(const locoex::TFLConcatenation *node) final
203   {
204     // TODO Support when TFLConcatenation has 0 input
205     assert(node->numValues() > 0);
206
207     auto axis = node->axis();
208     auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
209
210     loco::TensorShape output_shape;
211
212     output_shape.rank(first_shape.rank());
213     for (uint32_t i = 0; i < output_shape.rank(); ++i)
214       output_shape.dim(i) = first_shape.dim(i);
215
216     for (uint32_t i = 1; i < node->numValues(); ++i)
217     {
218       auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
219
220       for (uint32_t j = 0; j < output_shape.rank(); ++j)
221       {
222         if (j == axis)
223           output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
224         else
225           assert(output_shape.dim(j) == input_shape.dim(j));
226       }
227     }
228
229     return loco::NodeShape{output_shape};
230   }
231
232   loco::NodeShape visit(const locoex::TFLConst *node) final
233   {
234     loco::TensorShape shape;
235
236     shape.rank(node->rank());
237     for (uint32_t axis = 0; axis < node->rank(); axis++)
238       shape.dim(axis) = node->dim(axis);
239
240     return loco::NodeShape{shape};
241   }
242
243   loco::NodeShape visit(const locoex::TFLConv2D *node) final
244   {
245     auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
246     auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
247
248     assert(ifm_shape.rank() == 4);
249     assert(ker_shape.rank() == 4);
250     assert(ifm_shape.dim(3) == ker_shape.dim(3));
251
252     uint32_t input_height = ifm_shape.dim(1).value();
253     uint32_t input_width = ifm_shape.dim(2).value();
254     uint32_t stride_height = node->stride()->h();
255     uint32_t stride_width = node->stride()->w();
256     uint32_t ker_height = ker_shape.dim(1).value();
257     uint32_t ker_width = ker_shape.dim(2).value();
258     uint32_t dilation_height = 1;
259     uint32_t dilation_width = 1;
260     uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
261     uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
262
263     uint32_t output_height = 0;
264     uint32_t output_width = 0;
265
266     if (node->padding() == locoex::Padding::VALID)
267     {
268       output_height = (input_height + stride_height - effective_ker_height) / stride_height;
269       output_width = (input_width + stride_width - effective_ker_width) / stride_width;
270     }
271     else if (node->padding() == locoex::Padding::SAME)
272     {
273       output_height = (input_height + stride_height - 1) / stride_height;
274       output_width = (input_width + stride_width - 1) / stride_width;
275     }
276     else
277       EXO_ASSERT(false, "Wrong padding type");
278
279     loco::TensorShape ofm_shape;
280     ofm_shape.rank(4);
281     ofm_shape.dim(0) = ifm_shape.dim(0);
282     ofm_shape.dim(1) = output_height;
283     ofm_shape.dim(2) = output_width;
284     ofm_shape.dim(3) = ker_shape.dim(0);
285
286     return loco::NodeShape{ofm_shape};
287   }
288
289   loco::NodeShape visit(const locoex::TFLDepthwiseConv2D *node) final
290   {
291     auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
292     auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
293
294     assert(ifm_shape.rank() == 4);
295     assert(ker_shape.rank() == 4);
296     assert(ker_shape.dim(0).value() == 1);
297
298     uint32_t input_height = ifm_shape.dim(1).value();
299     uint32_t input_width = ifm_shape.dim(2).value();
300     uint32_t stride_height = node->stride()->h();
301     uint32_t stride_width = node->stride()->w();
302     uint32_t ker_height = ker_shape.dim(1).value();
303     uint32_t ker_width = ker_shape.dim(2).value();
304     uint32_t dilation_height = 1;
305     uint32_t dilation_width = 1;
306     uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
307     uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
308
309     uint32_t output_height = 0;
310     uint32_t output_width = 0;
311
312     if (node->padding() == locoex::Padding::VALID)
313     {
314       output_height = (input_height + stride_height - effective_ker_height) / stride_height;
315       output_width = (input_width + stride_width - effective_ker_width) / stride_width;
316     }
317     else if (node->padding() == locoex::Padding::SAME)
318     {
319       output_height = (input_height + stride_height - 1) / stride_height;
320       output_width = (input_width + stride_width - 1) / stride_width;
321     }
322     else
323       EXO_ASSERT(false, "Wrong padding type");
324
325     loco::TensorShape ofm_shape;
326     ofm_shape.rank(4);
327     ofm_shape.dim(0) = ifm_shape.dim(0);
328     ofm_shape.dim(1) = output_height;
329     ofm_shape.dim(2) = output_width;
330     ofm_shape.dim(3) = ker_shape.dim(3);
331
332     return loco::NodeShape{ofm_shape};
333   }
334
335   loco::NodeShape visit(const locoex::TFLDiv *node) final
336   {
337     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
338     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
339
340     auto output_shape = broadcast_shape(x_shape, y_shape);
341
342     return loco::NodeShape{output_shape};
343   }
344
345   loco::NodeShape visit(const locoex::TFLFullyConnected *node) final
346   {
347     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
348     auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
349
350     // Checking shape capability for multiplication
351     EXO_ASSERT(input_shape.rank() == 2, "NYI for input shape rank > 2");
352     EXO_ASSERT(weights_shape.rank() == 2, "Incompatible weights rank for fully connected");
353     EXO_ASSERT(input_shape.dim(1) == weights_shape.dim(1),
354                "Incompatible shapes for fully connected");
355
356     loco::TensorShape out_shape;
357     out_shape.rank(2);
358
359     out_shape.dim(0) = input_shape.dim(0);
360     out_shape.dim(1) = weights_shape.dim(0);
361
362     return loco::NodeShape{out_shape};
363   }
364
365   loco::NodeShape visit(const locoex::TFLMaximum *node) final
366   {
367     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
368     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
369
370     auto output_shape = broadcast_shape(x_shape, y_shape);
371
372     return loco::NodeShape{output_shape};
373   }
374
375   loco::NodeShape visit(const locoex::TFLMaxPool2D *node) final
376   {
377     return infer_pool_2d_shape(node);
378   }
379
380   loco::NodeShape visit(const locoex::TFLMean *node) final
381   {
382     const loco::DataType S32 = loco::DataType::S32;
383
384     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
385     auto reduction_indices = dynamic_cast<locoex::TFLConst *>(node->reduction_indices());
386
387     { // Exceptions
388       // TODO support non-const case
389       EXO_ASSERT(reduction_indices, "Only support constant reduction_indices");
390       // TODO support other data type
391       EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
392     }
393
394     std::vector<int32_t> reduction_values;
395
396     for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
397     {
398       int32_t axis = reduction_indices->at<S32>(i);
399       if (axis < 0)
400         axis += input_shape.rank();
401       if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
402         INTERNAL_EXN_V("Invalid reduction axis for MEAN", oops::to_uint32(axis));
403       reduction_values.push_back(axis);
404     }
405
406     loco::TensorShape output_shape;
407
408     if (node->keep_dims())
409     {
410       output_shape.rank(input_shape.rank());
411       for (uint32_t i = 0; i < input_shape.rank(); ++i)
412         output_shape.dim(i) = input_shape.dim(i);
413       for (uint32_t i = 0; i < reduction_values.size(); ++i)
414         output_shape.dim(reduction_values.at(i)) = 1;
415     }
416     else
417     {
418       std::vector<bool> check_reduce(input_shape.rank(), false);
419       for (uint32_t i = 0; i < reduction_values.size(); ++i)
420         check_reduce.at(reduction_values.at(i)) = true;
421
422       uint32_t reduce_cnt = 0;
423       for (uint32_t i = 0; i < check_reduce.size(); ++i)
424         if (check_reduce.at(i))
425           ++reduce_cnt;
426
427       output_shape.rank(input_shape.rank() - reduce_cnt);
428       for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
429         if (check_reduce.at(i) == false)
430           output_shape.dim(j++) = i;
431     }
432
433     return loco::NodeShape{output_shape};
434   }
435
436   loco::NodeShape visit(const locoex::TFLMul *node) final
437   {
438     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
439     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
440
441     auto output_shape = broadcast_shape(x_shape, y_shape);
442
443     return loco::NodeShape{output_shape};
444   }
445
446   loco::NodeShape visit(const locoex::TFLRelu *node) final
447   {
448     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
449
450     return loco::NodeShape{input_shape};
451   }
452
453   loco::NodeShape visit(const locoex::TFLRelu6 *node) final
454   {
455     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
456
457     return loco::NodeShape{input_shape};
458   }
459
460   /**
461    * @note  TFLReshape has new shape info in two places: 2nd input and attribute.
462    *        This shape inference forces both to exist, and match each other.
463    *        When this condition satisfied, it return the inferred shape
464    *
465    * TODO Change this policy when not appropriate
466    */
467   loco::NodeShape visit(const locoex::TFLReshape *node) final
468   {
469     const loco::DataType S32 = loco::DataType::S32;
470
471     loco::TensorShape shape_by_input;
472     {
473       EXO_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
474
475       // Only support node's shape() is TFLConst with S32
476       // TODO support other node with other types
477       auto const_shape_node = dynamic_cast<locoex::TFLConst *>(node->shape());
478       EXO_ASSERT(const_shape_node, "Only support TFLConst for shape of TFLReshape");
479       EXO_ASSERT(const_shape_node->dtype() == S32, "Only support int32 TFLConst");
480
481       if (const_shape_node->rank() != 1)
482         INTERNAL_EXN_V("Only support rank 1 TFLConst", oops::to_uint32(const_shape_node->rank()));
483
484       shape_by_input.rank(const_shape_node->dim(0).value());
485
486       for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
487       {
488         EXO_ASSERT(const_shape_node->at<S32>(axis) > 0, "Dimension should be > 0")
489         shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
490       }
491     }
492
493     loco::TensorShape shape_by_attr;
494     {
495       shape_by_attr.rank(node->newShape()->rank());
496
497       for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
498       {
499         EXO_ASSERT(node->newShape()->dim(axis) > 0, "Dimension should be > 0")
500         shape_by_attr.dim(axis) = node->newShape()->dim(axis);
501       }
502     }
503
504     EXO_ASSERT(shape_by_input == shape_by_attr,
505                "Warning: Two new shape information mismatched for TFLReshape");
506
507     return loco::NodeShape{shape_by_input};
508   }
509
510   loco::NodeShape visit(const locoex::TFLRsqrt *node) final
511   {
512     auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
513
514     return loco::NodeShape{input_shape};
515   }
516
517   // TODO TFLSoftmax
518
519   loco::NodeShape visit(const locoex::TFLSqrt *node) final
520   {
521     auto input_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
522
523     return loco::NodeShape{input_shape};
524   }
525
526   loco::NodeShape visit(const locoex::TFLSquaredDifference *node) final
527   {
528     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
529     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
530
531     auto output_shape = broadcast_shape(x_shape, y_shape);
532
533     return loco::NodeShape{output_shape};
534   }
535
536   loco::NodeShape visit(const locoex::TFLSub *node) final
537   {
538     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
539     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
540
541     auto output_shape = broadcast_shape(x_shape, y_shape);
542
543     return loco::NodeShape{output_shape};
544   }
545
546   // TODO TFLTanh
547
548   /// @brief Returns output shape of transpose. Use loco::ConstGen and locoex::TFLConst for ConstT.
549   template <class ConstT>
550   loco::TensorShape output_shape_of_transpose(loco::TensorShape input_shape,
551                                               const ConstT *perm_node)
552   {
553     loco::TensorShape output_shape;
554     output_shape.rank(input_shape.rank());
555
556     assert(perm_node->dtype() == loco::DataType::S32);
557     assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
558
559     for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
560     {
561       auto new_dim = perm_node->template at<loco::DataType::S32>(out_axis);
562       output_shape.dim(new_dim) = input_shape.dim(out_axis);
563     }
564
565     return output_shape;
566   }
567
568   loco::NodeShape visit(const locoex::TFLTranspose *node) final
569   {
570     auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
571
572     auto canon_perm = dynamic_cast<loco::ConstGen *>(node->perm());
573     auto tfl_perm = dynamic_cast<locoex::TFLConst *>(node->perm());
574
575     if (canon_perm)
576     {
577       return loco::NodeShape{output_shape_of_transpose(input_shape, canon_perm)};
578     }
579     else if (tfl_perm)
580     {
581       return loco::NodeShape{output_shape_of_transpose(input_shape, tfl_perm)};
582     }
583     else
584       INTERNAL_EXN("perm of TFLTranspose should be either ConstGen or TFLConst");
585   }
586
587   loco::NodeShape visit(const locoex::TFLTransposeConv *node) final
588   {
589     // TransposeConv's output shape is written in its 'inputSizes' argument
590     auto input_sizes_const = dynamic_cast<locoex::TFLConst *>(node->inputSizes());
591     EXO_ASSERT(input_sizes_const, "Only support when TFLTransposeConv's inputSizes is TFLConst")
592     EXO_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
593     EXO_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
594                "Only support rank 1 with 4 entries")
595
596     loco::TensorShape shape;
597
598     shape.rank(4);
599     for (uint32_t axis = 0; axis < 4; ++axis)
600       shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
601
602     return loco::NodeShape{shape};
603   }
604 };
605
606 } // namespace
607
608 namespace locoex
609 {
610
611 bool TFLShapeInferenceRule::recognize(const loco::Dialect *d) const
612 {
613   return TFLDialect::get() == d;
614 }
615
616 bool TFLShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
617 {
618   assert(node->dialect() == TFLDialect::get());
619   assert(dynamic_cast<const TFLNode *>(node) != nullptr);
620
621   ShapeInferenceAlgorithm alg;
622   shape = dynamic_cast<const TFLNode *>(node)->accept(&alg);
623
624   return true;
625 }
626
627 } // namespace locoex