Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / moco / service / src / Service / TFShapeInferenceRule.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "moco/Service/TFShapeInferenceRule.h"
18
19 #include <moco/Support/TFShapeInferenceHelper.h>
20
21 #include "moco/IR/TFDialect.h"
22 #include "moco/IR/TFNode.h"
23
24 #include <loco/IR/NodeShape.h>
25 #include <loco/Service/ShapeInference.h>
26
27 #include <oops/UserExn.h>
28
29 #include <cassert>
30 #include <cmath>
31
32 namespace
33 {
34
35 class ShapeInferenceAlgorithm final : public moco::TFNodeVisitor<loco::NodeShape>
36 {
37 public:
38   ShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
39   {
40     // DO NOTHING
41   }
42
43 private:
44   const loco::ShapeInferenceRule::Context *_ctx;
45
46 private:
47   bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
48   loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
49
50 private:
51   loco::NodeShape binary_node_shape(const moco::TFNode::Node *node)
52   {
53     // This helper works only for binary node.
54     assert(node->arity() == 2);
55
56     auto lhs_shape = node_shape(node->arg(0));
57     auto rhs_shape = node_shape(node->arg(1));
58
59     loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>();
60     loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>();
61     loco::TensorShape sum_tensorshape = moco::broadcast_shape(lhs_tensorshape, rhs_tensorshape);
62
63     loco::NodeShape sum_shape({sum_tensorshape});
64
65     return sum_shape;
66   }
67
68   loco::NodeShape node_shape_with_check(const moco::TFNode::Node *node)
69   {
70     auto nodeshape = node_shape(node);
71     assert(nodeshape.domain() == loco::Domain::Tensor);
72
73     return nodeshape;
74   }
75
76   bool valid_scalar_value(moco::TFConst *node)
77   {
78     auto nodeshape = node_shape(node);
79     if (nodeshape.domain() != loco::Domain::Tensor)
80     {
81       return false;
82     }
83     if (node->dtype() != loco::DataType::S32)
84     {
85       return false;
86     }
87
88     auto tensor_shape = nodeshape.as<loco::TensorShape>();
89     if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
90     {
91       return false;
92     }
93
94     return true;
95   }
96
97   int32_t scalar_value(moco::TFConst *node)
98   {
99     auto nodeshape = node_shape(node);
100     assert(node->dtype() == loco::DataType::S32);
101
102     auto tensor_shape = nodeshape.as<loco::TensorShape>();
103     assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
104
105     return node->at<loco::DataType::S32>(0);
106   }
107
108 public:
109   loco::NodeShape visit(const moco::TFAdd *node) final { return binary_node_shape(node); }
110
111   loco::NodeShape visit(const moco::TFAvgPool *node) final
112   {
113     auto value_shape = node_shape(node->value());
114     assert(value_shape.domain() != loco::Domain::Unknown);
115
116     moco::PlaneInference infer_plane_shape;
117
118     infer_plane_shape.padding(node->padding());
119     infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
120     infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
121
122     auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
123     auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
124     auto output_feature_shape = input_feature_shape;
125     auto output_plane_shape = infer_plane_shape(input_plane_shape);
126
127     moco::update(output_feature_shape).with(output_plane_shape);
128
129     return moco::as_tensor_shape(output_feature_shape, node->data_layout());
130   }
131
132   loco::NodeShape visit(const moco::TFBiasAdd *node) final
133   {
134     return node_shape_with_check(node->value());
135   }
136
137   loco::NodeShape visit(const moco::TFConcatV2 *node) final
138   {
139     // axis shape should be available
140     auto axis_node = node->axis();
141     auto axis_shape = node_shape(axis_node);
142     assert(axis_shape.domain() != loco::Domain::Unknown);
143
144     // check all input shapes and all ranks should be same
145     auto value_a = node->values(0);
146     auto value_a_shape = node_shape(value_a);
147     assert(value_a_shape.domain() == loco::Domain::Tensor);
148     auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
149     uint32_t a_rank = value_a_tensor_shape.rank();
150
151     uint32_t num_values = node->num_values();
152     for (uint32_t ni = 1; ni < num_values; ++ni)
153     {
154       auto value_b = node->values(ni);
155       auto value_b_shape = node_shape(value_b);
156       assert(value_b_shape.domain() == loco::Domain::Tensor);
157       auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
158       assert(a_rank == value_b_tensor_shape.rank());
159     }
160
161     int32_t axis_value = 0;
162     bool axis_available = false;
163     {
164       // check for axis is TFConst
165       auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
166       if (tfconst != nullptr)
167       {
168         if (valid_scalar_value(tfconst))
169         {
170           axis_value = scalar_value(tfconst);
171           axis_available = true;
172         }
173       }
174     }
175     if (!axis_available)
176     {
177       // TODO may need to refine error message
178       throw oops::UserExn("ConcatV2 node does not have axis input", node->name());
179     }
180
181     uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
182     loco::TensorShape output_tensor_shape = value_a_tensor_shape;
183
184     for (uint32_t index = 0; index < a_rank; ++index)
185     {
186       if (value_a_tensor_shape.dim(index).known())
187       {
188         uint32_t dim = value_a_tensor_shape.dim(index).value();
189         uint32_t dim_acc = dim;
190
191         for (uint32_t ni = 1; ni < num_values; ++ni)
192         {
193           auto value_b = node->values(ni);
194           auto value_b_shape = node_shape(value_b);
195           assert(value_b_shape.domain() == loco::Domain::Tensor);
196           auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
197           assert(value_b_tensor_shape.dim(index).known());
198           if (index == axis_absolute)
199             dim_acc += value_b_tensor_shape.dim(index).value();
200           else
201             assert(dim == value_b_tensor_shape.dim(index).value());
202         }
203         output_tensor_shape.dim(index) = dim_acc;
204       }
205       else
206         output_tensor_shape.dim(index).unset();
207     }
208     return loco::NodeShape(output_tensor_shape);
209   }
210
211   loco::NodeShape visit(const moco::TFConst *node) final
212   {
213     loco::TensorShape output_tensor_shape;
214
215     uint32_t rank = node->rank();
216     output_tensor_shape.rank(rank);
217     for (uint32_t index = 0; index < rank; ++index)
218     {
219       if (node->dim(index).known())
220         output_tensor_shape.dim(index) = node->dim(index).value();
221       else
222         output_tensor_shape.dim(index).unset();
223     }
224
225     return loco::NodeShape(output_tensor_shape);
226   }
227
228   loco::NodeShape visit(const moco::TFConv2D *node) final
229   {
230     auto input_shape = moco::node_shape(node->input());
231     auto ker_shape = moco::node_shape(node->filter());
232     auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
233     auto node_stride = moco::stride_of(node->strides(), node->data_layout());
234     auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
235
236     moco::PlaneInference infer_plane_shape;
237
238     infer_plane_shape.padding(node->padding());
239     infer_plane_shape.stride(node_stride);
240     infer_plane_shape.window(node_window);
241
242     auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
243     auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
244     // output count is from input count, depth is from kernel 'O' which is dim(3)
245     auto output_feature_shape = input_feature_shape;
246     output_feature_shape.depth() = ker_tensor_shape.dim(3).value();
247
248     auto output_plane_shape = infer_plane_shape(input_plane_shape);
249
250     moco::update(output_feature_shape).with(output_plane_shape);
251
252     return moco::as_tensor_shape(output_feature_shape, node->data_layout());
253   }
254
255   loco::NodeShape visit(const moco::TFConv2DBackpropInput *node) final
256   {
257     // TFConv2DBackpropInput's first input, named 'input_sizes', actually contains shape of node
258     // output's feature map. We can get shape of TFConv2DBackpropInput by just copying this.
259     // TODO Support when 'input_sizes' is not TFConst, or support constant folding
260     auto input_sizes_node = dynamic_cast<moco::TFConst *>(node->input_sizes());
261     if (input_sizes_node == nullptr)
262     {
263       // we are now supporting somekind of constant folding for this node, wait till it is finished
264       loco::NodeShape unknown;
265       return unknown;
266     }
267
268     // Let's support S32 for time being
269     // TODO Support other integer types
270     assert(input_sizes_node->dtype() == loco::DataType::S32);
271     assert(input_sizes_node->size<loco::DataType::S32>() == 4);
272
273     // copy!
274     loco::TensorShape ofm_tensor_shape;
275     ofm_tensor_shape.rank(4);
276     for (uint32_t i = 0; i < 4; ++i)
277     {
278       int32_t dim = input_sizes_node->at<loco::DataType::S32>(i);
279       assert(dim > 0);
280       ofm_tensor_shape.dim(i) = (uint32_t)dim;
281     }
282
283     return loco::NodeShape(ofm_tensor_shape);
284   }
285
286   loco::NodeShape visit(const moco::TFDepthwiseConv2dNative *node) final
287   {
288     auto input_shape = moco::node_shape(node->input()); // NHWC
289     auto ker_shape = moco::node_shape(node->filter());
290     auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
291     auto node_stride = moco::stride_of(node->strides(), node->data_layout());
292     auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
293
294     moco::PlaneInference infer_plane_shape;
295
296     infer_plane_shape.padding(node->padding());
297     infer_plane_shape.stride(node_stride);
298     infer_plane_shape.window(node_window);
299
300     auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
301     auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
302     // output count is from input count, depth is from kernel 'CM' which is dim(2) * dim(3)
303     auto output_feature_shape = input_feature_shape;
304     output_feature_shape.depth() =
305         loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
306
307     auto output_plane_shape = infer_plane_shape(input_plane_shape);
308
309     moco::update(output_feature_shape).with(output_plane_shape);
310
311     return moco::as_tensor_shape(output_feature_shape, node->data_layout());
312   }
313
314   loco::NodeShape visit(const moco::TFFakeQuantWithMinMaxVars *node) final
315   {
316     return node_shape_with_check(node->inputs());
317   }
318
319   loco::NodeShape visit(const moco::TFFusedBatchNorm *node) final
320   {
321     return node_shape_with_check(node->x());
322   }
323
324   loco::NodeShape visit(const moco::TFIdentity *node) final
325   {
326     return node_shape_with_check(node->input());
327   }
328
329   loco::NodeShape visit(const moco::TFMaximum *node) final { return binary_node_shape(node); }
330
331   loco::NodeShape visit(const moco::TFMaxPool *node) final
332   {
333     auto input_shape = node_shape(node->input());
334     assert(input_shape.domain() != loco::Domain::Unknown);
335
336     moco::PlaneInference infer_plane_shape;
337
338     infer_plane_shape.padding(node->padding());
339     infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
340     infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
341
342     auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
343     auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
344     auto output_feature_shape = input_feature_shape;
345     auto output_plane_shape = infer_plane_shape(input_plane_shape);
346
347     moco::update(output_feature_shape).with(output_plane_shape);
348
349     return moco::as_tensor_shape(output_feature_shape, node->data_layout());
350   }
351
352   loco::NodeShape visit(const moco::TFMean *node) final
353   {
354     auto input_shape = node_shape(node->input());
355     auto reduction_indices = node->reduction_indices();
356
357     // Get constant values if reduction_indices is const
358     std::vector<int32_t> reduction_values;
359     if (auto tfconst = dynamic_cast<moco::TFConst *>(reduction_indices))
360     {
361       assert(tfconst->dtype() == loco::DataType::S32);
362       auto const_size = tfconst->size<loco::DataType::S32>();
363       for (uint32_t i = 0; i < const_size; ++i)
364       {
365         int32_t axis = tfconst->at<loco::DataType::S32>(i);
366         if (axis < 0)
367           axis += input_shape.as<loco::TensorShape>().rank();
368         reduction_values.push_back(axis);
369       }
370     }
371     else
372     {
373       // we cannot find a valid reduction indices value
374       loco::NodeShape unknown;
375       return unknown;
376     }
377
378     loco::TensorShape output_shape;
379     auto input_tensor_shape = input_shape.as<loco::TensorShape>();
380
381     if (node->keep_dims())
382     {
383       output_shape.rank(input_tensor_shape.rank());
384       for (uint32_t i = 0; i < input_tensor_shape.rank(); ++i)
385         output_shape.dim(i) = input_tensor_shape.dim(i);
386       for (uint32_t i = 0; i < reduction_values.size(); ++i)
387         output_shape.dim(reduction_values.at(i)) = 1;
388     }
389     else
390     {
391       std::vector<bool> check_reduce(input_tensor_shape.rank(), false);
392       for (uint32_t i = 0; i < reduction_values.size(); ++i)
393         check_reduce.at(reduction_values.at(i)) = true;
394
395       uint32_t reduce_cnt = 0;
396       for (uint32_t i = 0; i < check_reduce.size(); ++i)
397         if (check_reduce.at(i))
398           ++reduce_cnt;
399
400       output_shape.rank(input_tensor_shape.rank() - reduce_cnt);
401       for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
402         if (check_reduce.at(i) == false)
403           output_shape.dim(j++) = i;
404     }
405
406     return loco::NodeShape(output_shape);
407   }
408
409   loco::NodeShape visit(const moco::TFMul *node) final { return binary_node_shape(node); }
410
411   loco::NodeShape visit(const moco::TFPack *node) final
412   {
413     loco::NodeShape unknown;
414
415     auto input_shape_0 = node_shape(node->values(0));
416     if (input_shape_0.domain() != loco::Domain::Tensor)
417     {
418       // TODO fix this for other cases
419       // We support only valid tensor shape for now
420       return unknown;
421     }
422     loco::TensorShape tensor_shape_0 = input_shape_0.as<loco::TensorShape>();
423
424     // all input shapes should be same
425     auto num_values = node->N();
426     for (uint32_t i = 1; i < num_values; ++i)
427     {
428       auto input_shape = node_shape(node->values(i));
429       if (input_shape.domain() != loco::Domain::Tensor)
430       {
431         // TODO ditto
432         return unknown;
433       }
434
435       loco::TensorShape tensor_shape = input_shape.as<loco::TensorShape>();
436       if (!(input_shape_0 == input_shape))
437       {
438         throw oops::UserExn("All input values shape should be same", node->name());
439       }
440     }
441
442     // output rank will be +1 of rank of the input
443     // axis should be in range of [-r, r), where r is rank of the output
444     auto axis = node->axis();
445     int32_t rank = static_cast<int32_t>(tensor_shape_0.rank());
446     assert(rank >= 0);
447     int32_t rank_output = rank + 1;
448     if (axis < -rank_output || rank_output <= axis)
449     {
450       throw oops::UserExn("axis is out of range", node->name());
451     }
452
453     auto axis_stack = (axis >= 0) ? axis : rank_output + axis;
454
455     loco::TensorShape output_tensor_shape;
456
457     output_tensor_shape.rank(rank_output);
458     for (int32_t r = 0; r < axis_stack; ++r)
459     {
460       output_tensor_shape.dim(r).set(tensor_shape_0.dim(r).value());
461     }
462     output_tensor_shape.dim(axis_stack).set(num_values);
463     for (int32_t r = axis_stack; r < rank; ++r)
464     {
465       output_tensor_shape.dim(r + 1).set(tensor_shape_0.dim(r).value());
466     }
467
468     return loco::NodeShape(output_tensor_shape);
469   }
470
471   loco::NodeShape visit(const moco::TFPad *node) final
472   {
473     auto input_shape = node_shape(node->input());
474     assert(input_shape.domain() == loco::Domain::Tensor);
475
476     auto const_paddings = loco::must_cast<moco::TFConst *>(node->paddings());
477     assert(const_paddings->dtype() == loco::DataType::S32);
478     assert(const_paddings->rank() == 2);
479
480     loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
481     loco::TensorShape output_tensor_shape;
482
483     output_tensor_shape.rank(input_tensor_shape.rank());
484     for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
485     {
486       output_tensor_shape.dim(axis) = input_tensor_shape.dim(axis).value() +
487                                       const_paddings->at<loco::DataType::S32>(axis * 2) +
488                                       const_paddings->at<loco::DataType::S32>(axis * 2 + 1);
489     }
490
491     return loco::NodeShape{output_tensor_shape};
492   }
493
494   loco::NodeShape visit(const moco::TFPlaceholder *node) final
495   {
496     loco::TensorShape output_tensor_shape;
497
498     uint32_t rank = node->rank();
499     output_tensor_shape.rank(rank);
500     for (uint32_t index = 0; index < rank; ++index)
501     {
502       if (node->dim(index).known())
503         output_tensor_shape.dim(index) = node->dim(index).value();
504       else
505         output_tensor_shape.dim(index).unset();
506     }
507
508     return loco::NodeShape(output_tensor_shape);
509   }
510
511   loco::NodeShape visit(const moco::TFRealDiv *node) final { return binary_node_shape(node); }
512
513   loco::NodeShape visit(const moco::TFRelu *node) final
514   {
515     return node_shape_with_check(node->features());
516   }
517
518   loco::NodeShape visit(const moco::TFRelu6 *node) final
519   {
520     return node_shape_with_check(node->features());
521   }
522
523   loco::NodeShape visit(const moco::TFReshape *node) final
524   {
525     loco::NodeShape unknown;
526
527     // For now, we only consider Fixed Reshape, i.e. Reshape with determined
528     //      'shape' input. So here we only support case when 'shape' input of
529     //      TFReshape is TFConst. If 'shape' input is not TFConst, another
530     //      transform (e.g. constant folding) should be done beforehand to make
531     //      it TFConst.
532     // TODO Support dynamic Reshape
533     // Note that 'shape()' here is 'shape' input, not node's shape information
534     auto const_shape_input = dynamic_cast<moco::TFConst *>(node->shape());
535     if (!const_shape_input)
536     {
537       // 'shape' input of TFReshape is not TFConst, we can not do shape inference
538       return unknown;
539     }
540
541     // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1]
542     assert(const_shape_input->dtype() == loco::DataType::S32);
543     assert(const_shape_input->rank() == 1);
544
545     auto shape_rank = const_shape_input->dim(0).value();
546     assert(shape_rank > 0);
547
548     loco::TensorShape output_shape;
549     output_shape.rank(shape_rank);
550     for (uint32_t axis = 0; axis < shape_rank; ++axis)
551     {
552       auto shape_dim = const_shape_input->at<loco::DataType::S32>(axis);
553       if (shape_dim == -1)
554       {
555         // Reshape's new shape has wildcard dimension, i.e. dynamic reshape
556         return unknown;
557       }
558       assert(shape_dim >= 1);
559       output_shape.dim(axis) = shape_dim;
560     }
561
562     // TODO Compare 'tensor' input and validate coherency?
563     // Not sure this is appropriate stage for this task.
564
565     return loco::NodeShape(output_shape);
566   }
567
568   loco::NodeShape visit(const moco::TFRsqrt *node) final
569   {
570     return node_shape_with_check(node->x());
571   }
572
573   loco::NodeShape visit(const moco::TFShape *node) final
574   {
575     auto input_shape = node_shape(node->input());
576     auto input_tensor_shape = input_shape.as<loco::TensorShape>();
577
578     loco::TensorShape output_shape;
579
580     // Note that input shape becomes node(TFShape)'s value
581     output_shape.rank(1);
582     output_shape.dim(0) = input_tensor_shape.rank();
583
584     return loco::NodeShape(output_shape);
585   }
586
587   loco::NodeShape visit(const moco::TFSoftmax *node) final
588   {
589     return node_shape_with_check(node->logits());
590   }
591
592   loco::NodeShape visit(const moco::TFSqrt *node) final { return node_shape_with_check(node->x()); }
593
594   loco::NodeShape visit(const moco::TFSquaredDifference *node) final
595   {
596     return binary_node_shape(node);
597   }
598
599   loco::NodeShape visit(const moco::TFSqueeze *node) final
600   {
601     auto input_shape = node_shape(node->input());
602
603     // TODO Not sure Squeeze only get input as Tensor
604     // Note that tensor_shape() has assertion in it
605     auto input_tensor_shape = input_shape.as<loco::TensorShape>();
606
607     auto squeeze_dims_vec = node->squeeze_dims();
608     std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
609
610     loco::TensorShape output_shape;
611     uint32_t output_rank = 0;
612
613     if (squeeze_dims.empty())
614     {
615       // Remove all dimensions whose value is 1
616       for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
617       {
618         assert(input_tensor_shape.dim(axis).known());
619         auto dim = input_tensor_shape.dim(axis).value();
620         if (dim != 1)
621         {
622           assert(dim > 1);
623           output_shape.rank(++output_rank);
624           output_shape.dim(output_rank - 1) = dim;
625         }
626       }
627     }
628     else
629     {
630       uint32_t input_rank = input_tensor_shape.rank();
631
632       // Sanity check for 'squeeze_dims'
633       auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
634         if (!(squeeze_dims.size() < input_rank))
635           return false;
636         for (auto squeeze_dim : squeeze_dims)
637         {
638           if (!(squeeze_dim >= -(int64_t)input_rank))
639             return false;
640           if (!(squeeze_dim < (int64_t)input_rank))
641             return false;
642         }
643         return true;
644       };
645
646       if (!is_valid_squeeze_dims())
647       {
648         throw oops::UserExn("Invalid squeeze dimension", node->name());
649       }
650
651       // Resolve negative squeeze dimension
652       std::set<int64_t> resolved_squeeze_dims;
653       for (auto squeeze_dim : squeeze_dims)
654       {
655         if (squeeze_dim < 0)
656           resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
657         else
658           resolved_squeeze_dims.insert(squeeze_dim);
659       }
660
661       // Remove squeeze dimensions only
662       for (uint32_t axis = 0; axis < input_rank; ++axis)
663       {
664         assert(input_tensor_shape.dim(axis).known());
665         auto dim = input_tensor_shape.dim(axis).value();
666         if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
667         {
668           // Not squeeze dim
669           output_shape.rank(++output_rank);
670           output_shape.dim(output_rank - 1) = dim;
671         }
672         else
673         {
674           // Is squeeze dim
675           assert(dim == 1);
676           // DO NOTHING
677         }
678       }
679     }
680
681     assert(output_shape.rank() > 0);
682
683     return loco::NodeShape(output_shape);
684   }
685
686   loco::NodeShape visit(const moco::TFStopGradient *node) final
687   {
688     return node_shape_with_check(node->input());
689   }
690
691   loco::NodeShape visit(const moco::TFStridedSlice *node) final
692   {
693     loco::NodeShape unknown;
694     auto input_shape = node_shape(node->input());
695     if (input_shape.domain() != loco::Domain::Tensor)
696     {
697       // TODO fix this for other cases
698       // We support only tensor shape for now
699       return unknown;
700     }
701
702     // TODO support full mask features: see import codes also
703     // Limited attributes for now
704     assert(node->begin_mask() == 0);
705     assert(node->end_mask() == 0);
706     assert(node->ellipsis_mask() == 0);
707     assert(node->shrink_axis_mask() == 1);
708
709     auto const_begin = loco::must_cast<moco::TFConst *>(node->begin());
710     auto const_end = loco::must_cast<moco::TFConst *>(node->end());
711     auto const_strides = loco::must_cast<moco::TFConst *>(node->strides());
712
713     assert(dynamic_cast<moco::TFConst *>(node->input()) != nullptr);
714     assert(const_begin != nullptr);
715     assert(const_end != nullptr);
716     assert(const_strides != nullptr);
717
718     auto input_tensor_shape = input_shape.as<loco::TensorShape>();
719     auto input_rank = input_tensor_shape.rank();
720     auto output_rank = input_rank;
721
722     // TODO support strides with > 1
723     uint32_t elements = const_strides->size<loco::DataType::S32>();
724     for (uint32_t e = 0; e < elements; ++e)
725       assert(const_strides->at<loco::DataType::S32>(e) == 1);
726
727     // lets apply begin ~ end range from input shape
728     loco::TensorShape output_shape_range;
729
730     output_shape_range.rank(input_rank);
731     for (uint32_t r = 0; r < input_rank; ++r)
732     {
733       // TODO apply begin/end mask
734       // TODO apply ellipsis mask
735       // TODO apply strides
736       auto end = const_end->at<loco::DataType::S32>(r);
737       auto begin = const_begin->at<loco::DataType::S32>(r);
738       auto size = end - begin;
739       output_shape_range.dim(r).set(size);
740     }
741
742     // get final tensor shape from applying shrink mask to output_shape_range
743     loco::TensorShape output_tensor_shape;
744
745     if (node->shrink_axis_mask() != 0)
746     {
747       for (uint32_t rs = 0; rs < input_rank; ++rs)
748       {
749         int32_t bit = 1 << rs;
750         int32_t mask = node->shrink_axis_mask();
751         if (bit & mask)
752         {
753           // shrink one dimension
754           assert(output_rank > 0);
755           output_rank = output_rank - 1;
756         }
757       }
758       output_tensor_shape.rank(output_rank);
759       for (uint32_t rs = 0, rd = 0; rs < input_rank; ++rs)
760       {
761         int32_t bit = 1 << rs;
762         int32_t mask = node->shrink_axis_mask();
763         if ((bit & mask) == 0)
764         {
765           // use this dimension
766           output_tensor_shape.dim(rd).set(output_shape_range.dim(rs).value());
767           rd++;
768         }
769         // else this dimension is shrink-ed
770       }
771     }
772     else
773     {
774       output_tensor_shape = output_shape_range;
775     }
776
777     return loco::NodeShape(output_tensor_shape);
778   }
779
780   loco::NodeShape visit(const moco::TFSub *node) final { return binary_node_shape(node); }
781
782   loco::NodeShape visit(const moco::TFTanh *node) final { return node_shape_with_check(node->x()); }
783
784   // For virtual nodes
785   loco::NodeShape visit(const moco::TFPush *node) { return node_shape_with_check(node->from()); }
786
787 public:
788   loco::NodeShape visit(const moco::TFNode *) final
789   {
790     loco::NodeShape unknown;
791     return unknown;
792   }
793 };
794
795 } // namespace
796
797 namespace
798 {
799 namespace compat
800 {
801
802 struct Context final : public loco::ShapeInferenceRule::Context
803 {
804   bool known(const loco::Node *node) const final { return loco::shape_known(node); }
805   loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); }
806 };
807
808 class Sink final : public loco::ShapeInferenceRule::Sink
809 {
810 public:
811   enum Status
812   {
813     Unknown,
814     Okay,
815     Fail,
816   };
817
818 public:
819   const Status &status(void) const { return _status; }
820   const loco::NodeShape &shape(void) const { return _shape; }
821
822 public:
823   void okay(const loco::NodeShape &shape) final
824   {
825     _status = Okay;
826     _shape = shape;
827   }
828
829   void fail(void) final
830   {
831     // Notify failrue
832     _status = Fail;
833   }
834
835 private:
836   Status _status = Unknown;
837   loco::NodeShape _shape;
838 };
839
840 } // namespace compat
841 } // namespace
842
843 namespace moco
844 {
845
846 bool TFShapeInferenceRule::support(const API &api) const
847 {
848   return api == API::V1 or api == API::V2;
849 }
850
851 bool TFShapeInferenceRule::recognize(const loco::Dialect *d) const
852 {
853   // handle only TensorFlow dialect
854   return TFDialect::get() == d;
855 }
856
857 bool TFShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
858 {
859   ::compat::Context ctx;
860   ::compat::Sink sink;
861
862   infer(&ctx, node, &sink);
863
864   assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
865
866   if (sink.status() == ::compat::Sink::Fail)
867   {
868     return false;
869   }
870
871   shape = sink.shape();
872
873   return true;
874 }
875
876 void TFShapeInferenceRule::infer(const Context *ctx, const loco::Node *node, Sink *sink) const
877 {
878   assert(node->dialect() == TFDialect::get());
879   assert(dynamic_cast<const TFNode *>(node) != nullptr);
880
881   ShapeInferenceAlgorithm alg{ctx};
882   auto shape = loco::must_cast<const TFNode *>(node)->accept(&alg);
883
884   if (shape.domain() == loco::Domain::Unknown)
885     sink->fail();
886   else
887     sink->okay(shape);
888 }
889
890 } // namespace moco